feat(ui): Enhanced context control with per-file flags and Gemini cache awareness
This commit is contained in:
62
tests/test_aggregate_flags.py
Normal file
62
tests/test_aggregate_flags.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from src import aggregate
|
||||
|
||||
def test_auto_aggregate_skip(tmp_path):
|
||||
# Create some test files
|
||||
f1 = tmp_path / "file1.txt"
|
||||
f1.write_text("content1")
|
||||
f2 = tmp_path / "file2.txt"
|
||||
f2.write_text("content2")
|
||||
|
||||
files = [
|
||||
{"path": "file1.txt", "auto_aggregate": True},
|
||||
{"path": "file2.txt", "auto_aggregate": False},
|
||||
]
|
||||
|
||||
items = aggregate.build_file_items(tmp_path, files)
|
||||
|
||||
# Test _build_files_section_from_items
|
||||
section = aggregate._build_files_section_from_items(items)
|
||||
assert "file1.txt" in section
|
||||
assert "file2.txt" not in section
|
||||
|
||||
# Test build_tier1_context
|
||||
t1 = aggregate.build_tier1_context(items, tmp_path, [], [])
|
||||
assert "file1.txt" in t1
|
||||
assert "file2.txt" not in t1
|
||||
|
||||
# Test build_tier3_context
|
||||
t3 = aggregate.build_tier3_context(items, tmp_path, [], [], [])
|
||||
assert "file1.txt" in t3
|
||||
assert "file2.txt" not in t3
|
||||
|
||||
def test_force_full(tmp_path):
|
||||
# Create a python file that would normally be skeletonized in Tier 3
|
||||
py_file = tmp_path / "script.py"
|
||||
py_file.write_text("def hello():\n print('world')\n")
|
||||
|
||||
# Tier 3 normally skeletonizes non-focus python files
|
||||
items = aggregate.build_file_items(tmp_path, [{"path": "script.py", "force_full": True}])
|
||||
|
||||
# Test build_tier3_context
|
||||
t3 = aggregate.build_tier3_context(items, tmp_path, [], [], [])
|
||||
assert "print('world')" in t3 # Full content present
|
||||
|
||||
# Compare with non-force_full
|
||||
items2 = aggregate.build_file_items(tmp_path, [{"path": "script.py", "force_full": False}])
|
||||
t3_2 = aggregate.build_tier3_context(items2, tmp_path, [], [], [])
|
||||
assert "print('world')" not in t3_2 # Skeletonized
|
||||
|
||||
# Tier 1 normally summarizes non-core files
|
||||
txt_file = tmp_path / "other.txt"
|
||||
txt_file.write_text("line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10")
|
||||
|
||||
items3 = aggregate.build_file_items(tmp_path, [{"path": "other.txt", "force_full": True}])
|
||||
t1 = aggregate.build_tier1_context(items3, tmp_path, [], [])
|
||||
assert "line10" in t1 # Full content present
|
||||
|
||||
items4 = aggregate.build_file_items(tmp_path, [{"path": "other.txt", "force_full": False}])
|
||||
t1_2 = aggregate.build_tier1_context(items4, tmp_path, [], [])
|
||||
# Generic summary for .txt shows first 8 lines
|
||||
assert "line10" not in t1_2
|
||||
70
tests/test_ai_cache_tracking.py
Normal file
70
tests/test_ai_cache_tracking.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import unittest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from src import ai_client
|
||||
import time
|
||||
|
||||
def test_gemini_cache_tracking() -> None:
|
||||
# Setup
|
||||
ai_client.reset_session()
|
||||
ai_client.set_provider("gemini", "gemini-2.5-flash-lite")
|
||||
|
||||
file_items = [
|
||||
{"path": "src/app.py", "content": "print('hello')", "mtime": 123.0},
|
||||
{"path": "src/utils.py", "content": "def util(): pass", "mtime": 456.0}
|
||||
]
|
||||
|
||||
# Mock credentials
|
||||
with patch("src.ai_client._load_credentials") as mock_creds:
|
||||
mock_creds.return_value = {"gemini": {"api_key": "fake-key"}}
|
||||
|
||||
# Mock genai.Client
|
||||
with patch("google.genai.Client") as MockClient:
|
||||
mock_client = MagicMock()
|
||||
MockClient.return_value = mock_client
|
||||
|
||||
# Mock count_tokens to return enough tokens for caching (>= 2048)
|
||||
mock_client.models.count_tokens.return_value = MagicMock(total_tokens=3000)
|
||||
|
||||
# Mock caches.create
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.name = "cached_contents/abc"
|
||||
mock_client.caches.create.return_value = mock_cache
|
||||
|
||||
# Mock chat creation and send_message
|
||||
mock_chat = MagicMock()
|
||||
mock_client.chats.create.return_value = mock_chat
|
||||
mock_chat.send_message.return_value = MagicMock(
|
||||
text="Response",
|
||||
candidates=[MagicMock(finish_reason=MagicMock(name="STOP"))],
|
||||
usage_metadata=MagicMock(prompt_token_count=100, candidates_token_count=50, total_token_count=150)
|
||||
)
|
||||
mock_chat._history = []
|
||||
|
||||
# Mock caches.list for stats
|
||||
mock_client.caches.list.return_value = [MagicMock(size_bytes=5000)]
|
||||
|
||||
# Act
|
||||
ai_client.send(
|
||||
md_content="Some long context that triggers caching",
|
||||
user_message="Hello",
|
||||
file_items=file_items
|
||||
)
|
||||
|
||||
# Assert
|
||||
stats = ai_client.get_gemini_cache_stats()
|
||||
assert stats["cached_files"] == ["src/app.py", "src/utils.py"]
|
||||
|
||||
# Test reset_session
|
||||
ai_client.reset_session()
|
||||
stats = ai_client.get_gemini_cache_stats()
|
||||
assert stats["cached_files"] == []
|
||||
|
||||
def test_gemini_cache_tracking_cleanup() -> None:
|
||||
ai_client._gemini_cached_file_paths = ["old.py"]
|
||||
ai_client.cleanup()
|
||||
assert ai_client._gemini_cached_file_paths == []
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gemini_cache_tracking()
|
||||
test_gemini_cache_tracking_cleanup()
|
||||
print("All tests passed!")
|
||||
39
tests/test_file_item_model.py
Normal file
39
tests/test_file_item_model.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
from src.models import FileItem
|
||||
|
||||
def test_file_item_fields():
|
||||
"""Test that FileItem exists and has correct default values."""
|
||||
item = FileItem(path="src/models.py")
|
||||
assert item.path == "src/models.py"
|
||||
assert item.auto_aggregate is True
|
||||
assert item.force_full is False
|
||||
|
||||
def test_file_item_to_dict():
|
||||
"""Test that FileItem can be serialized to a dict."""
|
||||
item = FileItem(path="test.py", auto_aggregate=False, force_full=True)
|
||||
expected = {
|
||||
"path": "test.py",
|
||||
"auto_aggregate": False,
|
||||
"force_full": True
|
||||
}
|
||||
assert item.to_dict() == expected
|
||||
|
||||
def test_file_item_from_dict():
|
||||
"""Test that FileItem can be deserialized from a dict."""
|
||||
data = {
|
||||
"path": "test.py",
|
||||
"auto_aggregate": False,
|
||||
"force_full": True
|
||||
}
|
||||
item = FileItem.from_dict(data)
|
||||
assert item.path == "test.py"
|
||||
assert item.auto_aggregate is False
|
||||
assert item.force_full is True
|
||||
|
||||
def test_file_item_from_dict_defaults():
|
||||
"""Test that FileItem.from_dict handles missing fields."""
|
||||
data = {"path": "test.py"}
|
||||
item = FileItem.from_dict(data)
|
||||
assert item.path == "test.py"
|
||||
assert item.auto_aggregate is True
|
||||
assert item.force_full is False
|
||||
90
tests/test_project_serialization.py
Normal file
90
tests/test_project_serialization.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import unittest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from src import project_manager
|
||||
from src import models
|
||||
from src.app_controller import AppController
|
||||
|
||||
class TestProjectSerialization(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.TemporaryDirectory()
|
||||
self.project_path = Path(self.test_dir.name) / "test_project.toml"
|
||||
|
||||
def tearDown(self):
|
||||
self.test_dir.cleanup()
|
||||
|
||||
def test_fileitem_roundtrip(self):
|
||||
"""Verify that FileItem objects survive a save/load cycle."""
|
||||
proj = project_manager.default_project("test")
|
||||
file1 = models.FileItem(path="src/main.py", auto_aggregate=True, force_full=False)
|
||||
file2 = models.FileItem(path="docs/readme.md", auto_aggregate=False, force_full=True)
|
||||
proj["files"]["paths"] = [file1, file2]
|
||||
|
||||
# Save
|
||||
project_manager.save_project(proj, self.project_path)
|
||||
|
||||
# Load
|
||||
loaded_proj = project_manager.load_project(self.project_path)
|
||||
|
||||
paths = loaded_proj["files"]["paths"]
|
||||
self.assertEqual(len(paths), 2)
|
||||
self.assertIsInstance(paths[0], models.FileItem)
|
||||
self.assertEqual(paths[0].path, "src/main.py")
|
||||
self.assertTrue(paths[0].auto_aggregate)
|
||||
self.assertFalse(paths[0].force_full)
|
||||
|
||||
self.assertIsInstance(paths[1], models.FileItem)
|
||||
self.assertEqual(paths[1].path, "docs/readme.md")
|
||||
self.assertFalse(paths[1].auto_aggregate)
|
||||
self.assertTrue(paths[1].force_full)
|
||||
|
||||
def test_backward_compatibility_strings(self):
|
||||
"""Verify that old-style string paths are converted to FileItem objects by AppController."""
|
||||
# Create a project file manually with string paths
|
||||
content = """
|
||||
[project]
|
||||
name = "legacy"
|
||||
|
||||
[files]
|
||||
base_dir = "."
|
||||
paths = ["file1.py", "file2.md"]
|
||||
|
||||
[discussion]
|
||||
roles = ["User", "AI"]
|
||||
"""
|
||||
with open(self.project_path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
# Load via project_manager (should load as strings)
|
||||
proj = project_manager.load_project(self.project_path)
|
||||
self.assertEqual(proj["files"]["paths"], ["file1.py", "file2.md"])
|
||||
|
||||
# Initialize AppController state logic
|
||||
controller = AppController()
|
||||
controller.project = proj
|
||||
|
||||
# Trigger deserialization (copied from init_state)
|
||||
raw_paths = controller.project.get("files", {}).get("paths", [])
|
||||
controller.files = []
|
||||
for p in raw_paths:
|
||||
if isinstance(p, models.FileItem):
|
||||
controller.files.append(p)
|
||||
elif isinstance(p, dict):
|
||||
controller.files.append(models.FileItem.from_dict(p))
|
||||
else:
|
||||
controller.files.append(models.FileItem(path=str(p)))
|
||||
|
||||
self.assertEqual(len(controller.files), 2)
|
||||
self.assertIsInstance(controller.files[0], models.FileItem)
|
||||
self.assertEqual(controller.files[0].path, "file1.py")
|
||||
self.assertIsInstance(controller.files[1], models.FileItem)
|
||||
self.assertEqual(controller.files[1].path, "file2.md")
|
||||
|
||||
def test_default_roles_include_context(self):
|
||||
"""Verify that 'Context' is in default project roles."""
|
||||
proj = project_manager.default_project("test")
|
||||
self.assertIn("Context", proj["discussion"]["roles"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user