test(core): Replace pytest.fail with functional assertions in token_usage and agent_capabilities
This commit is contained in:
@@ -1,10 +1,43 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
# Ensure project root is in path
|
# Ensure project root is in path
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
import ai_client
|
||||||
|
|
||||||
def test_agent_capabilities_listing() -> None:
|
def test_agent_capabilities_listing() -> None:
|
||||||
pytest.fail("TODO: Implement assertions")
|
# Mock credentials
|
||||||
|
with patch("ai_client._load_credentials") as mock_creds:
|
||||||
|
mock_creds.return_value = {"gemini": {"api_key": "fake-key"}}
|
||||||
|
|
||||||
|
# Mock the google-genai Client and models.list
|
||||||
|
with patch("google.genai.Client") as mock_client_class:
|
||||||
|
mock_client = mock_client_class.return_value
|
||||||
|
|
||||||
|
# Create a list of mock models
|
||||||
|
mock_models = []
|
||||||
|
model_names = [
|
||||||
|
"models/gemini-2.0-flash",
|
||||||
|
"models/gemini-2.0-flash-lite",
|
||||||
|
"models/gemini-1.5-pro",
|
||||||
|
"models/gemini-1.5-flash",
|
||||||
|
"models/gemini-pro",
|
||||||
|
"models/gemini-ultra"
|
||||||
|
]
|
||||||
|
for name in model_names:
|
||||||
|
m = MagicMock()
|
||||||
|
m.name = name
|
||||||
|
mock_models.append(m)
|
||||||
|
|
||||||
|
mock_client.models.list.return_value = mock_models
|
||||||
|
|
||||||
|
models = ai_client.list_models("gemini")
|
||||||
|
|
||||||
|
assert len(models) >= 6
|
||||||
|
assert "gemini-2.0-flash" in models
|
||||||
|
# Check that it stripped "models/" prefix
|
||||||
|
for m in models:
|
||||||
|
assert not m.startswith("models/")
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import hashlib
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
# Ensure project root is in path
|
# Ensure project root is in path
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
@@ -9,6 +12,54 @@ import ai_client
|
|||||||
|
|
||||||
def test_token_usage_tracking() -> None:
|
def test_token_usage_tracking() -> None:
|
||||||
ai_client.reset_session()
|
ai_client.reset_session()
|
||||||
# Mock an API response with token usage
|
ai_client.clear_comms_log()
|
||||||
# This would test the internal accumulator in ai_client
|
ai_client.set_provider("gemini", "gemini-2.0-flash")
|
||||||
pytest.fail("TODO: Implement assertions")
|
|
||||||
|
# Mock credentials so we don't need a real file
|
||||||
|
with patch("ai_client._load_credentials") as mock_creds:
|
||||||
|
mock_creds.return_value = {"gemini": {"api_key": "fake-key"}}
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
# Use SimpleNamespace to ensure attributes are real values, not Mocks
|
||||||
|
mock_resp.usage_metadata = SimpleNamespace(
|
||||||
|
prompt_token_count=100,
|
||||||
|
candidates_token_count=50,
|
||||||
|
cached_content_token_count=20
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup candidates
|
||||||
|
mock_candidate = MagicMock()
|
||||||
|
# Use spec to ensure hasattr(p, "function_call") is False for text parts
|
||||||
|
mock_part = MagicMock(spec=["text"])
|
||||||
|
mock_part.text = "Hello"
|
||||||
|
mock_candidate.content.parts = [mock_part]
|
||||||
|
mock_candidate.finish_reason.name = "STOP"
|
||||||
|
mock_resp.candidates = [mock_candidate]
|
||||||
|
|
||||||
|
mock_chat = MagicMock()
|
||||||
|
mock_chat.send_message.return_value = mock_resp
|
||||||
|
mock_chat._history = []
|
||||||
|
|
||||||
|
# Mock the client creation and storage
|
||||||
|
with patch("google.genai.Client") as mock_client_class:
|
||||||
|
mock_client_instance = mock_client_class.return_value
|
||||||
|
# Mock count_tokens to avoid call during send_gemini
|
||||||
|
mock_client_instance.models.count_tokens.return_value = MagicMock(total_tokens=100)
|
||||||
|
# Mock chats.create to return our mock_chat
|
||||||
|
mock_client_instance.chats.create.return_value = mock_chat
|
||||||
|
|
||||||
|
ai_client._gemini_client = mock_client_instance
|
||||||
|
ai_client._gemini_chat = mock_chat
|
||||||
|
# Set the hash to prevent chat reset
|
||||||
|
ai_client._gemini_cache_md_hash = hashlib.md5("context".encode()).hexdigest()
|
||||||
|
|
||||||
|
ai_client.send("context", "hi", enable_tools=False)
|
||||||
|
|
||||||
|
log = ai_client.get_comms_log()
|
||||||
|
# The log might have 'request' and 'response' entries
|
||||||
|
response_entries = [e for e in log if e["kind"] == "response"]
|
||||||
|
assert len(response_entries) > 0
|
||||||
|
usage = response_entries[0]["payload"]["usage"]
|
||||||
|
assert usage["input_tokens"] == 100
|
||||||
|
assert usage["output_tokens"] == 50
|
||||||
|
assert usage["cache_read_input_tokens"] == 20
|
||||||
|
|||||||
Reference in New Issue
Block a user