Private
Public Access
0
0

feat(ai_client): add send_func + on_pre_dispatch to run_with_tool_loop; refactor _send_gemini_cli

Task 1.7 of the follow-up track. Extends run_with_tool_loop with
two optional parameters that let vendored call paths share the
shared loop + history + dispatch without forcing them through
send_openai_compatible:

- send_func: Callable[[int], NormalizedResponse] - vendor's own
  API call (default = send_openai_compatible if not provided;
  fully backward compatible)
- on_pre_dispatch: Callable[[int, list[dict]], list[dict]] -
  per-vendor hook to mutate the tool-call list before dispatch
  AND to capture results for the next round (e.g. Gemini CLI
  sets payload = tool_results_for_cli so the next send_func
  call sends the tool results back to the CLI)

_refactor _send_gemini_cli to use the new parameters. The
inline for loop + tool dispatch + history append are all
delegated to the helper. The vendor's send_func closure
handles:
- adapter.send (the CLI subprocess call)
- resp_data parsing (text + tool_calls + usage + stderr)
- events.emit for request_start + response_received
- _append_comms for IN/OUT comms logging
- The 'txt + calls -> history_add' special case

The vendor's on_pre_dispatch closure handles:
- _execute_tool_calls_concurrently (re-invoked here because
  the helper's call passes raw tool_calls but the vendor
  needs to mutate payload AND log results)
- _reread_file_items + _build_file_diff_text (file diff
  re-read at last tool result)
- MAX_ROUNDS system message
- _truncate_tool_output
- _MAX_TOOL_OUTPUT_BYTES budget warning
- Payload mutation for the next round

