71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
import unittest
|
|
from unittest.mock import patch, MagicMock
|
|
from src import ai_client
|
|
import time
|
|
|
|
def test_gemini_cache_tracking() -> None:
|
|
# Setup
|
|
ai_client.reset_session()
|
|
ai_client.set_provider("gemini", "gemini-2.5-flash-lite")
|
|
|
|
file_items = [
|
|
{"path": "src/app.py", "content": "print('hello')", "mtime": 123.0},
|
|
{"path": "src/utils.py", "content": "def util(): pass", "mtime": 456.0}
|
|
]
|
|
|
|
# Mock credentials
|
|
with patch("src.ai_client._load_credentials") as mock_creds:
|
|
mock_creds.return_value = {"gemini": {"api_key": "fake-key"}}
|
|
|
|
# Mock genai.Client
|
|
with patch("google.genai.Client") as MockClient:
|
|
mock_client = MagicMock()
|
|
MockClient.return_value = mock_client
|
|
|
|
# Mock count_tokens to return enough tokens for caching (>= 2048)
|
|
mock_client.models.count_tokens.return_value = MagicMock(total_tokens=3000)
|
|
|
|
# Mock caches.create
|
|
mock_cache = MagicMock()
|
|
mock_cache.name = "cached_contents/abc"
|
|
mock_client.caches.create.return_value = mock_cache
|
|
|
|
# Mock chat creation and send_message
|
|
mock_chat = MagicMock()
|
|
mock_client.chats.create.return_value = mock_chat
|
|
mock_chat.send_message.return_value = MagicMock(
|
|
text="Response",
|
|
candidates=[MagicMock(finish_reason=MagicMock(name="STOP"))],
|
|
usage_metadata=MagicMock(prompt_token_count=100, candidates_token_count=50, total_token_count=150)
|
|
)
|
|
mock_chat._history = []
|
|
|
|
# Mock caches.list for stats
|
|
mock_client.caches.list.return_value = [MagicMock(size_bytes=5000)]
|
|
|
|
# Act
|
|
ai_client.send(
|
|
md_content="Some long context that triggers caching",
|
|
user_message="Hello",
|
|
file_items=file_items
|
|
)
|
|
|
|
# Assert
|
|
stats = ai_client.get_gemini_cache_stats()
|
|
assert stats["cached_files"] == ["src/app.py", "src/utils.py"]
|
|
|
|
# Test reset_session
|
|
ai_client.reset_session()
|
|
stats = ai_client.get_gemini_cache_stats()
|
|
assert stats["cached_files"] == []
|
|
|
|
def test_gemini_cache_tracking_cleanup() -> None:
|
|
ai_client._gemini_cached_file_paths = ["old.py"]
|
|
ai_client.cleanup()
|
|
assert ai_client._gemini_cached_file_paths == []
|
|
|
|
if __name__ == "__main__":
|
|
test_gemini_cache_tracking()
|
|
test_gemini_cache_tracking_cleanup()
|
|
print("All tests passed!")
|