diff --git a/src/ai_client.py b/src/ai_client.py index c8ecaaa..6d843cf 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -13,6 +13,7 @@ during chat creation to avoid massive history bloat. """ # ai_client.py import tomllib +import asyncio import json import sys import time @@ -473,6 +474,84 @@ def _gemini_tool_declaration() -> Optional[types.Tool]: )) return types.Tool(function_declarations=declarations) if declarations else None +async def _execute_tool_calls_concurrently( + calls: list[Any], + base_dir: str, + pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]], + qa_callback: Optional[Callable[[str], str]], + r_idx: int, + provider: str +) -> list[tuple[str, str, str, str]]: # tool_name, call_id, output, original_name + """ + Executes multiple tool calls concurrently using asyncio.gather. + Returns a list of (tool_name, call_id, output, original_name). + """ + tasks = [] + for fc in calls: + if provider == "gemini": + name, args, call_id = fc.name, dict(fc.args), fc.name # Gemini 1.0.0 doesn't have call IDs in types.Part + elif provider == "gemini_cli": + name, args, call_id = cast(str, fc.get("name")), cast(dict[str, Any], fc.get("args", {})), cast(str, fc.get("id")) + elif provider == "anthropic": + name, args, call_id = cast(str, getattr(fc, "name")), cast(dict[str, Any], getattr(fc, "input")), cast(str, getattr(fc, "id")) + elif provider == "deepseek": + tool_info = fc.get("function", {}) + name = cast(str, tool_info.get("name")) + tool_args_str = cast(str, tool_info.get("arguments", "{}")) + call_id = cast(str, fc.get("id")) + try: args = json.loads(tool_args_str) + except: args = {} + else: + continue + + tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx)) + + results = await asyncio.gather(*tasks) + return results + +async def _execute_single_tool_call_async( + name: str, + args: dict[str, Any], + call_id: str, + base_dir: str, + pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]], + qa_callback: Optional[Callable[[str], str]], + r_idx: int +) -> tuple[str, str, str, str]: + out = "" + tool_executed = False + events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx}) + + # Check for run_powershell + if name == TOOL_NAME and pre_tool_callback: + scr = cast(str, args.get("script", "")) + _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr}) + # pre_tool_callback is synchronous and might block for HITL + res = await asyncio.to_thread(pre_tool_callback, scr, base_dir, qa_callback) + if res is None: + out = "USER REJECTED: tool execution cancelled" + else: + out = res + tool_executed = True + + if not tool_executed: + if name and name in mcp_client.TOOL_NAMES: + _append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args}) + if name in mcp_client.MUTATING_TOOLS and pre_tool_callback: + desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items()) + _res = await asyncio.to_thread(pre_tool_callback, desc, base_dir, qa_callback) + out = "USER REJECTED: tool execution cancelled" if _res is None else await mcp_client.async_dispatch(name, args) + else: + out = await mcp_client.async_dispatch(name, args) + elif name == TOOL_NAME: + scr = cast(str, args.get("script", "")) + _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr}) + out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback) + else: + out = f"ERROR: unknown tool '{name}'" + + return (name, call_id, out, name) + def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str: if confirm_and_run_callback is None: return "ERROR: no confirmation handler registered" @@ -762,48 +841,33 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, if not calls or r_idx > MAX_TOOL_ROUNDS: break f_resps: list[types.Part] = [] log: list[dict[str, Any]] = [] - for i, fc in enumerate(calls): - name, args = fc.name, dict(fc.args) - out = "" - tool_executed = False - events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx}) - if name == TOOL_NAME and pre_tool_callback: - scr = cast(str, args.get("script", "")) - _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr}) - res = pre_tool_callback(scr, base_dir, qa_callback) - if res is None: - out = "USER REJECTED: tool execution cancelled" - else: - out = res - tool_executed = True - - if not tool_executed: - if name and name in mcp_client.TOOL_NAMES: - _append_comms("OUT", "tool_call", {"name": name, "args": args}) - if name in mcp_client.MUTATING_TOOLS and pre_tool_callback: - desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items()) - _res = pre_tool_callback(desc, base_dir, qa_callback) - out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args) - else: - out = mcp_client.dispatch(cast(str, name), args) - elif name == TOOL_NAME: - scr = cast(str, args.get("script", "")) - _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr}) - out = _run_script(scr, base_dir, qa_callback) - else: out = f"ERROR: unknown tool '{name}'" - - if i == len(calls) - 1: + + # Execute tools concurrently + try: + loop = asyncio.get_running_loop() + results = asyncio.run_coroutine_threadsafe( + _execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"), + loop + ).result() + except RuntimeError: + results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini")) + + for i, (name, call_id, out, _) in enumerate(results): + # Check if this is the last tool to trigger file refresh + if i == len(results) - 1: if file_items: file_items, changed = _reread_file_items(file_items) ctx = _build_file_diff_text(changed) if ctx: out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}" if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]" + out = _truncate_tool_output(out) _cumulative_tool_bytes += len(out) f_resps.append(types.Part(function_response=types.FunctionResponse(name=cast(str, name), response={"output": out}))) log.append({"tool_use_id": name, "content": out}) events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx}) + if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: f_resps.append(types.Part(text= f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now." @@ -877,42 +941,21 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, }) if not calls or r_idx > MAX_TOOL_ROUNDS: break - tool_results_for_cli: list[dict[str, Any]] = [] - for i, fc in enumerate(calls): - name = cast(str, fc.get("name")) - args = cast(dict[str, Any], fc.get("args", {})) - call_id = cast(str, fc.get("id")) - out = "" - tool_executed = False - events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx}) - if name == TOOL_NAME and pre_tool_callback: - scr = cast(str, args.get("script", "")) - _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr}) - res = pre_tool_callback(scr, base_dir, qa_callback) - if res is None: - out = "USER REJECTED: tool execution cancelled" - else: - out = res - tool_executed = True - - if not tool_executed: - if name and name in mcp_client.TOOL_NAMES: + + # Execute tools concurrently + try: + loop = asyncio.get_running_loop() + results = asyncio.run_coroutine_threadsafe( + _execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"), + loop + ).result() + except RuntimeError: + results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli")) - _append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args}) - if name in mcp_client.MUTATING_TOOLS and pre_tool_callback: - desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items()) - _res = pre_tool_callback(desc, base_dir, qa_callback) - out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args) - else: - out = mcp_client.dispatch(name, args) - elif name == TOOL_NAME: - scr = cast(str, args.get("script", "")) - _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr}) - out = _run_script(scr, base_dir, qa_callback) - else: - out = f"ERROR: unknown tool '{name}'" - - if i == len(calls) - 1: + tool_results_for_cli: list[dict[str, Any]] = [] + for i, (name, call_id, out, _) in enumerate(results): + # Check if this is the last tool to trigger file refresh + if i == len(results) - 1: if file_items: file_items, changed = _reread_file_items(file_items) ctx = _build_file_diff_text(changed) @@ -920,6 +963,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}" if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]" + out = _truncate_tool_output(out) _cumulative_tool_bytes += len(out) tool_results_for_cli.append({ @@ -930,6 +974,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, }) _append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out}) events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx}) + payload = tool_results_for_cli if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: _append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"}) @@ -1217,61 +1262,29 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item break if round_idx > MAX_TOOL_ROUNDS: break - tool_results: list[dict[str, Any]] = [] - for block in response.content: - if getattr(block, "type", None) != "tool_use": - continue - b_name = cast(str, getattr(block, "name")) - b_id = cast(str, getattr(block, "id")) - b_input = cast(dict[str, Any], getattr(block, "input")) - output = "" - tool_executed = False - events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx}) - if b_name == TOOL_NAME and pre_tool_callback: - script = cast(str, b_input.get("script", "")) - _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": b_id, "script": script}) - res = pre_tool_callback(script, base_dir, qa_callback) - if res is None: - output = "USER REJECTED: tool execution cancelled" - else: - output = res - tool_executed = True - - if not tool_executed: - if b_name and b_name in mcp_client.TOOL_NAMES: + + # Execute tools concurrently + try: + loop = asyncio.get_running_loop() + results = asyncio.run_coroutine_threadsafe( + _execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic"), + loop + ).result() + except RuntimeError: + results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic")) - _append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input}) - if b_name in mcp_client.MUTATING_TOOLS and pre_tool_callback: - desc = f"# MCP MUTATING TOOL: {b_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in b_input.items()) - _res = pre_tool_callback(desc, base_dir, qa_callback) - output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(b_name, b_input) - else: - output = mcp_client.dispatch(b_name, b_input) - _append_comms("IN", "tool_result", {"name": b_name, "id": b_id, "output": output}) - elif b_name == TOOL_NAME: - script = cast(str, b_input.get("script", "")) - _append_comms("OUT", "tool_call", { - "name": TOOL_NAME, - "id": b_id, - "script": script, - }) - output = _run_script(script, base_dir, qa_callback) - _append_comms("IN", "tool_result", { - "name": TOOL_NAME, - "id": b_id, - "output": output, - }) - else: - output = f"ERROR: unknown tool '{b_name}'" - - truncated = _truncate_tool_output(output) + tool_results: list[dict[str, Any]] = [] + for i, (name, call_id, out, _) in enumerate(results): + truncated = _truncate_tool_output(out) _cumulative_tool_bytes += len(truncated) tool_results.append({ "type": "tool_result", - "tool_use_id": b_id, + "tool_use_id": call_id, "content": truncated, }) - events.emit("tool_execution", payload={"status": "completed", "tool": b_name, "result": output, "round": round_idx}) + _append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out}) + events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx}) + if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: tool_results.append({ "type": "text", @@ -1448,62 +1461,39 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str, break if round_idx > MAX_TOOL_ROUNDS: break + + # Execute tools concurrently + try: + loop = asyncio.get_running_loop() + results = asyncio.run_coroutine_threadsafe( + _execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"), + loop + ).result() + except RuntimeError: + results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek")) + tool_results_for_history: list[dict[str, Any]] = [] - for i, tc_raw in enumerate(tool_calls_raw): - tool_info = tc_raw.get("function", {}) - tool_name = cast(str, tool_info.get("name")) - tool_args_str = cast(str, tool_info.get("arguments", "{}")) - tool_id = cast(str, tc_raw.get("id")) - try: - tool_args = json.loads(tool_args_str) - except: - tool_args = {} - tool_output = "" - tool_executed = False - events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx}) - if tool_name == TOOL_NAME and pre_tool_callback: - script = cast(str, tool_args.get("script", "")) - _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script}) - res = pre_tool_callback(script, base_dir, qa_callback) - if res is None: - tool_output = "USER REJECTED: tool execution cancelled" - else: - tool_output = res - tool_executed = True - - if not tool_executed: - if tool_name in mcp_client.TOOL_NAMES: - _append_comms("OUT", "tool_call", {"name": tool_name, "id": tool_id, "args": tool_args}) - if tool_name in mcp_client.MUTATING_TOOLS and pre_tool_callback: - desc = f"# MCP MUTATING TOOL: {tool_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in tool_args.items()) - _res = pre_tool_callback(desc, base_dir, qa_callback) - tool_output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(tool_name, tool_args) - else: - tool_output = mcp_client.dispatch(tool_name, tool_args) - elif tool_name == TOOL_NAME: - script = cast(str, tool_args.get("script", "")) - _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script}) - tool_output = _run_script(script, base_dir, qa_callback) - else: - tool_output = f"ERROR: unknown tool '{tool_name}'" - - if i == len(tool_calls_raw) - 1: + for i, (name, call_id, out, _) in enumerate(results): + if i == len(results) - 1: if file_items: file_items, changed = _reread_file_items(file_items) ctx = _build_file_diff_text(changed) if ctx: - tool_output += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}" + out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}" if round_idx == MAX_TOOL_ROUNDS: - tool_output += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]" - tool_output = _truncate_tool_output(tool_output) - _cumulative_tool_bytes += len(tool_output) + out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]" + + truncated = _truncate_tool_output(out) + _cumulative_tool_bytes += len(truncated) tool_results_for_history.append({ "role": "tool", - "tool_call_id": tool_id, - "content": tool_output, + "tool_call_id": call_id, + "name": name, + "content": truncated, }) - _append_comms("IN", "tool_result", {"name": tool_name, "id": tool_id, "output": tool_output}) - events.emit("tool_execution", payload={"status": "completed", "tool": tool_name, "result": tool_output, "round": round_idx}) + _append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out}) + events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx}) + if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: tool_results_for_history.append({ "role": "user", @@ -1524,6 +1514,7 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str, except Exception as e: raise _classify_deepseek_error(e) from e + def run_tier4_analysis(stderr: str) -> str: if not stderr or not stderr.strip(): return ""