feat(ai_client): Emit API lifecycle and tool execution events
This commit is contained in:
12
ai_client.py
12
ai_client.py
@@ -19,6 +19,7 @@ from pathlib import Path
|
|||||||
import file_cache
|
import file_cache
|
||||||
import mcp_client
|
import mcp_client
|
||||||
import google.genai
|
import google.genai
|
||||||
|
from google.genai import types
|
||||||
from events import EventEmitter
|
from events import EventEmitter
|
||||||
|
|
||||||
_provider: str = "gemini"
|
_provider: str = "gemini"
|
||||||
@@ -620,6 +621,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, file_items:
|
|||||||
r["output"] = val
|
r["output"] = val
|
||||||
|
|
||||||
for r_idx in range(MAX_TOOL_ROUNDS + 2):
|
for r_idx in range(MAX_TOOL_ROUNDS + 2):
|
||||||
|
events.emit("request_start", payload={"provider": "gemini", "model": _model, "round": r_idx})
|
||||||
resp = _gemini_chat.send_message(payload)
|
resp = _gemini_chat.send_message(payload)
|
||||||
txt = "\n".join(p.text for c in resp.candidates if getattr(c, "content", None) for p in c.content.parts if hasattr(p, "text") and p.text)
|
txt = "\n".join(p.text for c in resp.candidates if getattr(c, "content", None) for p in c.content.parts if hasattr(p, "text") and p.text)
|
||||||
if txt: all_text.append(txt)
|
if txt: all_text.append(txt)
|
||||||
@@ -629,6 +631,8 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, file_items:
|
|||||||
cached_tokens = getattr(resp.usage_metadata, "cached_content_token_count", None)
|
cached_tokens = getattr(resp.usage_metadata, "cached_content_token_count", None)
|
||||||
if cached_tokens:
|
if cached_tokens:
|
||||||
usage["cache_read_input_tokens"] = cached_tokens
|
usage["cache_read_input_tokens"] = cached_tokens
|
||||||
|
|
||||||
|
events.emit("response_received", payload={"provider": "gemini", "model": _model, "usage": usage, "round": r_idx})
|
||||||
reason = resp.candidates[0].finish_reason.name if resp.candidates and hasattr(resp.candidates[0], "finish_reason") else "STOP"
|
reason = resp.candidates[0].finish_reason.name if resp.candidates and hasattr(resp.candidates[0], "finish_reason") else "STOP"
|
||||||
|
|
||||||
_append_comms("IN", "response", {"round": r_idx, "stop_reason": reason, "text": txt, "tool_calls": [{"name": c.name, "args": dict(c.args)} for c in calls], "usage": usage})
|
_append_comms("IN", "response", {"round": r_idx, "stop_reason": reason, "text": txt, "tool_calls": [{"name": c.name, "args": dict(c.args)} for c in calls], "usage": usage})
|
||||||
@@ -662,6 +666,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, file_items:
|
|||||||
f_resps, log = [], []
|
f_resps, log = [], []
|
||||||
for i, fc in enumerate(calls):
|
for i, fc in enumerate(calls):
|
||||||
name, args = fc.name, dict(fc.args)
|
name, args = fc.name, dict(fc.args)
|
||||||
|
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
||||||
if name in mcp_client.TOOL_NAMES:
|
if name in mcp_client.TOOL_NAMES:
|
||||||
_append_comms("OUT", "tool_call", {"name": name, "args": args})
|
_append_comms("OUT", "tool_call", {"name": name, "args": args})
|
||||||
out = mcp_client.dispatch(name, args)
|
out = mcp_client.dispatch(name, args)
|
||||||
@@ -681,6 +686,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, file_items:
|
|||||||
|
|
||||||
f_resps.append(types.Part.from_function_response(name=name, response={"output": out}))
|
f_resps.append(types.Part.from_function_response(name=name, response={"output": out}))
|
||||||
log.append({"tool_use_id": name, "content": out})
|
log.append({"tool_use_id": name, "content": out})
|
||||||
|
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
||||||
|
|
||||||
_append_comms("OUT", "tool_result_send", {"results": log})
|
_append_comms("OUT", "tool_result_send", {"results": log})
|
||||||
payload = f_resps
|
payload = f_resps
|
||||||
@@ -998,6 +1004,7 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
|
|||||||
def _strip_private_keys(history):
|
def _strip_private_keys(history):
|
||||||
return [{k: v for k, v in m.items() if not k.startswith("_")} for m in history]
|
return [{k: v for k, v in m.items() if not k.startswith("_")} for m in history]
|
||||||
|
|
||||||
|
events.emit("request_start", payload={"provider": "anthropic", "model": _model, "round": round_idx})
|
||||||
response = _anthropic_client.messages.create(
|
response = _anthropic_client.messages.create(
|
||||||
model=_model,
|
model=_model,
|
||||||
max_tokens=_max_tokens,
|
max_tokens=_max_tokens,
|
||||||
@@ -1036,6 +1043,8 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
|
|||||||
if cache_read is not None:
|
if cache_read is not None:
|
||||||
usage_dict["cache_read_input_tokens"] = cache_read
|
usage_dict["cache_read_input_tokens"] = cache_read
|
||||||
|
|
||||||
|
events.emit("response_received", payload={"provider": "anthropic", "model": _model, "usage": usage_dict, "round": round_idx})
|
||||||
|
|
||||||
_append_comms("IN", "response", {
|
_append_comms("IN", "response", {
|
||||||
"round": round_idx,
|
"round": round_idx,
|
||||||
"stop_reason": response.stop_reason,
|
"stop_reason": response.stop_reason,
|
||||||
@@ -1059,6 +1068,7 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
|
|||||||
b_name = getattr(block, "name", None)
|
b_name = getattr(block, "name", None)
|
||||||
b_id = getattr(block, "id", "")
|
b_id = getattr(block, "id", "")
|
||||||
b_input = getattr(block, "input", {})
|
b_input = getattr(block, "input", {})
|
||||||
|
events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx})
|
||||||
if b_name in mcp_client.TOOL_NAMES:
|
if b_name in mcp_client.TOOL_NAMES:
|
||||||
_append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input})
|
_append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input})
|
||||||
output = mcp_client.dispatch(b_name, b_input)
|
output = mcp_client.dispatch(b_name, b_input)
|
||||||
@@ -1068,6 +1078,7 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
|
|||||||
"tool_use_id": b_id,
|
"tool_use_id": b_id,
|
||||||
"content": output,
|
"content": output,
|
||||||
})
|
})
|
||||||
|
events.emit("tool_execution", payload={"status": "completed", "tool": b_name, "result": output, "round": round_idx})
|
||||||
elif b_name == TOOL_NAME:
|
elif b_name == TOOL_NAME:
|
||||||
script = b_input.get("script", "")
|
script = b_input.get("script", "")
|
||||||
_append_comms("OUT", "tool_call", {
|
_append_comms("OUT", "tool_call", {
|
||||||
@@ -1086,6 +1097,7 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
|
|||||||
"tool_use_id": b_id,
|
"tool_use_id": b_id,
|
||||||
"content": output,
|
"content": output,
|
||||||
})
|
})
|
||||||
|
events.emit("tool_execution", payload={"status": "completed", "tool": b_name, "result": output, "round": round_idx})
|
||||||
|
|
||||||
# Refresh file context after tool calls — only inject CHANGED files
|
# Refresh file context after tool calls — only inject CHANGED files
|
||||||
if file_items:
|
if file_items:
|
||||||
|
|||||||
@@ -18,3 +18,97 @@ def test_event_emission():
|
|||||||
ai_client.events.emit("request_start", payload={"model": "test"})
|
ai_client.events.emit("request_start", payload={"model": "test"})
|
||||||
|
|
||||||
mock_callback.assert_called_once_with(payload={"model": "test"})
|
mock_callback.assert_called_once_with(payload={"model": "test"})
|
||||||
|
|
||||||
|
def test_send_emits_events():
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
# We need to mock _ensure_gemini_client and the chat object it creates
|
||||||
|
with patch("ai_client._ensure_gemini_client"), \
|
||||||
|
patch("ai_client._gemini_client") as mock_client, \
|
||||||
|
patch("ai_client._gemini_chat") as mock_chat:
|
||||||
|
|
||||||
|
# Setup mock response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.candidates = []
|
||||||
|
# Explicitly set usage_metadata as a mock with integer values
|
||||||
|
mock_usage = MagicMock()
|
||||||
|
mock_usage.prompt_token_count = 10
|
||||||
|
mock_usage.candidates_token_count = 5
|
||||||
|
mock_usage.cached_content_token_count = None
|
||||||
|
mock_response.usage_metadata = mock_usage
|
||||||
|
mock_chat.send_message.return_value = mock_response
|
||||||
|
mock_client.chats.create.return_value = mock_chat
|
||||||
|
|
||||||
|
ai_client.set_provider("gemini", "gemini-flash")
|
||||||
|
|
||||||
|
start_callback = MagicMock()
|
||||||
|
response_callback = MagicMock()
|
||||||
|
|
||||||
|
ai_client.events.on("request_start", start_callback)
|
||||||
|
ai_client.events.on("response_received", response_callback)
|
||||||
|
|
||||||
|
# We need to bypass the context changed check or set it up
|
||||||
|
ai_client.send("context", "message")
|
||||||
|
|
||||||
|
assert start_callback.called
|
||||||
|
assert response_callback.called
|
||||||
|
|
||||||
|
# Check payload
|
||||||
|
args, kwargs = start_callback.call_args
|
||||||
|
assert kwargs['payload']['provider'] == 'gemini'
|
||||||
|
|
||||||
|
def test_send_emits_tool_events():
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
with patch("ai_client._ensure_gemini_client"), \
|
||||||
|
patch("ai_client._gemini_client") as mock_client, \
|
||||||
|
patch("ai_client._gemini_chat") as mock_chat, \
|
||||||
|
patch("mcp_client.dispatch") as mock_dispatch:
|
||||||
|
|
||||||
|
# 1. Setup mock response with a tool call
|
||||||
|
mock_fc = MagicMock()
|
||||||
|
mock_fc.name = "read_file"
|
||||||
|
mock_fc.args = {"path": "test.txt"}
|
||||||
|
|
||||||
|
mock_response_with_tool = MagicMock()
|
||||||
|
mock_response_with_tool.candidates = [MagicMock()]
|
||||||
|
mock_part = MagicMock()
|
||||||
|
mock_part.text = "tool call text"
|
||||||
|
mock_part.function_call = mock_fc
|
||||||
|
mock_response_with_tool.candidates[0].content.parts = [mock_part]
|
||||||
|
mock_response_with_tool.candidates[0].finish_reason.name = "STOP"
|
||||||
|
|
||||||
|
# Setup mock usage
|
||||||
|
mock_usage = MagicMock()
|
||||||
|
mock_usage.prompt_token_count = 10
|
||||||
|
mock_usage.candidates_token_count = 5
|
||||||
|
mock_usage.cached_content_token_count = None
|
||||||
|
mock_response_with_tool.usage_metadata = mock_usage
|
||||||
|
|
||||||
|
# 2. Setup second mock response (final answer)
|
||||||
|
mock_response_final = MagicMock()
|
||||||
|
mock_response_final.candidates = []
|
||||||
|
mock_response_final.usage_metadata = mock_usage
|
||||||
|
|
||||||
|
mock_chat.send_message.side_effect = [mock_response_with_tool, mock_response_final]
|
||||||
|
mock_dispatch.return_value = "file content"
|
||||||
|
|
||||||
|
ai_client.set_provider("gemini", "gemini-flash")
|
||||||
|
|
||||||
|
tool_callback = MagicMock()
|
||||||
|
ai_client.events.on("tool_execution", tool_callback)
|
||||||
|
|
||||||
|
ai_client.send("context", "message")
|
||||||
|
|
||||||
|
# Should be called twice: once for 'started', once for 'completed'
|
||||||
|
assert tool_callback.call_count == 2
|
||||||
|
|
||||||
|
# Check 'started' call
|
||||||
|
args, kwargs = tool_callback.call_args_list[0]
|
||||||
|
assert kwargs['payload']['status'] == 'started'
|
||||||
|
assert kwargs['payload']['tool'] == 'read_file'
|
||||||
|
|
||||||
|
# Check 'completed' call
|
||||||
|
args, kwargs = tool_callback.call_args_list[1]
|
||||||
|
assert kwargs['payload']['status'] == 'completed'
|
||||||
|
assert kwargs['payload']['result'] == 'file content'
|
||||||
|
|||||||
Reference in New Issue
Block a user