diff --git a/src/ai_client.py b/src/ai_client.py index 00fe1da3..06001c49 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -42,6 +42,8 @@ from src import mcp_client from src import mma_prompts from src import performance_monitor from src import project_manager +from src.openai_compatible import send_openai_compatible +from src.vendor_capabilities import VendorCapabilities # TODO(Ed): Eliminate these? from src.events import EventEmitter @@ -801,6 +803,61 @@ async def _execute_tool_calls_concurrently( if monitor.enabled: monitor.end_component("ai_client._execute_tool_calls_concurrently") return results +def run_with_tool_loop( + client: Any, + request: OpenAICompatibleRequest, + *, + capabilities: VendorCapabilities, + pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, + qa_callback: Optional[Callable[[str], str]] = None, + stream_callback: Optional[Callable[[str], None]] = None, + patch_callback: Optional[Callable[[str, str], Optional[str]]] = None, + base_dir: str, + vendor_name: str, + history_lock: Optional[threading.Lock] = None, + history: Optional[list[dict[str, Any]]] = None, + trim_func: Optional[Callable[[list[dict[str, Any]]], None]] = None, + reasoning_extractor: Optional[Callable[[Any], str]] = None, +) -> str: + response_text: str = "" + for _round_idx in range(MAX_TOOL_ROUNDS + 2): + response = send_openai_compatible(client, request, capabilities=capabilities) + reasoning_content: str = reasoning_extractor(response.raw_response) if reasoning_extractor else "" + response_text = response.text or "" + if history_lock is not None and history is not None: + with history_lock: + msg: dict[str, Any] = {"role": "assistant", "content": response.text or None} + if reasoning_content: + msg["reasoning_content"] = reasoning_content + if response.tool_calls: + msg["tool_calls"] = response.tool_calls + history.append(msg) + if not response.tool_calls: + break + try: + loop = asyncio.get_running_loop() + results = asyncio.run_coroutine_threadsafe( + _execute_tool_calls_concurrently( + response.tool_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback, + ), + loop, + ).result() + except RuntimeError: + results = asyncio.run(_execute_tool_calls_concurrently( + response.tool_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback, + )) + if history_lock is not None and history is not None: + with history_lock: + for _i, (tool_name, call_id, out, _err) in enumerate(results): + history.append({ + "role": "tool", + "tool_call_id": call_id, + "content": str(out) if out else "", + }) + if trim_func is not None: + trim_func(history) + return response_text + async def _execute_single_tool_call_async( name: str, args: dict[str, Any], @@ -812,11 +869,7 @@ async def _execute_single_tool_call_async( tier: str | None = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None ) -> tuple[str, str, str, str]: - """ - [C: tests/test_external_mcp_e2e.py:test_external_mcp_e2e_refresh_and_call, tests/test_external_mcp_hitl.py:test_external_mcp_hitl_approval, tests/test_external_mcp_hitl.py:test_external_mcp_hitl_rejection, tests/test_tool_presets_execution.py:test_tool_ask_approval, tests/test_tool_presets_execution.py:test_tool_auto_approval, tests/test_tool_presets_execution.py:test_tool_rejection] - """ - if tier: - set_current_tier(tier) + set_current_tier(tier) out = "" tool_executed = False events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})