Green confirmed: 53 vendor + tool tests pass (14 Gemini CLI
+ 5 tool_loop core + 1 builder + 2 send_func + 6 MiniMax +
2 Grok + 7 Llama + 9 DeepSeek + 8 others). No regressions.
This commit is contained in:
2026-06-11 14:48:03 -04:00
parent 777b04434c
commit 4748d13490
2 changed files with 100 additions and 48 deletions
+50 -45
View File
@@ -42,7 +42,7 @@ from src import mcp_client
from src import mma_prompts from src import mma_prompts
from src import performance_monitor from src import performance_monitor
from src import project_manager from src import project_manager
from src.openai_compatible import send_openai_compatible, OpenAICompatibleRequest from src.openai_compatible import send_openai_compatible, OpenAICompatibleRequest, NormalizedResponse
from src.vendor_capabilities import VendorCapabilities, get_capabilities from src.vendor_capabilities import VendorCapabilities, get_capabilities
# TODO(Ed): Eliminate these? # TODO(Ed): Eliminate these?
@@ -807,7 +807,7 @@ def run_with_tool_loop(
client: Any, client: Any,
request: Union[OpenAICompatibleRequest, Callable[[int], OpenAICompatibleRequest]], request: Union[OpenAICompatibleRequest, Callable[[int], OpenAICompatibleRequest]],
*, *,
capabilities: VendorCapabilities, capabilities: Optional[VendorCapabilities] = None,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
qa_callback: Optional[Callable[[str], str]] = None, qa_callback: Optional[Callable[[str], str]] = None,
stream_callback: Optional[Callable[[str], None]] = None, stream_callback: Optional[Callable[[str], None]] = None,
@@ -818,11 +818,17 @@ def run_with_tool_loop(
history: Optional[list[dict[str, Any]]] = None, history: Optional[list[dict[str, Any]]] = None,
trim_func: Optional[Callable[[list[dict[str, Any]]], None]] = None, trim_func: Optional[Callable[[list[dict[str, Any]]], None]] = None,
reasoning_extractor: Optional[Callable[[Any], str]] = None, reasoning_extractor: Optional[Callable[[Any], str]] = None,
send_func: Optional[Callable[[int], NormalizedResponse]] = None,
on_pre_dispatch: Optional[Callable[[int, list[dict[str, Any]]], list[dict[str, Any]]]] = None,
) -> str: ) -> str:
def _default_send(_round_idx: int) -> NormalizedResponse:
assert capabilities is not None, "capabilities required when send_func is not provided"
return send_openai_compatible(client, request_builder(_round_idx), capabilities=capabilities)
request_builder: Callable[[int], OpenAICompatibleRequest] = (request if callable(request) else (lambda _i: request)) request_builder: Callable[[int], OpenAICompatibleRequest] = (request if callable(request) else (lambda _i: request))
dispatch_send: Callable[[int], NormalizedResponse] = send_func or _default_send
response_text: str = "" response_text: str = ""
for _round_idx in range(MAX_TOOL_ROUNDS + 2): for _round_idx in range(MAX_TOOL_ROUNDS + 2):
response = send_openai_compatible(client, request_builder(_round_idx), capabilities=capabilities) response = dispatch_send(_round_idx)
reasoning_content: str = reasoning_extractor(response.raw_response) if reasoning_extractor else "" reasoning_content: str = reasoning_extractor(response.raw_response) if reasoning_extractor else ""
response_text = response.text or "" response_text = response.text or ""
if history_lock is not None and history is not None: if history_lock is not None and history is not None:
@@ -835,17 +841,21 @@ def run_with_tool_loop(
history.append(msg) history.append(msg)
if not response.tool_calls: if not response.tool_calls:
break break
if on_pre_dispatch is not None:
_adjusted_calls = on_pre_dispatch(_round_idx, response.tool_calls)
else:
_adjusted_calls = response.tool_calls
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe( results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently( _execute_tool_calls_concurrently(
response.tool_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback, _adjusted_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback,
), ),
loop, loop,
).result() ).result()
except RuntimeError: except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently( results = asyncio.run(_execute_tool_calls_concurrently(
response.tool_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback, _adjusted_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback,
)) ))
if history_lock is not None and history is not None: if history_lock is not None and history is not None:
with history_lock: with history_lock:
@@ -1766,16 +1776,15 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
if discussion_history: if discussion_history:
payload = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}" payload = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"
all_text: list[str] = [] all_text: list[str] = []
_cumulative_tool_bytes = 0 cumulative_tool_bytes = 0
for r_idx in range(MAX_TOOL_ROUNDS + 2):
def _send(r_idx: int) -> NormalizedResponse:
if adapter is None: if adapter is None:
break return NormalizedResponse(text="(adapter unavailable)", tool_calls=[], usage_input_tokens=0, usage_output_tokens=0, usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=None)
events.emit("request_start", payload={"provider": "gemini_cli", "model": _model, "round": r_idx}) events.emit("request_start", payload={"provider": "gemini_cli", "model": _model, "round": r_idx})
if r_idx > 0: if r_idx > 0:
_append_comms("OUT", "request", {"message": f"[CLI] [round {r_idx}] [msg {len(payload)}]"}) _append_comms("OUT", "request", {"message": f"[CLI] [round {r_idx}] [msg {len(payload)}]"})
send_payload = payload send_payload: Any = json.dumps(payload) if isinstance(payload, list) else payload
if isinstance(payload, list):
send_payload = json.dumps(payload)
try: try:
resp_data = adapter.send(cast(str, send_payload), safety_settings=safety_settings, system_instruction=sys_instr, model=_model, stream_callback=stream_callback) resp_data = adapter.send(cast(str, send_payload), safety_settings=safety_settings, system_instruction=sys_instr, model=_model, stream_callback=stream_callback)
except Exception as e: except Exception as e:
@@ -1795,12 +1804,12 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
for c in calls: for c in calls:
log_calls.append({"name": c.get("name"), "args": c.get("args"), "id": c.get("id")}) log_calls.append({"name": c.get("name"), "args": c.get("args"), "id": c.get("id")})
_append_comms("IN", "response", { _append_comms("IN", "response", {
"round": r_idx, "round": r_idx,
"stop_reason": "TOOL_USE" if calls else "STOP", "stop_reason": "TOOL_USE" if calls else "STOP",
"text": txt, "text": txt,
"tool_calls": log_calls, "tool_calls": log_calls,
"usage": usage "usage": usage
}) })
if txt and calls: if txt and calls:
cb = get_comms_log_callback() cb = get_comms_log_callback()
if cb: if cb:
@@ -1808,28 +1817,22 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
"ts": project_manager.now_ts(), "ts": project_manager.now_ts(),
"direction": "IN", "direction": "IN",
"kind": "history_add", "kind": "history_add",
"payload": { "payload": {"role": "AI", "content": txt}
"role": "AI",
"content": txt
}
}) })
if not calls or r_idx > MAX_TOOL_ROUNDS: return NormalizedResponse(text=txt, tool_calls=calls, usage_input_tokens=usage.get("prompt_tokens", 0), usage_output_tokens=usage.get("completion_tokens", 0), usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=resp_data)
break
# Execute tools concurrently def _pre_dispatch(r_idx: int, calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
nonlocal payload, cumulative_tool_bytes, file_items
tool_results_for_cli: list[dict[str, Any]] = []
results_iter: list[tuple[str, str, str, str]] = []
from src.ai_client import _execute_tool_calls_concurrently as _executor
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe( results_iter = loop.run_until_complete(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback)) if False else asyncio.run_coroutine_threadsafe(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback), loop).result()
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback),
loop
).result()
except RuntimeError: except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback)) results_iter = asyncio.run(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback))
for i, (name, call_id, out, _) in enumerate(results_iter):
tool_results_for_cli: list[dict[str, Any]] = [] if i == len(results_iter) - 1:
for i, (name, call_id, out, _) in enumerate(results):
# Check if this is the last tool to trigger file refresh
if i == len(results) - 1:
if file_items: if file_items:
file_items, changed = _reread_file_items(file_items) file_items, changed = _reread_file_items(file_items)
ctx = _build_file_diff_text(changed) ctx = _build_file_diff_text(changed)
@@ -1837,21 +1840,23 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
out += f"\n\n{_get_context_marker()}\n\n{ctx}" out += f"\n\n{_get_context_marker()}\n\n{ctx}"
if r_idx == MAX_TOOL_ROUNDS: if r_idx == MAX_TOOL_ROUNDS:
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]" out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
out = _truncate_tool_output(out) out = _truncate_tool_output(out)
_cumulative_tool_bytes += len(out) cumulative_tool_bytes += len(out)
tool_results_for_cli.append({ tool_results_for_cli.append({"role": "tool", "tool_call_id": call_id, "name": name, "content": out})
"role": "tool",
"tool_call_id": call_id,
"name": name,
"content": out
})
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out}) _append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx}) events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
payload = tool_results_for_cli payload = tool_results_for_cli
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: if cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"}) _append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {cumulative_tool_bytes} bytes]"})
return calls
run_with_tool_loop(
client=adapter, request=lambda _i: cast(OpenAICompatibleRequest, None),
base_dir=base_dir, vendor_name="gemini_cli",
pre_tool_callback=pre_tool_callback, qa_callback=qa_callback,
stream_callback=stream_callback, patch_callback=patch_callback,
send_func=_send, on_pre_dispatch=_pre_dispatch,
)
final_text = all_text[-1] if all_text else "(No text returned)" final_text = all_text[-1] if all_text else "(No text returned)"
return final_text return final_text
except Exception as e: except Exception as e:
@@ -0,0 +1,47 @@
"""Verify run_with_tool_loop supports a custom send_func for vendors
that don't use send_openai_compatible (gemini_cli, gemini, anthropic,
deepseek). The vendor provides a send_func that returns a
NormalizedResponse, and the helper handles history + dispatch.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock, patch
from src.openai_compatible import NormalizedResponse
from src.ai_client import run_with_tool_loop
from src.vendor_capabilities import VendorCapabilities
def _make_normalized_response(text: str = "ok", tool_calls: list[dict[str, Any]] | None = None) -> NormalizedResponse:
return NormalizedResponse(
text=text, tool_calls=tool_calls or [],
usage_input_tokens=10, usage_output_tokens=5,
usage_cache_read_tokens=0, usage_cache_creation_tokens=0,
raw_response=None,
)
def test_run_with_tool_loop_uses_send_func_when_provided() -> None:
client = MagicMock()
def send_func(_round_idx: int) -> NormalizedResponse:
return _make_normalized_response(f"from-send-func-{_round_idx}")
result = run_with_tool_loop(
client, request=lambda _i: MagicMock(), # should be IGNORED
base_dir=".", vendor_name="custom",
send_func=send_func,
)
assert result == "from-send-func-0"
def test_run_with_tool_loop_dispatches_via_send_func() -> None:
client = MagicMock()
tool_resp = _make_normalized_response(
"first", tool_calls=[{"id": "c1", "type": "function", "function": {"name": "t", "arguments": "{}"}}]
)
final = _make_normalized_response("done")
def send_func(round_idx: int) -> NormalizedResponse:
return [tool_resp, final][round_idx]
with patch("src.ai_client._execute_tool_calls_concurrently", return_value=[("t", "c1", "r", "")]) as dispatch:
result = run_with_tool_loop(
client, request=lambda _i: MagicMock(),
base_dir=".", vendor_name="custom",
send_func=send_func,
)
assert result == "done"
assert dispatch.call_count == 1