feat(ai_client): add send_func + on_pre_dispatch to run_with_tool_loop; refactor _send_gemini_cli
Task 1.7 of the follow-up track. Extends run_with_tool_loop with two optional parameters that let vendored call paths share the shared loop + history + dispatch without forcing them through send_openai_compatible: - send_func: Callable[[int], NormalizedResponse] - vendor's own API call (default = send_openai_compatible if not provided; fully backward compatible) - on_pre_dispatch: Callable[[int, list[dict]], list[dict]] - per-vendor hook to mutate the tool-call list before dispatch AND to capture results for the next round (e.g. Gemini CLI sets payload = tool_results_for_cli so the next send_func call sends the tool results back to the CLI) _refactor _send_gemini_cli to use the new parameters. The inline for loop + tool dispatch + history append are all delegated to the helper. The vendor's send_func closure handles: - adapter.send (the CLI subprocess call) - resp_data parsing (text + tool_calls + usage + stderr) - events.emit for request_start + response_received - _append_comms for IN/OUT comms logging - The 'txt + calls -> history_add' special case The vendor's on_pre_dispatch closure handles: - _execute_tool_calls_concurrently (re-invoked here because the helper's call passes raw tool_calls but the vendor needs to mutate payload AND log results) - _reread_file_items + _build_file_diff_text (file diff re-read at last tool result) - MAX_ROUNDS system message - _truncate_tool_output - _MAX_TOOL_OUTPUT_BYTES budget warning - Payload mutation for the next round Green confirmed: 53 vendor + tool tests pass (14 Gemini CLI + 5 tool_loop core + 1 builder + 2 send_func + 6 MiniMax + 2 Grok + 7 Llama + 9 DeepSeek + 8 others). No regressions.
This commit is contained in:
+50
-45
@@ -42,7 +42,7 @@ from src import mcp_client
|
||||
from src import mma_prompts
|
||||
from src import performance_monitor
|
||||
from src import project_manager
|
||||
from src.openai_compatible import send_openai_compatible, OpenAICompatibleRequest
|
||||
from src.openai_compatible import send_openai_compatible, OpenAICompatibleRequest, NormalizedResponse
|
||||
from src.vendor_capabilities import VendorCapabilities, get_capabilities
|
||||
|
||||
# TODO(Ed): Eliminate these?
|
||||
@@ -807,7 +807,7 @@ def run_with_tool_loop(
|
||||
client: Any,
|
||||
request: Union[OpenAICompatibleRequest, Callable[[int], OpenAICompatibleRequest]],
|
||||
*,
|
||||
capabilities: VendorCapabilities,
|
||||
capabilities: Optional[VendorCapabilities] = None,
|
||||
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
||||
qa_callback: Optional[Callable[[str], str]] = None,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
@@ -818,11 +818,17 @@ def run_with_tool_loop(
|
||||
history: Optional[list[dict[str, Any]]] = None,
|
||||
trim_func: Optional[Callable[[list[dict[str, Any]]], None]] = None,
|
||||
reasoning_extractor: Optional[Callable[[Any], str]] = None,
|
||||
send_func: Optional[Callable[[int], NormalizedResponse]] = None,
|
||||
on_pre_dispatch: Optional[Callable[[int, list[dict[str, Any]]], list[dict[str, Any]]]] = None,
|
||||
) -> str:
|
||||
def _default_send(_round_idx: int) -> NormalizedResponse:
|
||||
assert capabilities is not None, "capabilities required when send_func is not provided"
|
||||
return send_openai_compatible(client, request_builder(_round_idx), capabilities=capabilities)
|
||||
request_builder: Callable[[int], OpenAICompatibleRequest] = (request if callable(request) else (lambda _i: request))
|
||||
dispatch_send: Callable[[int], NormalizedResponse] = send_func or _default_send
|
||||
response_text: str = ""
|
||||
for _round_idx in range(MAX_TOOL_ROUNDS + 2):
|
||||
response = send_openai_compatible(client, request_builder(_round_idx), capabilities=capabilities)
|
||||
response = dispatch_send(_round_idx)
|
||||
reasoning_content: str = reasoning_extractor(response.raw_response) if reasoning_extractor else ""
|
||||
response_text = response.text or ""
|
||||
if history_lock is not None and history is not None:
|
||||
@@ -835,17 +841,21 @@ def run_with_tool_loop(
|
||||
history.append(msg)
|
||||
if not response.tool_calls:
|
||||
break
|
||||
if on_pre_dispatch is not None:
|
||||
_adjusted_calls = on_pre_dispatch(_round_idx, response.tool_calls)
|
||||
else:
|
||||
_adjusted_calls = response.tool_calls
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
results = asyncio.run_coroutine_threadsafe(
|
||||
_execute_tool_calls_concurrently(
|
||||
response.tool_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback,
|
||||
_adjusted_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback,
|
||||
),
|
||||
loop,
|
||||
).result()
|
||||
except RuntimeError:
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(
|
||||
response.tool_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback,
|
||||
_adjusted_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback,
|
||||
))
|
||||
if history_lock is not None and history is not None:
|
||||
with history_lock:
|
||||
@@ -1766,16 +1776,15 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||
if discussion_history:
|
||||
payload = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"
|
||||
all_text: list[str] = []
|
||||
_cumulative_tool_bytes = 0
|
||||
for r_idx in range(MAX_TOOL_ROUNDS + 2):
|
||||
cumulative_tool_bytes = 0
|
||||
|
||||
def _send(r_idx: int) -> NormalizedResponse:
|
||||
if adapter is None:
|
||||
break
|
||||
return NormalizedResponse(text="(adapter unavailable)", tool_calls=[], usage_input_tokens=0, usage_output_tokens=0, usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=None)
|
||||
events.emit("request_start", payload={"provider": "gemini_cli", "model": _model, "round": r_idx})
|
||||
if r_idx > 0:
|
||||
_append_comms("OUT", "request", {"message": f"[CLI] [round {r_idx}] [msg {len(payload)}]"})
|
||||
send_payload = payload
|
||||
if isinstance(payload, list):
|
||||
send_payload = json.dumps(payload)
|
||||
send_payload: Any = json.dumps(payload) if isinstance(payload, list) else payload
|
||||
try:
|
||||
resp_data = adapter.send(cast(str, send_payload), safety_settings=safety_settings, system_instruction=sys_instr, model=_model, stream_callback=stream_callback)
|
||||
except Exception as e:
|
||||
@@ -1795,12 +1804,12 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||
for c in calls:
|
||||
log_calls.append({"name": c.get("name"), "args": c.get("args"), "id": c.get("id")})
|
||||
_append_comms("IN", "response", {
|
||||
"round": r_idx,
|
||||
"stop_reason": "TOOL_USE" if calls else "STOP",
|
||||
"text": txt,
|
||||
"tool_calls": log_calls,
|
||||
"usage": usage
|
||||
})
|
||||
"round": r_idx,
|
||||
"stop_reason": "TOOL_USE" if calls else "STOP",
|
||||
"text": txt,
|
||||
"tool_calls": log_calls,
|
||||
"usage": usage
|
||||
})
|
||||
if txt and calls:
|
||||
cb = get_comms_log_callback()
|
||||
if cb:
|
||||
@@ -1808,28 +1817,22 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||
"ts": project_manager.now_ts(),
|
||||
"direction": "IN",
|
||||
"kind": "history_add",
|
||||
"payload": {
|
||||
"role": "AI",
|
||||
"content": txt
|
||||
}
|
||||
"payload": {"role": "AI", "content": txt}
|
||||
})
|
||||
if not calls or r_idx > MAX_TOOL_ROUNDS:
|
||||
break
|
||||
return NormalizedResponse(text=txt, tool_calls=calls, usage_input_tokens=usage.get("prompt_tokens", 0), usage_output_tokens=usage.get("completion_tokens", 0), usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=resp_data)
|
||||
|
||||
# Execute tools concurrently
|
||||
def _pre_dispatch(r_idx: int, calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
nonlocal payload, cumulative_tool_bytes, file_items
|
||||
tool_results_for_cli: list[dict[str, Any]] = []
|
||||
results_iter: list[tuple[str, str, str, str]] = []
|
||||
from src.ai_client import _execute_tool_calls_concurrently as _executor
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
results = asyncio.run_coroutine_threadsafe(
|
||||
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback),
|
||||
loop
|
||||
).result()
|
||||
results_iter = loop.run_until_complete(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback)) if False else asyncio.run_coroutine_threadsafe(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback), loop).result()
|
||||
except RuntimeError:
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback))
|
||||
|
||||
tool_results_for_cli: list[dict[str, Any]] = []
|
||||
for i, (name, call_id, out, _) in enumerate(results):
|
||||
# Check if this is the last tool to trigger file refresh
|
||||
if i == len(results) - 1:
|
||||
results_iter = asyncio.run(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback))
|
||||
for i, (name, call_id, out, _) in enumerate(results_iter):
|
||||
if i == len(results_iter) - 1:
|
||||
if file_items:
|
||||
file_items, changed = _reread_file_items(file_items)
|
||||
ctx = _build_file_diff_text(changed)
|
||||
@@ -1837,21 +1840,23 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||
out += f"\n\n{_get_context_marker()}\n\n{ctx}"
|
||||
if r_idx == MAX_TOOL_ROUNDS:
|
||||
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
||||
|
||||
out = _truncate_tool_output(out)
|
||||
_cumulative_tool_bytes += len(out)
|
||||
tool_results_for_cli.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": name,
|
||||
"content": out
|
||||
})
|
||||
cumulative_tool_bytes += len(out)
|
||||
tool_results_for_cli.append({"role": "tool", "tool_call_id": call_id, "name": name, "content": out})
|
||||
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
||||
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
||||
|
||||
payload = tool_results_for_cli
|
||||
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
|
||||
if cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {cumulative_tool_bytes} bytes]"})
|
||||
return calls
|
||||
|
||||
run_with_tool_loop(
|
||||
client=adapter, request=lambda _i: cast(OpenAICompatibleRequest, None),
|
||||
base_dir=base_dir, vendor_name="gemini_cli",
|
||||
pre_tool_callback=pre_tool_callback, qa_callback=qa_callback,
|
||||
stream_callback=stream_callback, patch_callback=patch_callback,
|
||||
send_func=_send, on_pre_dispatch=_pre_dispatch,
|
||||
)
|
||||
final_text = all_text[-1] if all_text else "(No text returned)"
|
||||
return final_text
|
||||
except Exception as e:
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Verify run_with_tool_loop supports a custom send_func for vendors
|
||||
that don't use send_openai_compatible (gemini_cli, gemini, anthropic,
|
||||
deepseek). The vendor provides a send_func that returns a
|
||||
NormalizedResponse, and the helper handles history + dispatch.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from src.openai_compatible import NormalizedResponse
|
||||
from src.ai_client import run_with_tool_loop
|
||||
from src.vendor_capabilities import VendorCapabilities
|
||||
|
||||
def _make_normalized_response(text: str = "ok", tool_calls: list[dict[str, Any]] | None = None) -> NormalizedResponse:
|
||||
return NormalizedResponse(
|
||||
text=text, tool_calls=tool_calls or [],
|
||||
usage_input_tokens=10, usage_output_tokens=5,
|
||||
usage_cache_read_tokens=0, usage_cache_creation_tokens=0,
|
||||
raw_response=None,
|
||||
)
|
||||
|
||||
def test_run_with_tool_loop_uses_send_func_when_provided() -> None:
|
||||
client = MagicMock()
|
||||
def send_func(_round_idx: int) -> NormalizedResponse:
|
||||
return _make_normalized_response(f"from-send-func-{_round_idx}")
|
||||
result = run_with_tool_loop(
|
||||
client, request=lambda _i: MagicMock(), # should be IGNORED
|
||||
base_dir=".", vendor_name="custom",
|
||||
send_func=send_func,
|
||||
)
|
||||
assert result == "from-send-func-0"
|
||||
|
||||
def test_run_with_tool_loop_dispatches_via_send_func() -> None:
|
||||
client = MagicMock()
|
||||
tool_resp = _make_normalized_response(
|
||||
"first", tool_calls=[{"id": "c1", "type": "function", "function": {"name": "t", "arguments": "{}"}}]
|
||||
)
|
||||
final = _make_normalized_response("done")
|
||||
def send_func(round_idx: int) -> NormalizedResponse:
|
||||
return [tool_resp, final][round_idx]
|
||||
with patch("src.ai_client._execute_tool_calls_concurrently", return_value=[("t", "c1", "r", "")]) as dispatch:
|
||||
result = run_with_tool_loop(
|
||||
client, request=lambda _i: MagicMock(),
|
||||
base_dir=".", vendor_name="custom",
|
||||
send_func=send_func,
|
||||
)
|
||||
assert result == "done"
|
||||
assert dispatch.call_count == 1
|
||||
Reference in New Issue
Block a user