feat(ai_client): Refactor tool dispatch to use asyncio.gather for parallel tool execution.
This commit is contained in:
311
src/ai_client.py
311
src/ai_client.py
@@ -13,6 +13,7 @@ during chat creation to avoid massive history bloat.
|
||||
"""
|
||||
# ai_client.py
|
||||
import tomllib
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
@@ -473,6 +474,84 @@ def _gemini_tool_declaration() -> Optional[types.Tool]:
|
||||
))
|
||||
return types.Tool(function_declarations=declarations) if declarations else None
|
||||
|
||||
async def _execute_tool_calls_concurrently(
|
||||
calls: list[Any],
|
||||
base_dir: str,
|
||||
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
|
||||
qa_callback: Optional[Callable[[str], str]],
|
||||
r_idx: int,
|
||||
provider: str
|
||||
) -> list[tuple[str, str, str, str]]: # tool_name, call_id, output, original_name
|
||||
"""
|
||||
Executes multiple tool calls concurrently using asyncio.gather.
|
||||
Returns a list of (tool_name, call_id, output, original_name).
|
||||
"""
|
||||
tasks = []
|
||||
for fc in calls:
|
||||
if provider == "gemini":
|
||||
name, args, call_id = fc.name, dict(fc.args), fc.name # Gemini 1.0.0 doesn't have call IDs in types.Part
|
||||
elif provider == "gemini_cli":
|
||||
name, args, call_id = cast(str, fc.get("name")), cast(dict[str, Any], fc.get("args", {})), cast(str, fc.get("id"))
|
||||
elif provider == "anthropic":
|
||||
name, args, call_id = cast(str, getattr(fc, "name")), cast(dict[str, Any], getattr(fc, "input")), cast(str, getattr(fc, "id"))
|
||||
elif provider == "deepseek":
|
||||
tool_info = fc.get("function", {})
|
||||
name = cast(str, tool_info.get("name"))
|
||||
tool_args_str = cast(str, tool_info.get("arguments", "{}"))
|
||||
call_id = cast(str, fc.get("id"))
|
||||
try: args = json.loads(tool_args_str)
|
||||
except: args = {}
|
||||
else:
|
||||
continue
|
||||
|
||||
tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx))
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
async def _execute_single_tool_call_async(
|
||||
name: str,
|
||||
args: dict[str, Any],
|
||||
call_id: str,
|
||||
base_dir: str,
|
||||
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
|
||||
qa_callback: Optional[Callable[[str], str]],
|
||||
r_idx: int
|
||||
) -> tuple[str, str, str, str]:
|
||||
out = ""
|
||||
tool_executed = False
|
||||
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
||||
|
||||
# Check for run_powershell
|
||||
if name == TOOL_NAME and pre_tool_callback:
|
||||
scr = cast(str, args.get("script", ""))
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
||||
# pre_tool_callback is synchronous and might block for HITL
|
||||
res = await asyncio.to_thread(pre_tool_callback, scr, base_dir, qa_callback)
|
||||
if res is None:
|
||||
out = "USER REJECTED: tool execution cancelled"
|
||||
else:
|
||||
out = res
|
||||
tool_executed = True
|
||||
|
||||
if not tool_executed:
|
||||
if name and name in mcp_client.TOOL_NAMES:
|
||||
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
|
||||
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
|
||||
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
|
||||
_res = await asyncio.to_thread(pre_tool_callback, desc, base_dir, qa_callback)
|
||||
out = "USER REJECTED: tool execution cancelled" if _res is None else await mcp_client.async_dispatch(name, args)
|
||||
else:
|
||||
out = await mcp_client.async_dispatch(name, args)
|
||||
elif name == TOOL_NAME:
|
||||
scr = cast(str, args.get("script", ""))
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
||||
out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback)
|
||||
else:
|
||||
out = f"ERROR: unknown tool '{name}'"
|
||||
|
||||
return (name, call_id, out, name)
|
||||
|
||||
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str:
|
||||
if confirm_and_run_callback is None:
|
||||
return "ERROR: no confirmation handler registered"
|
||||
@@ -762,48 +841,33 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
||||
if not calls or r_idx > MAX_TOOL_ROUNDS: break
|
||||
f_resps: list[types.Part] = []
|
||||
log: list[dict[str, Any]] = []
|
||||
for i, fc in enumerate(calls):
|
||||
name, args = fc.name, dict(fc.args)
|
||||
out = ""
|
||||
tool_executed = False
|
||||
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
||||
if name == TOOL_NAME and pre_tool_callback:
|
||||
scr = cast(str, args.get("script", ""))
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr})
|
||||
res = pre_tool_callback(scr, base_dir, qa_callback)
|
||||
if res is None:
|
||||
out = "USER REJECTED: tool execution cancelled"
|
||||
else:
|
||||
out = res
|
||||
tool_executed = True
|
||||
|
||||
if not tool_executed:
|
||||
if name and name in mcp_client.TOOL_NAMES:
|
||||
_append_comms("OUT", "tool_call", {"name": name, "args": args})
|
||||
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
|
||||
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
|
||||
_res = pre_tool_callback(desc, base_dir, qa_callback)
|
||||
out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args)
|
||||
else:
|
||||
out = mcp_client.dispatch(cast(str, name), args)
|
||||
elif name == TOOL_NAME:
|
||||
scr = cast(str, args.get("script", ""))
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr})
|
||||
out = _run_script(scr, base_dir, qa_callback)
|
||||
else: out = f"ERROR: unknown tool '{name}'"
|
||||
# Execute tools concurrently
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
results = asyncio.run_coroutine_threadsafe(
|
||||
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"),
|
||||
loop
|
||||
).result()
|
||||
except RuntimeError:
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"))
|
||||
|
||||
if i == len(calls) - 1:
|
||||
for i, (name, call_id, out, _) in enumerate(results):
|
||||
# Check if this is the last tool to trigger file refresh
|
||||
if i == len(results) - 1:
|
||||
if file_items:
|
||||
file_items, changed = _reread_file_items(file_items)
|
||||
ctx = _build_file_diff_text(changed)
|
||||
if ctx:
|
||||
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
||||
if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
||||
|
||||
out = _truncate_tool_output(out)
|
||||
_cumulative_tool_bytes += len(out)
|
||||
f_resps.append(types.Part(function_response=types.FunctionResponse(name=cast(str, name), response={"output": out})))
|
||||
log.append({"tool_use_id": name, "content": out})
|
||||
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
||||
|
||||
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||
f_resps.append(types.Part(text=
|
||||
f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
|
||||
@@ -877,42 +941,21 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||
})
|
||||
if not calls or r_idx > MAX_TOOL_ROUNDS:
|
||||
break
|
||||
|
||||
# Execute tools concurrently
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
results = asyncio.run_coroutine_threadsafe(
|
||||
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"),
|
||||
loop
|
||||
).result()
|
||||
except RuntimeError:
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"))
|
||||
|
||||
tool_results_for_cli: list[dict[str, Any]] = []
|
||||
for i, fc in enumerate(calls):
|
||||
name = cast(str, fc.get("name"))
|
||||
args = cast(dict[str, Any], fc.get("args", {}))
|
||||
call_id = cast(str, fc.get("id"))
|
||||
out = ""
|
||||
tool_executed = False
|
||||
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
||||
if name == TOOL_NAME and pre_tool_callback:
|
||||
scr = cast(str, args.get("script", ""))
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
||||
res = pre_tool_callback(scr, base_dir, qa_callback)
|
||||
if res is None:
|
||||
out = "USER REJECTED: tool execution cancelled"
|
||||
else:
|
||||
out = res
|
||||
tool_executed = True
|
||||
|
||||
if not tool_executed:
|
||||
if name and name in mcp_client.TOOL_NAMES:
|
||||
|
||||
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
|
||||
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
|
||||
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
|
||||
_res = pre_tool_callback(desc, base_dir, qa_callback)
|
||||
out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args)
|
||||
else:
|
||||
out = mcp_client.dispatch(name, args)
|
||||
elif name == TOOL_NAME:
|
||||
scr = cast(str, args.get("script", ""))
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
||||
out = _run_script(scr, base_dir, qa_callback)
|
||||
else:
|
||||
out = f"ERROR: unknown tool '{name}'"
|
||||
|
||||
if i == len(calls) - 1:
|
||||
for i, (name, call_id, out, _) in enumerate(results):
|
||||
# Check if this is the last tool to trigger file refresh
|
||||
if i == len(results) - 1:
|
||||
if file_items:
|
||||
file_items, changed = _reread_file_items(file_items)
|
||||
ctx = _build_file_diff_text(changed)
|
||||
@@ -920,6 +963,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
||||
if r_idx == MAX_TOOL_ROUNDS:
|
||||
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
||||
|
||||
out = _truncate_tool_output(out)
|
||||
_cumulative_tool_bytes += len(out)
|
||||
tool_results_for_cli.append({
|
||||
@@ -930,6 +974,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||
})
|
||||
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
||||
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
||||
|
||||
payload = tool_results_for_cli
|
||||
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
|
||||
@@ -1217,61 +1262,29 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
|
||||
break
|
||||
if round_idx > MAX_TOOL_ROUNDS:
|
||||
break
|
||||
|
||||
# Execute tools concurrently
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
results = asyncio.run_coroutine_threadsafe(
|
||||
_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic"),
|
||||
loop
|
||||
).result()
|
||||
except RuntimeError:
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic"))
|
||||
|
||||
tool_results: list[dict[str, Any]] = []
|
||||
for block in response.content:
|
||||
if getattr(block, "type", None) != "tool_use":
|
||||
continue
|
||||
b_name = cast(str, getattr(block, "name"))
|
||||
b_id = cast(str, getattr(block, "id"))
|
||||
b_input = cast(dict[str, Any], getattr(block, "input"))
|
||||
output = ""
|
||||
tool_executed = False
|
||||
events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx})
|
||||
if b_name == TOOL_NAME and pre_tool_callback:
|
||||
script = cast(str, b_input.get("script", ""))
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": b_id, "script": script})
|
||||
res = pre_tool_callback(script, base_dir, qa_callback)
|
||||
if res is None:
|
||||
output = "USER REJECTED: tool execution cancelled"
|
||||
else:
|
||||
output = res
|
||||
tool_executed = True
|
||||
|
||||
if not tool_executed:
|
||||
if b_name and b_name in mcp_client.TOOL_NAMES:
|
||||
|
||||
_append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input})
|
||||
if b_name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
|
||||
desc = f"# MCP MUTATING TOOL: {b_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in b_input.items())
|
||||
_res = pre_tool_callback(desc, base_dir, qa_callback)
|
||||
output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(b_name, b_input)
|
||||
else:
|
||||
output = mcp_client.dispatch(b_name, b_input)
|
||||
_append_comms("IN", "tool_result", {"name": b_name, "id": b_id, "output": output})
|
||||
elif b_name == TOOL_NAME:
|
||||
script = cast(str, b_input.get("script", ""))
|
||||
_append_comms("OUT", "tool_call", {
|
||||
"name": TOOL_NAME,
|
||||
"id": b_id,
|
||||
"script": script,
|
||||
})
|
||||
output = _run_script(script, base_dir, qa_callback)
|
||||
_append_comms("IN", "tool_result", {
|
||||
"name": TOOL_NAME,
|
||||
"id": b_id,
|
||||
"output": output,
|
||||
})
|
||||
else:
|
||||
output = f"ERROR: unknown tool '{b_name}'"
|
||||
|
||||
truncated = _truncate_tool_output(output)
|
||||
for i, (name, call_id, out, _) in enumerate(results):
|
||||
truncated = _truncate_tool_output(out)
|
||||
_cumulative_tool_bytes += len(truncated)
|
||||
tool_results.append({
|
||||
"type": "tool_result",
|
||||
"tool_use_id": b_id,
|
||||
"tool_use_id": call_id,
|
||||
"content": truncated,
|
||||
})
|
||||
events.emit("tool_execution", payload={"status": "completed", "tool": b_name, "result": output, "round": round_idx})
|
||||
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
||||
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx})
|
||||
|
||||
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||
tool_results.append({
|
||||
"type": "text",
|
||||
@@ -1448,62 +1461,39 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
||||
break
|
||||
if round_idx > MAX_TOOL_ROUNDS:
|
||||
break
|
||||
|
||||
# Execute tools concurrently
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
results = asyncio.run_coroutine_threadsafe(
|
||||
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"),
|
||||
loop
|
||||
).result()
|
||||
except RuntimeError:
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"))
|
||||
|
||||
tool_results_for_history: list[dict[str, Any]] = []
|
||||
for i, tc_raw in enumerate(tool_calls_raw):
|
||||
tool_info = tc_raw.get("function", {})
|
||||
tool_name = cast(str, tool_info.get("name"))
|
||||
tool_args_str = cast(str, tool_info.get("arguments", "{}"))
|
||||
tool_id = cast(str, tc_raw.get("id"))
|
||||
try:
|
||||
tool_args = json.loads(tool_args_str)
|
||||
except:
|
||||
tool_args = {}
|
||||
tool_output = ""
|
||||
tool_executed = False
|
||||
events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx})
|
||||
if tool_name == TOOL_NAME and pre_tool_callback:
|
||||
script = cast(str, tool_args.get("script", ""))
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script})
|
||||
res = pre_tool_callback(script, base_dir, qa_callback)
|
||||
if res is None:
|
||||
tool_output = "USER REJECTED: tool execution cancelled"
|
||||
else:
|
||||
tool_output = res
|
||||
tool_executed = True
|
||||
|
||||
if not tool_executed:
|
||||
if tool_name in mcp_client.TOOL_NAMES:
|
||||
_append_comms("OUT", "tool_call", {"name": tool_name, "id": tool_id, "args": tool_args})
|
||||
if tool_name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
|
||||
desc = f"# MCP MUTATING TOOL: {tool_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in tool_args.items())
|
||||
_res = pre_tool_callback(desc, base_dir, qa_callback)
|
||||
tool_output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(tool_name, tool_args)
|
||||
else:
|
||||
tool_output = mcp_client.dispatch(tool_name, tool_args)
|
||||
elif tool_name == TOOL_NAME:
|
||||
script = cast(str, tool_args.get("script", ""))
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script})
|
||||
tool_output = _run_script(script, base_dir, qa_callback)
|
||||
else:
|
||||
tool_output = f"ERROR: unknown tool '{tool_name}'"
|
||||
|
||||
if i == len(tool_calls_raw) - 1:
|
||||
for i, (name, call_id, out, _) in enumerate(results):
|
||||
if i == len(results) - 1:
|
||||
if file_items:
|
||||
file_items, changed = _reread_file_items(file_items)
|
||||
ctx = _build_file_diff_text(changed)
|
||||
if ctx:
|
||||
tool_output += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
||||
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
||||
if round_idx == MAX_TOOL_ROUNDS:
|
||||
tool_output += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
||||
tool_output = _truncate_tool_output(tool_output)
|
||||
_cumulative_tool_bytes += len(tool_output)
|
||||
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
||||
|
||||
truncated = _truncate_tool_output(out)
|
||||
_cumulative_tool_bytes += len(truncated)
|
||||
tool_results_for_history.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"content": tool_output,
|
||||
"tool_call_id": call_id,
|
||||
"name": name,
|
||||
"content": truncated,
|
||||
})
|
||||
_append_comms("IN", "tool_result", {"name": tool_name, "id": tool_id, "output": tool_output})
|
||||
events.emit("tool_execution", payload={"status": "completed", "tool": tool_name, "result": tool_output, "round": round_idx})
|
||||
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
||||
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx})
|
||||
|
||||
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||
tool_results_for_history.append({
|
||||
"role": "user",
|
||||
@@ -1524,6 +1514,7 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
||||
except Exception as e:
|
||||
raise _classify_deepseek_error(e) from e
|
||||
|
||||
|
||||
def run_tier4_analysis(stderr: str) -> str:
|
||||
if not stderr or not stderr.strip():
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user