diff --git a/src/ai_client.py b/src/ai_client.py index b2431dfa..d560ea67 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -42,7 +42,7 @@ from src import mcp_client from src import mma_prompts from src import performance_monitor from src import project_manager -from src.openai_compatible import send_openai_compatible, OpenAICompatibleRequest +from src.openai_compatible import send_openai_compatible, OpenAICompatibleRequest, NormalizedResponse from src.vendor_capabilities import VendorCapabilities, get_capabilities # TODO(Ed): Eliminate these? @@ -807,7 +807,7 @@ def run_with_tool_loop( client: Any, request: Union[OpenAICompatibleRequest, Callable[[int], OpenAICompatibleRequest]], *, - capabilities: VendorCapabilities, + capabilities: Optional[VendorCapabilities] = None, pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, qa_callback: Optional[Callable[[str], str]] = None, stream_callback: Optional[Callable[[str], None]] = None, @@ -818,11 +818,17 @@ def run_with_tool_loop( history: Optional[list[dict[str, Any]]] = None, trim_func: Optional[Callable[[list[dict[str, Any]]], None]] = None, reasoning_extractor: Optional[Callable[[Any], str]] = None, + send_func: Optional[Callable[[int], NormalizedResponse]] = None, + on_pre_dispatch: Optional[Callable[[int, list[dict[str, Any]]], list[dict[str, Any]]]] = None, ) -> str: + def _default_send(_round_idx: int) -> NormalizedResponse: + assert capabilities is not None, "capabilities required when send_func is not provided" + return send_openai_compatible(client, request_builder(_round_idx), capabilities=capabilities) request_builder: Callable[[int], OpenAICompatibleRequest] = (request if callable(request) else (lambda _i: request)) + dispatch_send: Callable[[int], NormalizedResponse] = send_func or _default_send response_text: str = "" for _round_idx in range(MAX_TOOL_ROUNDS + 2): - response = send_openai_compatible(client, request_builder(_round_idx), capabilities=capabilities) + response = dispatch_send(_round_idx) reasoning_content: str = reasoning_extractor(response.raw_response) if reasoning_extractor else "" response_text = response.text or "" if history_lock is not None and history is not None: @@ -835,17 +841,21 @@ def run_with_tool_loop( history.append(msg) if not response.tool_calls: break + if on_pre_dispatch is not None: + _adjusted_calls = on_pre_dispatch(_round_idx, response.tool_calls) + else: + _adjusted_calls = response.tool_calls try: loop = asyncio.get_running_loop() results = asyncio.run_coroutine_threadsafe( _execute_tool_calls_concurrently( - response.tool_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback, + _adjusted_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback, ), loop, ).result() except RuntimeError: results = asyncio.run(_execute_tool_calls_concurrently( - response.tool_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback, + _adjusted_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback, )) if history_lock is not None and history is not None: with history_lock: @@ -1756,8 +1766,8 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, global _gemini_cli_adapter try: if _gemini_cli_adapter is None: - _gemini_cli_adapter = GeminiCliAdapter(binary_path="gemini") - adapter = _gemini_cli_adapter + _gemini_cli_adapter = GeminiCliAdapter(binary_path="gemini") + adapter = _gemini_cli_adapter mcp_client.configure(file_items or [], [base_dir]) sys_instr = f"{_get_combined_system_prompt()}\n\n\n{md_content}\n" safety_settings = [{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_ONLY_HIGH'}] @@ -1766,16 +1776,15 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, if discussion_history: payload = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}" all_text: list[str] = [] - _cumulative_tool_bytes = 0 - for r_idx in range(MAX_TOOL_ROUNDS + 2): + cumulative_tool_bytes = 0 + + def _send(r_idx: int) -> NormalizedResponse: if adapter is None: - break + return NormalizedResponse(text="(adapter unavailable)", tool_calls=[], usage_input_tokens=0, usage_output_tokens=0, usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=None) events.emit("request_start", payload={"provider": "gemini_cli", "model": _model, "round": r_idx}) if r_idx > 0: _append_comms("OUT", "request", {"message": f"[CLI] [round {r_idx}] [msg {len(payload)}]"}) - send_payload = payload - if isinstance(payload, list): - send_payload = json.dumps(payload) + send_payload: Any = json.dumps(payload) if isinstance(payload, list) else payload try: resp_data = adapter.send(cast(str, send_payload), safety_settings=safety_settings, system_instruction=sys_instr, model=_model, stream_callback=stream_callback) except Exception as e: @@ -1795,12 +1804,12 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, for c in calls: log_calls.append({"name": c.get("name"), "args": c.get("args"), "id": c.get("id")}) _append_comms("IN", "response", { - "round": r_idx, - "stop_reason": "TOOL_USE" if calls else "STOP", - "text": txt, - "tool_calls": log_calls, - "usage": usage - }) + "round": r_idx, + "stop_reason": "TOOL_USE" if calls else "STOP", + "text": txt, + "tool_calls": log_calls, + "usage": usage + }) if txt and calls: cb = get_comms_log_callback() if cb: @@ -1808,28 +1817,22 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, "ts": project_manager.now_ts(), "direction": "IN", "kind": "history_add", - "payload": { - "role": "AI", - "content": txt - } + "payload": {"role": "AI", "content": txt} }) - if not calls or r_idx > MAX_TOOL_ROUNDS: - break - - # Execute tools concurrently + return NormalizedResponse(text=txt, tool_calls=calls, usage_input_tokens=usage.get("prompt_tokens", 0), usage_output_tokens=usage.get("completion_tokens", 0), usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=resp_data) + + def _pre_dispatch(r_idx: int, calls: list[dict[str, Any]]) -> list[dict[str, Any]]: + nonlocal payload, cumulative_tool_bytes, file_items + tool_results_for_cli: list[dict[str, Any]] = [] + results_iter: list[tuple[str, str, str, str]] = [] + from src.ai_client import _execute_tool_calls_concurrently as _executor try: loop = asyncio.get_running_loop() - results = asyncio.run_coroutine_threadsafe( - _execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback), - loop - ).result() + results_iter = loop.run_until_complete(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback)) if False else asyncio.run_coroutine_threadsafe(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback), loop).result() except RuntimeError: - results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback)) - - tool_results_for_cli: list[dict[str, Any]] = [] - for i, (name, call_id, out, _) in enumerate(results): - # Check if this is the last tool to trigger file refresh - if i == len(results) - 1: + results_iter = asyncio.run(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback)) + for i, (name, call_id, out, _) in enumerate(results_iter): + if i == len(results_iter) - 1: if file_items: file_items, changed = _reread_file_items(file_items) ctx = _build_file_diff_text(changed) @@ -1837,21 +1840,23 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, out += f"\n\n{_get_context_marker()}\n\n{ctx}" if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]" - out = _truncate_tool_output(out) - _cumulative_tool_bytes += len(out) - tool_results_for_cli.append({ - "role": "tool", - "tool_call_id": call_id, - "name": name, - "content": out - }) + cumulative_tool_bytes += len(out) + tool_results_for_cli.append({"role": "tool", "tool_call_id": call_id, "name": name, "content": out}) _append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out}) events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx}) - payload = tool_results_for_cli - if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: - _append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"}) + if cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: + _append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {cumulative_tool_bytes} bytes]"}) + return calls + + run_with_tool_loop( + client=adapter, request=lambda _i: cast(OpenAICompatibleRequest, None), + base_dir=base_dir, vendor_name="gemini_cli", + pre_tool_callback=pre_tool_callback, qa_callback=qa_callback, + stream_callback=stream_callback, patch_callback=patch_callback, + send_func=_send, on_pre_dispatch=_pre_dispatch, + ) final_text = all_text[-1] if all_text else "(No text returned)" return final_text except Exception as e: diff --git a/tests/test_ai_client_tool_loop_send_func.py b/tests/test_ai_client_tool_loop_send_func.py new file mode 100644 index 00000000..d46501f9 --- /dev/null +++ b/tests/test_ai_client_tool_loop_send_func.py @@ -0,0 +1,47 @@ +"""Verify run_with_tool_loop supports a custom send_func for vendors +that don't use send_openai_compatible (gemini_cli, gemini, anthropic, +deepseek). The vendor provides a send_func that returns a +NormalizedResponse, and the helper handles history + dispatch. +""" +from __future__ import annotations +from typing import Any +from unittest.mock import MagicMock, patch +from src.openai_compatible import NormalizedResponse +from src.ai_client import run_with_tool_loop +from src.vendor_capabilities import VendorCapabilities + +def _make_normalized_response(text: str = "ok", tool_calls: list[dict[str, Any]] | None = None) -> NormalizedResponse: + return NormalizedResponse( + text=text, tool_calls=tool_calls or [], + usage_input_tokens=10, usage_output_tokens=5, + usage_cache_read_tokens=0, usage_cache_creation_tokens=0, + raw_response=None, + ) + +def test_run_with_tool_loop_uses_send_func_when_provided() -> None: + client = MagicMock() + def send_func(_round_idx: int) -> NormalizedResponse: + return _make_normalized_response(f"from-send-func-{_round_idx}") + result = run_with_tool_loop( + client, request=lambda _i: MagicMock(), # should be IGNORED + base_dir=".", vendor_name="custom", + send_func=send_func, + ) + assert result == "from-send-func-0" + +def test_run_with_tool_loop_dispatches_via_send_func() -> None: + client = MagicMock() + tool_resp = _make_normalized_response( + "first", tool_calls=[{"id": "c1", "type": "function", "function": {"name": "t", "arguments": "{}"}}] + ) + final = _make_normalized_response("done") + def send_func(round_idx: int) -> NormalizedResponse: + return [tool_resp, final][round_idx] + with patch("src.ai_client._execute_tool_calls_concurrently", return_value=[("t", "c1", "r", "")]) as dispatch: + result = run_with_tool_loop( + client, request=lambda _i: MagicMock(), + base_dir=".", vendor_name="custom", + send_func=send_func, + ) + assert result == "done" + assert dispatch.call_count == 1