diff --git a/tests/test_agent_capabilities.py b/tests/test_agent_capabilities.py index 0b6a098..c9e40a3 100644 --- a/tests/test_agent_capabilities.py +++ b/tests/test_agent_capabilities.py @@ -1,10 +1,43 @@ import pytest import sys import os +from unittest.mock import patch, MagicMock # Ensure project root is in path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +import ai_client def test_agent_capabilities_listing() -> None: - pytest.fail("TODO: Implement assertions") + # Mock credentials + with patch("ai_client._load_credentials") as mock_creds: + mock_creds.return_value = {"gemini": {"api_key": "fake-key"}} + + # Mock the google-genai Client and models.list + with patch("google.genai.Client") as mock_client_class: + mock_client = mock_client_class.return_value + + # Create a list of mock models + mock_models = [] + model_names = [ + "models/gemini-2.0-flash", + "models/gemini-2.0-flash-lite", + "models/gemini-1.5-pro", + "models/gemini-1.5-flash", + "models/gemini-pro", + "models/gemini-ultra" + ] + for name in model_names: + m = MagicMock() + m.name = name + mock_models.append(m) + + mock_client.models.list.return_value = mock_models + + models = ai_client.list_models("gemini") + + assert len(models) >= 6 + assert "gemini-2.0-flash" in models + # Check that it stripped "models/" prefix + for m in models: + assert not m.startswith("models/") diff --git a/tests/test_token_usage.py b/tests/test_token_usage.py index 70717c9..7d8fc00 100644 --- a/tests/test_token_usage.py +++ b/tests/test_token_usage.py @@ -1,6 +1,9 @@ import pytest import sys import os +import hashlib +from unittest.mock import patch, MagicMock +from types import SimpleNamespace # Ensure project root is in path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -9,6 +12,54 @@ import ai_client def test_token_usage_tracking() -> None: ai_client.reset_session() - # Mock an API response with token usage - # This would test the internal accumulator in ai_client - pytest.fail("TODO: Implement assertions") + ai_client.clear_comms_log() + ai_client.set_provider("gemini", "gemini-2.0-flash") + + # Mock credentials so we don't need a real file + with patch("ai_client._load_credentials") as mock_creds: + mock_creds.return_value = {"gemini": {"api_key": "fake-key"}} + + mock_resp = MagicMock() + # Use SimpleNamespace to ensure attributes are real values, not Mocks + mock_resp.usage_metadata = SimpleNamespace( + prompt_token_count=100, + candidates_token_count=50, + cached_content_token_count=20 + ) + + # Setup candidates + mock_candidate = MagicMock() + # Use spec to ensure hasattr(p, "function_call") is False for text parts + mock_part = MagicMock(spec=["text"]) + mock_part.text = "Hello" + mock_candidate.content.parts = [mock_part] + mock_candidate.finish_reason.name = "STOP" + mock_resp.candidates = [mock_candidate] + + mock_chat = MagicMock() + mock_chat.send_message.return_value = mock_resp + mock_chat._history = [] + + # Mock the client creation and storage + with patch("google.genai.Client") as mock_client_class: + mock_client_instance = mock_client_class.return_value + # Mock count_tokens to avoid call during send_gemini + mock_client_instance.models.count_tokens.return_value = MagicMock(total_tokens=100) + # Mock chats.create to return our mock_chat + mock_client_instance.chats.create.return_value = mock_chat + + ai_client._gemini_client = mock_client_instance + ai_client._gemini_chat = mock_chat + # Set the hash to prevent chat reset + ai_client._gemini_cache_md_hash = hashlib.md5("context".encode()).hexdigest() + + ai_client.send("context", "hi", enable_tools=False) + + log = ai_client.get_comms_log() + # The log might have 'request' and 'response' entries + response_entries = [e for e in log if e["kind"] == "response"] + assert len(response_entries) > 0 + usage = response_entries[0]["payload"]["usage"] + assert usage["input_tokens"] == 100 + assert usage["output_tokens"] == 50 + assert usage["cache_read_input_tokens"] == 20