import pytest import sys import os import hashlib from unittest.mock import patch, MagicMock from types import SimpleNamespace # Ensure project root is in path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import ai_client def test_token_usage_tracking() -> None: ai_client.reset_session() ai_client.clear_comms_log() ai_client.set_provider("gemini", "gemini-2.0-flash") # Mock credentials so we don't need a real file with patch("ai_client._load_credentials") as mock_creds: mock_creds.return_value = {"gemini": {"api_key": "fake-key"}} mock_resp = MagicMock() # Use SimpleNamespace to ensure attributes are real values, not Mocks mock_resp.usage_metadata = SimpleNamespace( prompt_token_count=100, candidates_token_count=50, cached_content_token_count=20 ) # Setup candidates mock_candidate = MagicMock() # Use spec to ensure hasattr(p, "function_call") is False for text parts mock_part = MagicMock(spec=["text"]) mock_part.text = "Hello" mock_candidate.content.parts = [mock_part] mock_candidate.finish_reason.name = "STOP" mock_resp.candidates = [mock_candidate] mock_chat = MagicMock() mock_chat.send_message.return_value = mock_resp mock_chat._history = [] # Mock the client creation and storage with patch("google.genai.Client") as mock_client_class: mock_client_instance = mock_client_class.return_value # Mock count_tokens to avoid call during send_gemini mock_client_instance.models.count_tokens.return_value = MagicMock(total_tokens=100) # Mock chats.create to return our mock_chat mock_client_instance.chats.create.return_value = mock_chat ai_client._gemini_client = mock_client_instance ai_client._gemini_chat = mock_chat # Set the hash to prevent chat reset ai_client._gemini_cache_md_hash = hashlib.md5("context".encode()).hexdigest() ai_client.send("context", "hi", enable_tools=False) log = ai_client.get_comms_log() # The log might have 'request' and 'response' entries response_entries = [e for e in log if e["kind"] == "response"] assert len(response_entries) > 0 usage = response_entries[0]["payload"]["usage"] assert usage["input_tokens"] == 100 assert usage["output_tokens"] == 50 assert usage["cache_read_input_tokens"] == 20