fix(conductor): Implement Gemini cache metrics
This change corrects the implementation of get_gemini_cache_stats to use the Gemini client instance and updates the corresponding test to use proper mocking.
This commit is contained in:
61
ai_client.py
61
ai_client.py
@@ -18,6 +18,7 @@ import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import file_cache
|
import file_cache
|
||||||
import mcp_client
|
import mcp_client
|
||||||
|
import google.genai
|
||||||
|
|
||||||
_provider: str = "gemini"
|
_provider: str = "gemini"
|
||||||
_model: str = "gemini-2.5-flash"
|
_model: str = "gemini-2.5-flash"
|
||||||
@@ -241,6 +242,22 @@ def reset_session():
|
|||||||
_CACHED_ANTHROPIC_TOOLS = None
|
_CACHED_ANTHROPIC_TOOLS = None
|
||||||
file_cache.reset_client()
|
file_cache.reset_client()
|
||||||
|
|
||||||
|
def get_gemini_cache_stats() -> dict:
|
||||||
|
"""
|
||||||
|
Retrieves statistics about the Gemini caches, such as count and total size.
|
||||||
|
"""
|
||||||
|
_ensure_gemini_client()
|
||||||
|
|
||||||
|
|
||||||
|
caches_iterator = _gemini_client.caches.list()
|
||||||
|
caches = list(caches_iterator)
|
||||||
|
|
||||||
|
total_size_bytes = sum(c.size_bytes for c in caches)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"cache_count": len(list(caches)),
|
||||||
|
"total_size_bytes": total_size_bytes,
|
||||||
|
}
|
||||||
|
|
||||||
# ------------------------------------------------------------------ model listing
|
# ------------------------------------------------------------------ model listing
|
||||||
|
|
||||||
@@ -254,9 +271,9 @@ def list_models(provider: str) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _list_gemini_models(api_key: str) -> list[str]:
|
def _list_gemini_models(api_key: str) -> list[str]:
|
||||||
from google import genai
|
# from google import genai # Removed
|
||||||
try:
|
try:
|
||||||
client = genai.Client(api_key=api_key)
|
client = google.genai.Client(api_key=api_key)
|
||||||
models = []
|
models = []
|
||||||
for m in client.models.list():
|
for m in client.models.list():
|
||||||
name = m.name
|
name = m.name
|
||||||
@@ -348,7 +365,7 @@ def _get_anthropic_tools() -> list[dict]:
|
|||||||
|
|
||||||
|
|
||||||
def _gemini_tool_declaration():
|
def _gemini_tool_declaration():
|
||||||
from google.genai import types
|
# from google.genai import types # Removed
|
||||||
|
|
||||||
declarations = []
|
declarations = []
|
||||||
|
|
||||||
@@ -358,15 +375,15 @@ def _gemini_tool_declaration():
|
|||||||
continue
|
continue
|
||||||
props = {}
|
props = {}
|
||||||
for pname, pdef in spec["parameters"].get("properties", {}).items():
|
for pname, pdef in spec["parameters"].get("properties", {}).items():
|
||||||
props[pname] = types.Schema(
|
props[pname] = google.genai.types.Schema(
|
||||||
type=types.Type.STRING,
|
type=google.genai.types.Type.STRING,
|
||||||
description=pdef.get("description", ""),
|
description=pdef.get("description", ""),
|
||||||
)
|
)
|
||||||
declarations.append(types.FunctionDeclaration(
|
declarations.append(google.genai.types.FunctionDeclaration(
|
||||||
name=spec["name"],
|
name=spec["name"],
|
||||||
description=spec["description"],
|
description=spec["description"],
|
||||||
parameters=types.Schema(
|
parameters=google.genai.types.Schema(
|
||||||
type=types.Type.OBJECT,
|
type=google.genai.types.Type.OBJECT,
|
||||||
properties=props,
|
properties=props,
|
||||||
required=spec["parameters"].get("required", []),
|
required=spec["parameters"].get("required", []),
|
||||||
),
|
),
|
||||||
@@ -374,7 +391,7 @@ def _gemini_tool_declaration():
|
|||||||
|
|
||||||
# PowerShell tool
|
# PowerShell tool
|
||||||
if _agent_tools.get(TOOL_NAME, True):
|
if _agent_tools.get(TOOL_NAME, True):
|
||||||
declarations.append(types.FunctionDeclaration(
|
declarations.append(google.genai.types.FunctionDeclaration(
|
||||||
name=TOOL_NAME,
|
name=TOOL_NAME,
|
||||||
description=(
|
description=(
|
||||||
"Run a PowerShell script within the project base_dir. "
|
"Run a PowerShell script within the project base_dir. "
|
||||||
@@ -382,11 +399,11 @@ def _gemini_tool_declaration():
|
|||||||
"The working directory is set to base_dir automatically. "
|
"The working directory is set to base_dir automatically. "
|
||||||
"stdout and stderr are returned to you as the result."
|
"stdout and stderr are returned to you as the result."
|
||||||
),
|
),
|
||||||
parameters=types.Schema(
|
parameters=google.genai.types.Schema(
|
||||||
type=types.Type.OBJECT,
|
type=google.genai.types.Type.OBJECT,
|
||||||
properties={
|
properties={
|
||||||
"script": types.Schema(
|
"script": google.genai.types.Schema(
|
||||||
type=types.Type.STRING,
|
type=google.genai.types.Type.STRING,
|
||||||
description="The PowerShell script to execute."
|
description="The PowerShell script to execute."
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
@@ -394,7 +411,7 @@ def _gemini_tool_declaration():
|
|||||||
),
|
),
|
||||||
))
|
))
|
||||||
|
|
||||||
return types.Tool(function_declarations=declarations) if declarations else None
|
return google.genai.types.Tool(function_declarations=declarations) if declarations else None
|
||||||
|
|
||||||
|
|
||||||
def _run_script(script: str, base_dir: str) -> str:
|
def _run_script(script: str, base_dir: str) -> str:
|
||||||
@@ -489,9 +506,9 @@ def _content_block_to_dict(block) -> dict:
|
|||||||
def _ensure_gemini_client():
|
def _ensure_gemini_client():
|
||||||
global _gemini_client
|
global _gemini_client
|
||||||
if _gemini_client is None:
|
if _gemini_client is None:
|
||||||
from google import genai
|
# from google import genai # Removed
|
||||||
creds = _load_credentials()
|
creds = _load_credentials()
|
||||||
_gemini_client = genai.Client(api_key=creds["gemini"]["api_key"])
|
_gemini_client = google.genai.Client(api_key=creds["gemini"]["api_key"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -508,7 +525,7 @@ def _get_gemini_history_list(chat):
|
|||||||
|
|
||||||
def _send_gemini(md_content: str, user_message: str, base_dir: str, file_items: list[dict] | None = None) -> str:
|
def _send_gemini(md_content: str, user_message: str, base_dir: str, file_items: list[dict] | None = None) -> str:
|
||||||
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at
|
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at
|
||||||
from google.genai import types
|
# from google.genai import types # Removed
|
||||||
try:
|
try:
|
||||||
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
|
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
|
||||||
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"
|
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"
|
||||||
@@ -541,29 +558,29 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, file_items:
|
|||||||
_append_comms("OUT", "request", {"message": f"[CACHE TTL] Rebuilding cache (expired after {int(elapsed)}s)..."})
|
_append_comms("OUT", "request", {"message": f"[CACHE TTL] Rebuilding cache (expired after {int(elapsed)}s)..."})
|
||||||
|
|
||||||
if not _gemini_chat:
|
if not _gemini_chat:
|
||||||
chat_config = types.GenerateContentConfig(
|
chat_config = google.genai.types.GenerateContentConfig(
|
||||||
system_instruction=sys_instr,
|
system_instruction=sys_instr,
|
||||||
tools=tools_decl,
|
tools=tools_decl,
|
||||||
temperature=_temperature,
|
temperature=_temperature,
|
||||||
max_output_tokens=_max_tokens,
|
max_output_tokens=_max_tokens,
|
||||||
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
|
safety_settings=[google.genai.types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# Gemini requires 1024 (Flash) or 4096 (Pro) tokens to cache.
|
# Gemini requires 1024 (Flash) or 4096 (Pro) tokens to cache.
|
||||||
_gemini_cache = _gemini_client.caches.create(
|
_gemini_cache = _gemini_client.caches.create(
|
||||||
model=_model,
|
model=_model,
|
||||||
config=types.CreateCachedContentConfig(
|
config=google.genai.types.CreateCachedContentConfig(
|
||||||
system_instruction=sys_instr,
|
system_instruction=sys_instr,
|
||||||
tools=tools_decl,
|
tools=tools_decl,
|
||||||
ttl=f"{_GEMINI_CACHE_TTL}s",
|
ttl=f"{_GEMINI_CACHE_TTL}s",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
_gemini_cache_created_at = time.time()
|
_gemini_cache_created_at = time.time()
|
||||||
chat_config = types.GenerateContentConfig(
|
chat_config = google.genai.types.GenerateContentConfig(
|
||||||
cached_content=_gemini_cache.name,
|
cached_content=_gemini_cache.name,
|
||||||
temperature=_temperature,
|
temperature=_temperature,
|
||||||
max_output_tokens=_max_tokens,
|
max_output_tokens=_max_tokens,
|
||||||
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
|
safety_settings=[google.genai.types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
|
||||||
)
|
)
|
||||||
_append_comms("OUT", "request", {"message": f"[CACHE CREATED] {_gemini_cache.name}"})
|
_append_comms("OUT", "request", {"message": f"[CACHE CREATED] {_gemini_cache.name}"})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
45
tests/test_gemini_metrics.py
Normal file
45
tests/test_gemini_metrics.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
# Import the necessary functions from ai_client, including the reset helper
|
||||||
|
from ai_client import get_gemini_cache_stats, reset_session
|
||||||
|
|
||||||
|
def test_get_gemini_cache_stats_with_mock_client():
|
||||||
|
"""
|
||||||
|
Test that get_gemini_cache_stats correctly processes cache lists
|
||||||
|
from a mocked client instance.
|
||||||
|
"""
|
||||||
|
# Ensure a clean state before the test by resetting the session
|
||||||
|
reset_session()
|
||||||
|
|
||||||
|
# 1. Create a mock for the cache object that the client will return
|
||||||
|
mock_cache = MagicMock()
|
||||||
|
mock_cache.name = "cachedContents/test-cache"
|
||||||
|
mock_cache.display_name = "Test Cache"
|
||||||
|
mock_cache.model = "models/gemini-1.5-pro-001"
|
||||||
|
mock_cache.size_bytes = 1024
|
||||||
|
|
||||||
|
# 2. Create a mock for the client instance
|
||||||
|
mock_client_instance = MagicMock()
|
||||||
|
# Configure its `caches.list` method to return our mock cache
|
||||||
|
mock_client_instance.caches.list.return_value = [mock_cache]
|
||||||
|
|
||||||
|
# 3. Patch the Client constructor to return our mock instance
|
||||||
|
# This intercepts the `_ensure_gemini_client` call inside the function
|
||||||
|
with patch('google.genai.Client', return_value=mock_client_instance) as mock_client_constructor:
|
||||||
|
|
||||||
|
# 4. Call the function under test
|
||||||
|
stats = get_gemini_cache_stats()
|
||||||
|
|
||||||
|
# 5. Assert that the function behaved as expected
|
||||||
|
|
||||||
|
# It should have constructed the client
|
||||||
|
mock_client_constructor.assert_called_once()
|
||||||
|
# It should have called the `list` method on the `caches` attribute
|
||||||
|
mock_client_instance.caches.list.assert_called_once()
|
||||||
|
|
||||||
|
# The returned stats dictionary should be correct
|
||||||
|
assert "cache_count" in stats
|
||||||
|
assert "total_size_bytes" in stats
|
||||||
|
assert stats["cache_count"] == 1
|
||||||
|
assert stats["total_size_bytes"] == 1024
|
||||||
Reference in New Issue
Block a user