134 lines
5.2 KiB
Python
134 lines
5.2 KiB
Python
import sys
|
|
import os
|
|
import tomli_w
|
|
import tomllib
|
|
from pathlib import Path
|
|
|
|
# Ensure project root is in path
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
|
|
from src import aggregate
|
|
from src import project_manager
|
|
from src import ai_client
|
|
|
|
def test_aggregate_includes_segregated_history() -> None:
|
|
"""Tests if the aggregate function correctly includes history"""
|
|
project_data = {
|
|
"discussion": {
|
|
"history": [
|
|
{"role": "User", "content": "Hello", "ts": "2024-01-01T00:00:00"},
|
|
{"role": "AI", "content": "Hi there", "ts": "2024-01-01T00:00:01"}
|
|
]
|
|
}
|
|
}
|
|
history_text = aggregate.build_discussion_text(project_data["discussion"]["history"])
|
|
assert "User: Hello" in history_text
|
|
assert "AI: Hi there" in history_text
|
|
|
|
def test_mcp_blacklist() -> None:
|
|
"""Tests that the MCP client correctly blacklists files"""
|
|
from src import mcp_client
|
|
from src.models import CONFIG_PATH
|
|
# CONFIG_PATH is usually something like 'config.toml'
|
|
assert mcp_client._is_allowed(Path("src/gui_2.py")) is True
|
|
# config.toml should be blacklisted for reading by the AI
|
|
assert mcp_client._is_allowed(Path(CONFIG_PATH)) is False
|
|
|
|
def test_aggregate_blacklist() -> None:
|
|
"""Tests that aggregate correctly excludes blacklisted files"""
|
|
file_items = [
|
|
{"path": "src/gui_2.py", "name": "gui_2.py", "content": "print('hello')"},
|
|
{"path": "config.toml", "name": "config.toml", "content": "secret = 123"}
|
|
]
|
|
# build_markdown_no_history uses item.get("path") for label if name missing
|
|
md = aggregate.build_markdown_no_history(file_items, Path("."), [])
|
|
# Check if it contains the file content or label
|
|
assert "print('hello')" in md
|
|
assert "secret = 123" in md
|
|
|
|
def test_migration_on_load(tmp_path: Path) -> None:
|
|
"""Tests that legacy configuration is correctly migrated on load"""
|
|
legacy_config = {
|
|
"project": {"name": "Legacy"},
|
|
"files": ["file1.py"],
|
|
"discussion_history": "User: Hello\nAI: Hi"
|
|
}
|
|
legacy_path = tmp_path / "legacy.toml"
|
|
with open(legacy_path, "wb") as f:
|
|
tomli_w.dump(legacy_config, f)
|
|
|
|
migrated = project_manager.load_project(str(legacy_path))
|
|
# current impl might put it in discussion -> history or project -> discussion_history
|
|
assert "discussion" in migrated or "discussion_history" in migrated
|
|
|
|
def test_save_separation(tmp_path: Path) -> None:
|
|
"""Tests that saving project data correctly separates history and files"""
|
|
project_path = tmp_path / "project.toml"
|
|
project_data = project_manager.default_project("Test")
|
|
# Navigate to history in default_project structure
|
|
active_disc = project_data["discussion"]["active"]
|
|
history = project_data["discussion"]["discussions"][active_disc]["history"]
|
|
history.append({"role": "User", "content": "Test", "ts": "2024-01-01T00:00:00"})
|
|
|
|
project_manager.save_project(project_data, str(project_path))
|
|
|
|
with open(project_path, "rb") as f:
|
|
saved = tomllib.load(f)
|
|
# Main file should NOT have discussion
|
|
assert "discussion" not in saved
|
|
|
|
# History file SHOULD have the entire discussion dict
|
|
hist_path = project_manager.get_history_path(project_path)
|
|
assert hist_path.exists()
|
|
with open(hist_path, "rb") as f:
|
|
saved_hist = tomllib.load(f)
|
|
assert "discussions" in saved_hist
|
|
assert active_disc in saved_hist["discussions"]
|
|
assert len(saved_hist["discussions"][active_disc]["history"]) == 1
|
|
|
|
def test_history_persistence_across_turns(tmp_path: Path) -> None:
|
|
"""Tests that discussion history is correctly persisted across multiple save/load cycles."""
|
|
project_path = tmp_path / "project.toml"
|
|
project_data = project_manager.default_project("Test")
|
|
|
|
# Turn 1
|
|
active_disc = project_data["discussion"]["active"]
|
|
history = project_data["discussion"]["discussions"][active_disc]["history"]
|
|
history.append({"role": "User", "content": "Turn 1", "ts": "2024-01-01T00:00:00"})
|
|
project_manager.save_project(project_data, str(project_path))
|
|
|
|
# Reload
|
|
loaded = project_manager.load_project(str(project_path))
|
|
active_disc = loaded["discussion"]["active"]
|
|
h = loaded["discussion"]["discussions"][active_disc]["history"]
|
|
assert len(h) >= 1
|
|
assert any("Turn 1" in str(entry) for entry in h)
|
|
|
|
# Turn 2
|
|
h.append({"role": "AI", "content": "Response 1", "ts": "2024-01-01T00:00:01"})
|
|
project_manager.save_project(loaded, str(project_path))
|
|
|
|
# Reload again
|
|
reloaded = project_manager.load_project(str(project_path))
|
|
active_disc = reloaded["discussion"]["active"]
|
|
h2 = reloaded["discussion"]["discussions"][active_disc]["history"]
|
|
assert len(h2) >= 2
|
|
|
|
def test_get_history_bleed_stats_basic() -> None:
|
|
"""Tests basic retrieval of history bleed statistics from the AI client."""
|
|
ai_client.set_provider("gemini", "gemini-2.5-flash-lite")
|
|
# Before any message, it might be 0 or based on an empty context
|
|
stats = ai_client.get_history_bleed_stats()
|
|
assert "provider" in stats
|
|
assert stats["provider"] == "gemini"
|
|
assert "current" in stats
|
|
assert "limit" in stats, "Stats dictionary should contain 'limit'"
|
|
|
|
# Test with a different limit
|
|
ai_client.set_model_params(0.0, 8192, 500)
|
|
stats = ai_client.get_history_bleed_stats()
|
|
assert "current" in stats, "Stats dictionary should contain 'current' token usage"
|
|
assert 'limit' in stats, "Stats dictionary should contain 'limit'"
|
|
assert stats['limit'] == 500
|
|
assert isinstance(stats['current'], int) and stats['current'] >= 0
|