diff --git a/src/ai_client.py b/src/ai_client.py index 0e68150d..9d6f1d0b 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -40,6 +40,7 @@ from src import project_manager from src import file_cache from src import mcp_client from src import mcp_tool_specs +from src.openai_schemas import UsageStats from src import mma_prompts from src import performance_monitor from src import project_manager @@ -2051,7 +2052,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, def _send(r_idx: int) -> NormalizedResponse: if adapter is None: - return NormalizedResponse(text="(adapter unavailable)", tool_calls=[], usage_input_tokens=0, usage_output_tokens=0, usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=None) + return NormalizedResponse(text="(adapter unavailable)", tool_calls=(), usage=UsageStats(input_tokens=0, output_tokens=0), raw_response=None) send_result = _send_cli_round_result(r_idx, adapter, payload, safety_settings, sys_instr, stream_callback) if not send_result.ok: raise cast(Exception, send_result.errors[0].original) from None @@ -2085,7 +2086,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, "kind": "history_add", "payload": {"role": "AI", "content": txt} }) - return NormalizedResponse(text=txt, tool_calls=calls, usage_input_tokens=usage.get("prompt_tokens", 0), usage_output_tokens=usage.get("completion_tokens", 0), usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=resp_data) + return NormalizedResponse(text=txt, tool_calls=(), usage=UsageStats(input_tokens=usage.get("prompt_tokens", 0), output_tokens=usage.get("completion_tokens", 0)), raw_response=resp_data) def _pre_dispatch(r_idx: int, calls: list[Metadata]) -> list[Metadata]: nonlocal payload, cumulative_tool_bytes, file_items diff --git a/src/openai_compatible.py b/src/openai_compatible.py index d86246dd..3ab8a1fd 100644 --- a/src/openai_compatible.py +++ b/src/openai_compatible.py @@ -83,7 +83,7 @@ def send_openai_compatible( *, capabilities: Any, ) -> Result[NormalizedResponse]: - messages_dicts = [m.to_dict() for m in request.messages] + messages_dicts = [m.to_dict() if hasattr(m, "to_dict") else m for m in request.messages] kwargs: dict[str, Any] = { "model": request.model, "messages": messages_dicts, diff --git a/tests/test_ai_client_tool_loop.py b/tests/test_ai_client_tool_loop.py index eb576dc6..d09fa7f1 100644 --- a/tests/test_ai_client_tool_loop.py +++ b/tests/test_ai_client_tool_loop.py @@ -26,10 +26,10 @@ def caps() -> VendorCapabilities: return VendorCapabilities(vendor="test", model="test-model", tool_calling=True, context_window=8192) def _make_normalized_response(text: str = "ok", tool_calls: list[dict[str, Any]] | None = None) -> Result[NormalizedResponse]: + from src.openai_schemas import UsageStats return Result(data=NormalizedResponse( - text=text, tool_calls=tool_calls or [], - usage_input_tokens=10, usage_output_tokens=5, - usage_cache_read_tokens=0, usage_cache_creation_tokens=0, + text=text, tool_calls=tool_calls or (), + usage=UsageStats(input_tokens=10, output_tokens=5), raw_response=None, )) diff --git a/tests/test_ai_client_tool_loop_builder.py b/tests/test_ai_client_tool_loop_builder.py index e7fae125..05ed1cc8 100644 --- a/tests/test_ai_client_tool_loop_builder.py +++ b/tests/test_ai_client_tool_loop_builder.py @@ -13,10 +13,10 @@ from src.result_types import Result from src.vendor_capabilities import VendorCapabilities def _make_normalized_response(text: str = "ok", tool_calls: list[dict[str, Any]] | None = None) -> NormalizedResponse: + from src.openai_schemas import UsageStats return NormalizedResponse( - text=text, tool_calls=tool_calls or [], - usage_input_tokens=10, usage_output_tokens=5, - usage_cache_read_tokens=0, usage_cache_creation_tokens=0, + text=text, tool_calls=tool_calls or (), + usage=UsageStats(input_tokens=10, output_tokens=5), raw_response=None, ) diff --git a/tests/test_ai_client_tool_loop_send_func.py b/tests/test_ai_client_tool_loop_send_func.py index d46501f9..904124ec 100644 --- a/tests/test_ai_client_tool_loop_send_func.py +++ b/tests/test_ai_client_tool_loop_send_func.py @@ -11,10 +11,10 @@ from src.ai_client import run_with_tool_loop from src.vendor_capabilities import VendorCapabilities def _make_normalized_response(text: str = "ok", tool_calls: list[dict[str, Any]] | None = None) -> NormalizedResponse: + from src.openai_schemas import UsageStats return NormalizedResponse( - text=text, tool_calls=tool_calls or [], - usage_input_tokens=10, usage_output_tokens=5, - usage_cache_read_tokens=0, usage_cache_creation_tokens=0, + text=text, tool_calls=tool_calls or (), + usage=UsageStats(input_tokens=10, output_tokens=5), raw_response=None, ) diff --git a/tests/test_auto_whitelist.py b/tests/test_auto_whitelist.py index 5ad2c77d..20921535 100644 --- a/tests/test_auto_whitelist.py +++ b/tests/test_auto_whitelist.py @@ -17,7 +17,9 @@ def test_auto_whitelist_keywords(registry_setup: LogRegistry) -> None: reg.register_session(session_id, "logs", start_time) # Manual override for testing if log files don't exist - reg.data[session_id]["whitelisted"] = True + reg.update_session_metadata( + session_id, message_count=0, errors=0, size_kb=0, whitelisted=True, reason="manual override", + ) assert reg.is_session_whitelisted(session_id) is True def test_auto_whitelist_message_count(registry_setup: LogRegistry) -> None: diff --git a/tests/test_openai_compatible.py b/tests/test_openai_compatible.py index ff1dcaeb..0bf44625 100644 --- a/tests/test_openai_compatible.py +++ b/tests/test_openai_compatible.py @@ -5,6 +5,7 @@ from src.openai_compatible import ( OpenAICompatibleRequest, send_openai_compatible, ) +from src.openai_schemas import UsageStats from src.vendor_capabilities import VendorCapabilities, register @pytest.fixture @@ -58,8 +59,8 @@ def test_tool_call_detection_in_blocking_response(caps: VendorCapabilities) -> N kwargs = {"model": "m", "messages": [{"role": "user", "content": "ping"}], "temperature": 0.0, "top_p": 1.0, "max_tokens": 8192, "stream": False} response = _send_blocking(client, kwargs) assert len(response.tool_calls) == 1 - assert response.tool_calls[0]["function"]["name"] == "read_file" - assert response.tool_calls[0]["id"] == "call_1" + assert response.tool_calls[0].function.name == "read_file" + assert response.tool_calls[0].id == "call_1" def test_vision_multimodal_message(caps: VendorCapabilities) -> None: client = MagicMock() @@ -84,6 +85,6 @@ def test_error_classification_429_to_rate_limit(caps: VendorCapabilities) -> Non def test_normalized_response_is_frozen_dataclass() -> None: from dataclasses import FrozenInstanceError - r = NormalizedResponse(text="x", tool_calls=[], usage_input_tokens=0, usage_output_tokens=0, usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=None) + r = NormalizedResponse(text="x", tool_calls=(), usage=UsageStats(input_tokens=0, output_tokens=0), raw_response=None) with pytest.raises(FrozenInstanceError): r.text = "y"