diff --git a/ai_client.py b/ai_client.py index 3de4e62..1581083 100644 --- a/ai_client.py +++ b/ai_client.py @@ -18,6 +18,7 @@ import datetime from pathlib import Path import file_cache import mcp_client +import google.genai _provider: str = "gemini" _model: str = "gemini-2.5-flash" @@ -241,6 +242,22 @@ def reset_session(): _CACHED_ANTHROPIC_TOOLS = None file_cache.reset_client() +def get_gemini_cache_stats() -> dict: + """ + Retrieves statistics about the Gemini caches, such as count and total size. + """ + _ensure_gemini_client() + + + caches_iterator = _gemini_client.caches.list() + caches = list(caches_iterator) + + total_size_bytes = sum(c.size_bytes for c in caches) + + return { + "cache_count": len(list(caches)), + "total_size_bytes": total_size_bytes, + } # ------------------------------------------------------------------ model listing @@ -254,9 +271,9 @@ def list_models(provider: str) -> list[str]: def _list_gemini_models(api_key: str) -> list[str]: - from google import genai + # from google import genai # Removed try: - client = genai.Client(api_key=api_key) + client = google.genai.Client(api_key=api_key) models = [] for m in client.models.list(): name = m.name @@ -348,7 +365,7 @@ def _get_anthropic_tools() -> list[dict]: def _gemini_tool_declaration(): - from google.genai import types + # from google.genai import types # Removed declarations = [] @@ -358,15 +375,15 @@ def _gemini_tool_declaration(): continue props = {} for pname, pdef in spec["parameters"].get("properties", {}).items(): - props[pname] = types.Schema( - type=types.Type.STRING, + props[pname] = google.genai.types.Schema( + type=google.genai.types.Type.STRING, description=pdef.get("description", ""), ) - declarations.append(types.FunctionDeclaration( + declarations.append(google.genai.types.FunctionDeclaration( name=spec["name"], description=spec["description"], - parameters=types.Schema( - type=types.Type.OBJECT, + parameters=google.genai.types.Schema( + type=google.genai.types.Type.OBJECT, properties=props, required=spec["parameters"].get("required", []), ), @@ -374,7 +391,7 @@ def _gemini_tool_declaration(): # PowerShell tool if _agent_tools.get(TOOL_NAME, True): - declarations.append(types.FunctionDeclaration( + declarations.append(google.genai.types.FunctionDeclaration( name=TOOL_NAME, description=( "Run a PowerShell script within the project base_dir. " @@ -382,11 +399,11 @@ def _gemini_tool_declaration(): "The working directory is set to base_dir automatically. " "stdout and stderr are returned to you as the result." ), - parameters=types.Schema( - type=types.Type.OBJECT, + parameters=google.genai.types.Schema( + type=google.genai.types.Type.OBJECT, properties={ - "script": types.Schema( - type=types.Type.STRING, + "script": google.genai.types.Schema( + type=google.genai.types.Type.STRING, description="The PowerShell script to execute." ) }, @@ -394,7 +411,7 @@ def _gemini_tool_declaration(): ), )) - return types.Tool(function_declarations=declarations) if declarations else None + return google.genai.types.Tool(function_declarations=declarations) if declarations else None def _run_script(script: str, base_dir: str) -> str: @@ -489,9 +506,9 @@ def _content_block_to_dict(block) -> dict: def _ensure_gemini_client(): global _gemini_client if _gemini_client is None: - from google import genai + # from google import genai # Removed creds = _load_credentials() - _gemini_client = genai.Client(api_key=creds["gemini"]["api_key"]) + _gemini_client = google.genai.Client(api_key=creds["gemini"]["api_key"]) @@ -508,7 +525,7 @@ def _get_gemini_history_list(chat): def _send_gemini(md_content: str, user_message: str, base_dir: str, file_items: list[dict] | None = None) -> str: global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at - from google.genai import types + # from google.genai import types # Removed try: _ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir]) sys_instr = f"{_get_combined_system_prompt()}\n\n\n{md_content}\n" @@ -541,29 +558,29 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, file_items: _append_comms("OUT", "request", {"message": f"[CACHE TTL] Rebuilding cache (expired after {int(elapsed)}s)..."}) if not _gemini_chat: - chat_config = types.GenerateContentConfig( + chat_config = google.genai.types.GenerateContentConfig( system_instruction=sys_instr, tools=tools_decl, temperature=_temperature, max_output_tokens=_max_tokens, - safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")] + safety_settings=[google.genai.types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")] ) try: # Gemini requires 1024 (Flash) or 4096 (Pro) tokens to cache. _gemini_cache = _gemini_client.caches.create( model=_model, - config=types.CreateCachedContentConfig( + config=google.genai.types.CreateCachedContentConfig( system_instruction=sys_instr, tools=tools_decl, ttl=f"{_GEMINI_CACHE_TTL}s", ) ) _gemini_cache_created_at = time.time() - chat_config = types.GenerateContentConfig( + chat_config = google.genai.types.GenerateContentConfig( cached_content=_gemini_cache.name, temperature=_temperature, max_output_tokens=_max_tokens, - safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")] + safety_settings=[google.genai.types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")] ) _append_comms("OUT", "request", {"message": f"[CACHE CREATED] {_gemini_cache.name}"}) except Exception as e: diff --git a/tests/test_gemini_metrics.py b/tests/test_gemini_metrics.py new file mode 100644 index 0000000..409096c --- /dev/null +++ b/tests/test_gemini_metrics.py @@ -0,0 +1,45 @@ +import pytest +from unittest.mock import MagicMock, patch + +# Import the necessary functions from ai_client, including the reset helper +from ai_client import get_gemini_cache_stats, reset_session + +def test_get_gemini_cache_stats_with_mock_client(): + """ + Test that get_gemini_cache_stats correctly processes cache lists + from a mocked client instance. + """ + # Ensure a clean state before the test by resetting the session + reset_session() + + # 1. Create a mock for the cache object that the client will return + mock_cache = MagicMock() + mock_cache.name = "cachedContents/test-cache" + mock_cache.display_name = "Test Cache" + mock_cache.model = "models/gemini-1.5-pro-001" + mock_cache.size_bytes = 1024 + + # 2. Create a mock for the client instance + mock_client_instance = MagicMock() + # Configure its `caches.list` method to return our mock cache + mock_client_instance.caches.list.return_value = [mock_cache] + + # 3. Patch the Client constructor to return our mock instance + # This intercepts the `_ensure_gemini_client` call inside the function + with patch('google.genai.Client', return_value=mock_client_instance) as mock_client_constructor: + + # 4. Call the function under test + stats = get_gemini_cache_stats() + + # 5. Assert that the function behaved as expected + + # It should have constructed the client + mock_client_constructor.assert_called_once() + # It should have called the `list` method on the `caches` attribute + mock_client_instance.caches.list.assert_called_once() + + # The returned stats dictionary should be correct + assert "cache_count" in stats + assert "total_size_bytes" in stats + assert stats["cache_count"] == 1 + assert stats["total_size_bytes"] == 1024