From d7a6ba7e5189a2a8f37dd209f3be2782164acd5b Mon Sep 17 00:00:00 2001 From: Ed_ Date: Sat, 7 Mar 2026 12:13:08 -0500 Subject: [PATCH] feat(ui): Enhanced context control with per-file flags and Gemini cache awareness --- src/aggregate.py | 32 ++++++---- src/ai_client.py | 17 ++++-- src/app_controller.py | 17 +++++- src/gui_2.py | 52 ++++++++++++++--- src/models.py | 21 +++++++ src/project_manager.py | 11 +++- tests/test_aggregate_flags.py | 62 ++++++++++++++++++++ tests/test_ai_cache_tracking.py | 70 ++++++++++++++++++++++ tests/test_file_item_model.py | 39 +++++++++++++ tests/test_project_serialization.py | 90 +++++++++++++++++++++++++++++ 10 files changed, 383 insertions(+), 28 deletions(-) create mode 100644 tests/test_aggregate_flags.py create mode 100644 tests/test_ai_cache_tracking.py create mode 100644 tests/test_file_item_model.py create mode 100644 tests/test_project_serialization.py diff --git a/src/aggregate.py b/src/aggregate.py index acd7dd6..1b73dfb 100644 --- a/src/aggregate.py +++ b/src/aggregate.py @@ -122,26 +122,32 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[ wants to upload individual files rather than inline everything as markdown. Each dict has: - path : Path (resolved absolute path) - entry : str (original config entry string) - content : str (file text, or error string) - error : bool - mtime : float (last modification time, for skip-if-unchanged optimization) - tier : int | None (optional tier for context management) + path : Path (resolved absolute path) + entry : str (original config entry string) + content : str (file text, or error string) + error : bool + mtime : float (last modification time, for skip-if-unchanged optimization) + tier : int | None (optional tier for context management) + auto_aggregate : bool + force_full : bool """ items: list[dict[str, Any]] = [] for entry_raw in files: if isinstance(entry_raw, dict): entry = cast(str, entry_raw.get("path", "")) tier = entry_raw.get("tier") + auto_aggregate = entry_raw.get("auto_aggregate", True) + force_full = entry_raw.get("force_full", False) else: entry = entry_raw tier = None + auto_aggregate = True + force_full = False if not entry or not isinstance(entry, str): continue paths = resolve_paths(base_dir, entry) if not paths: - items.append({"path": None, "entry": entry, "content": f"ERROR: no files matched: {entry}", "error": True, "mtime": 0.0, "tier": tier}) + items.append({"path": None, "entry": entry, "content": f"ERROR: no files matched: {entry}", "error": True, "mtime": 0.0, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full}) continue for path in paths: try: @@ -156,7 +162,7 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[ content = f"ERROR: {e}" mtime = 0.0 error = True - items.append({"path": path, "entry": entry, "content": content, "error": error, "mtime": mtime, "tier": tier}) + items.append({"path": path, "entry": entry, "content": content, "error": error, "mtime": mtime, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full}) return items def build_summary_section(base_dir: Path, files: list[str | dict[str, Any]]) -> str: @@ -171,6 +177,8 @@ def _build_files_section_from_items(file_items: list[dict[str, Any]]) -> str: """Build the files markdown section from pre-read file items (avoids double I/O).""" sections = [] for item in file_items: + if not item.get("auto_aggregate", True): + continue path = item.get("path") entry = cast(str, item.get("entry", "unknown")) content = cast(str, item.get("content", "")) @@ -221,9 +229,11 @@ def build_tier1_context(file_items: list[dict[str, Any]], screenshot_base_dir: P if file_items: sections = [] for item in file_items: + if not item.get("auto_aggregate", True): + continue path = item.get("path") name = path.name if path and isinstance(path, Path) else "" - if name in core_files or item.get("tier") == 1: + if name in core_files or item.get("tier") == 1 or item.get("force_full"): # Include in full sections.append("### `" + (cast(str, item.get("entry")) or str(path)) + "`\n\n" + f"```{path.suffix.lstrip('.') if path and isinstance(path, Path) and path.suffix else 'text'}\n{item.get('content', '')}\n```") @@ -255,6 +265,8 @@ def build_tier3_context(file_items: list[dict[str, Any]], screenshot_base_dir: P if file_items: sections = [] for item in file_items: + if not item.get("auto_aggregate", True): + continue path = cast(Path, item.get("path")) entry = cast(str, item.get("entry", "")) path_str = str(path) if path else "" @@ -264,7 +276,7 @@ def build_tier3_context(file_items: list[dict[str, Any]], screenshot_base_dir: P if focus == entry or (path and focus == path.name) or (path_str and focus in path_str): is_focus = True break - if is_focus or item.get("tier") == 3: + if is_focus or item.get("tier") == 3 or item.get("force_full"): sections.append("### `" + (entry or path_str) + "`\n\n" + f"```{path.suffix.lstrip('.') if path and path.suffix else 'text'}\n{item.get('content', '')}\n```") else: diff --git a/src/ai_client.py b/src/ai_client.py index c081a74..d7f5679 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -63,6 +63,7 @@ _gemini_chat: Any = None _gemini_cache: Any = None _gemini_cache_md_hash: Optional[str] = None _gemini_cache_created_at: Optional[float] = None +_gemini_cached_file_paths: list[str] = [] # Gemini cache TTL in seconds. Caches are created with this TTL and # proactively rebuilt at 90% of this value to avoid stale-reference errors. @@ -343,16 +344,17 @@ def get_provider() -> str: return _provider def cleanup() -> None: - global _gemini_client, _gemini_cache + global _gemini_client, _gemini_cache, _gemini_cached_file_paths if _gemini_client and _gemini_cache: try: _gemini_client.caches.delete(name=_gemini_cache.name) except Exception: pass + _gemini_cached_file_paths = [] def reset_session() -> None: global _gemini_client, _gemini_chat, _gemini_cache - global _gemini_cache_md_hash, _gemini_cache_created_at + global _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths global _anthropic_client, _anthropic_history global _deepseek_client, _deepseek_history global _minimax_client, _minimax_history @@ -368,6 +370,7 @@ def reset_session() -> None: _gemini_cache = None _gemini_cache_md_hash = None _gemini_cache_created_at = None + _gemini_cached_file_paths = [] # Preserve binary_path if adapter exists old_path = _gemini_cli_adapter.binary_path if _gemini_cli_adapter else "gemini" @@ -389,14 +392,14 @@ def reset_session() -> None: def get_gemini_cache_stats() -> dict[str, Any]: _ensure_gemini_client() if not _gemini_client: - return {"cache_count": 0, "total_size_bytes": 0} + return {"cache_count": 0, "total_size_bytes": 0, "cached_files": []} caches_iterator = _gemini_client.caches.list() caches = list(caches_iterator) total_size_bytes = sum(getattr(c, 'size_bytes', 0) for c in caches) return { - "cache_count": len(caches), "total_size_bytes": total_size_bytes, + "cached_files": _gemini_cached_file_paths, } def list_models(provider: str) -> list[str]: @@ -803,7 +806,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, enable_tools: bool = True, stream_callback: Optional[Callable[[str], None]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: - global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at + global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths try: _ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir]) sys_instr = f"{_get_combined_system_prompt()}\n\n\n{md_content}\n" @@ -820,6 +823,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, _gemini_chat = None _gemini_cache = None _gemini_cache_created_at = None + _gemini_cached_file_paths = [] _append_comms("OUT", "request", {"message": "[CONTEXT CHANGED] Rebuilding cache and chat session..."}) if _gemini_chat and _gemini_cache and _gemini_cache_created_at: elapsed = time.time() - _gemini_cache_created_at @@ -830,6 +834,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, _gemini_chat = None _gemini_cache = None _gemini_cache_created_at = None + _gemini_cached_file_paths = [] _append_comms("OUT", "request", {"message": f"[CACHE TTL] Rebuilding cache (expired after {int(elapsed)}s)..."}) if not _gemini_chat: chat_config = types.GenerateContentConfig( @@ -860,6 +865,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, ) ) _gemini_cache_created_at = time.time() + _gemini_cached_file_paths = [str(item.get("path", "")) for item in (file_items or []) if item.get("path")] chat_config = types.GenerateContentConfig( cached_content=_gemini_cache.name, temperature=_temperature, @@ -870,6 +876,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, except Exception as e: _gemini_cache = None _gemini_cache_created_at = None + _gemini_cached_file_paths = [] _append_comms("OUT", "request", {"message": f"[CACHE FAILED] {type(e).__name__}: {e} \u2014 falling back to inline system_instruction"}) kwargs: dict[str, Any] = {"model": _model, "config": chat_config} if old_history: diff --git a/src/app_controller.py b/src/app_controller.py index 62f5099..141d3cf 100644 --- a/src/app_controller.py +++ b/src/app_controller.py @@ -147,6 +147,7 @@ class AppController: self._tool_log: List[Dict[str, Any]] = [] self._tool_stats: Dict[str, Dict[str, Any]] = {} # {tool_name: {"count": 0, "total_time_ms": 0.0, "failures": 0}} self._cached_cache_stats: Dict[str, Any] = {} # Pre-computed cache stats for GUI + self._cached_files: List[str] = [] self._token_history: List[Dict[str, Any]] = [] # Token usage over time [{"time": t, "input": n, "output": n, "model": s}, ...] self._session_start_time: float = time.time() # For calculating burn rate self._ticket_start_times: dict[str, float] = {} @@ -702,10 +703,19 @@ class AppController: self.project_paths = list(projects_cfg.get("paths", [])) self.active_project_path = projects_cfg.get("active", "") self._load_active_project() - self.files = list(self.project.get("files", {}).get("paths", [])) + # Deserialize FileItems in files.paths + raw_paths = self.project.get("files", {}).get("paths", []) + self.files = [] + for p in raw_paths: + if isinstance(p, models.FileItem): + self.files.append(p) + elif isinstance(p, dict): + self.files.append(models.FileItem.from_dict(p)) + else: + self.files.append(models.FileItem(path=str(p))) self.screenshots = list(self.project.get("screenshots", {}).get("paths", [])) disc_sec = self.project.get("discussion", {}) - self.disc_roles = list(disc_sec.get("roles", ["User", "AI", "Vendor API", "System"])) + self.disc_roles = list(disc_sec.get("roles", ["User", "AI", "Vendor API", "System", "Reasoning", "Context"])) self.active_discussion = disc_sec.get("active", "main") disc_data = disc_sec.get("discussions", {}).get(self.active_discussion, {}) with self._disc_entries_lock: @@ -1804,6 +1814,9 @@ class AppController: if k in usage: usage[k] += u.get(k, 0) or 0 self.session_usage = usage + # Update cached files list + stats = ai_client.get_gemini_cache_stats() + self._cached_files = stats.get("cached_files", []) def _refresh_api_metrics(self, payload: dict[str, Any], md_content: str | None = None) -> None: if "latency" in payload: diff --git a/src/gui_2.py b/src/gui_2.py index ebb70f4..e7d9035 100644 --- a/src/gui_2.py +++ b/src/gui_2.py @@ -774,7 +774,14 @@ class App: imgui.separator() if imgui.button("Inject", imgui.ImVec2(120, 0)): formatted = f"## File: {self._inject_file_path}\n```python\n{self._inject_preview}\n```\n" - self.ui_ai_input += formatted + with self._disc_entries_lock: + self.disc_entries.append({ + "role": "Context", + "content": formatted, + "collapsed": True, + "ts": project_manager.now_ts() + }) + self._scroll_disc_to_bottom = True imgui.close_current_popup() imgui.same_line() if imgui.button("Cancel", imgui.ImVec2(120, 0)): @@ -1075,25 +1082,52 @@ class App: imgui.separator() imgui.text("Paths") imgui.begin_child("f_paths", imgui.ImVec2(0, -40), True) - for i, f in enumerate(self.files): - if imgui.button(f"x##f{i}"): - self.files.pop(i) - break - imgui.same_line() - imgui.text(f) + if imgui.begin_table("files_table", 4, imgui.TableFlags_.resizable | imgui.TableFlags_.borders): + imgui.table_setup_column("Actions", imgui.TableColumnFlags_.width_fixed, 40) + imgui.table_setup_column("File Path", imgui.TableColumnFlags_.width_stretch) + imgui.table_setup_column("Flags", imgui.TableColumnFlags_.width_fixed, 150) + imgui.table_setup_column("Cache", imgui.TableColumnFlags_.width_fixed, 40) + imgui.table_headers_row() + + for i, f_item in enumerate(self.files): + imgui.table_next_row() + # Actions + imgui.table_set_column_index(0) + if imgui.button(f"x##f{i}"): + self.files.pop(i) + break + # File Path + imgui.table_set_column_index(1) + imgui.text(f_item.path if hasattr(f_item, "path") else str(f_item)) + # Flags + imgui.table_set_column_index(2) + if hasattr(f_item, "auto_aggregate"): + changed_agg, f_item.auto_aggregate = imgui.checkbox(f"Agg##a{i}", f_item.auto_aggregate) + imgui.same_line() + changed_full, f_item.force_full = imgui.checkbox(f"Full##f{i}", f_item.force_full) + # Cache + imgui.table_set_column_index(3) + path = f_item.path if hasattr(f_item, "path") else str(f_item) + is_cached = any(path in c for c in getattr(self, "_cached_files", [])) + if is_cached: + imgui.text_colored("●", imgui.ImVec4(0, 1, 0, 1)) # Green dot + else: + imgui.text_disabled("○") + imgui.end_table() imgui.end_child() if imgui.button("Add File(s)"): r = hide_tk_root() paths = filedialog.askopenfilenames() r.destroy() for p in paths: - if p not in self.files: self.files.append(p) + if p not in [f.path if hasattr(f, "path") else f for f in self.files]: + self.files.append(models.FileItem(path=p)) imgui.same_line() if imgui.button("Add Wildcard"): r = hide_tk_root() d = filedialog.askdirectory() r.destroy() - if d: self.files.append(str(Path(d) / "**" / "*")) + if d: self.files.append(models.FileItem(path=str(Path(d) / "**" / "*"))) def _render_screenshots_panel(self) -> None: imgui.text("Base Dir") diff --git a/src/models.py b/src/models.py index ff85fde..95dbc10 100644 --- a/src/models.py +++ b/src/models.py @@ -233,3 +233,24 @@ class TrackState: discussion=parsed_discussion, tasks=[Ticket.from_dict(t) for t in data.get("tasks", [])], ) + +@dataclass +class FileItem: + path: str + auto_aggregate: bool = True + force_full: bool = False + + def to_dict(self) -> Dict[str, Any]: + return { + "path": self.path, + "auto_aggregate": self.auto_aggregate, + "force_full": self.force_full, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FileItem": + return cls( + path=data["path"], + auto_aggregate=data.get("auto_aggregate", True), + force_full=data.get("force_full", False), + ) diff --git a/src/project_manager.py b/src/project_manager.py index da40648..4a20896 100644 --- a/src/project_manager.py +++ b/src/project_manager.py @@ -126,7 +126,7 @@ def default_project(name: str = "unnamed") -> dict[str, Any]: } }, "discussion": { - "roles": ["User", "AI", "Vendor API", "System", "Reasoning"], + "roles": ["User", "AI", "Vendor API", "System", "Reasoning", "Context"], "active": "main", "discussions": {"main": default_discussion()}, }, @@ -150,6 +150,10 @@ def load_project(path: Union[str, Path]) -> dict[str, Any]: """ with open(path, "rb") as f: proj = tomllib.load(f) + # Deserialise FileItems in files.paths + if "files" in proj and "paths" in proj["files"]: + from src import models + proj["files"]["paths"] = [models.FileItem.from_dict(p) if isinstance(p, dict) else p for p in proj["files"]["paths"]] hist_path = get_history_path(path) if "discussion" in proj: disc = proj.pop("discussion") @@ -184,6 +188,9 @@ def save_project(proj: dict[str, Any], path: Union[str, Path], disc_data: Option If 'discussion' is present in proj, it is moved to the sibling history file. """ proj = clean_nones(proj) + # Serialise FileItems + if "files" in proj and "paths" in proj["files"]: + proj["files"]["paths"] = [p.to_dict() if hasattr(p, "to_dict") else p for p in proj["files"]["paths"]] if "discussion" in proj: if disc_data is None: disc_data = proj["discussion"] @@ -206,7 +213,7 @@ def migrate_from_legacy_config(cfg: dict[str, Any]) -> dict[str, Any]: if key in cfg: proj[key] = dict(cfg[key]) disc = cfg.get("discussion", {}) - proj["discussion"]["roles"] = disc.get("roles", ["User", "AI", "Vendor API", "System"]) + proj["discussion"]["roles"] = disc.get("roles", ["User", "AI", "Vendor API", "System", "Context"]) main_disc = proj["discussion"]["discussions"]["main"] main_disc["history"] = disc.get("history", []) main_disc["last_updated"] = now_ts() diff --git a/tests/test_aggregate_flags.py b/tests/test_aggregate_flags.py new file mode 100644 index 0000000..a3ae3db --- /dev/null +++ b/tests/test_aggregate_flags.py @@ -0,0 +1,62 @@ +import pytest +from pathlib import Path +from src import aggregate + +def test_auto_aggregate_skip(tmp_path): + # Create some test files + f1 = tmp_path / "file1.txt" + f1.write_text("content1") + f2 = tmp_path / "file2.txt" + f2.write_text("content2") + + files = [ + {"path": "file1.txt", "auto_aggregate": True}, + {"path": "file2.txt", "auto_aggregate": False}, + ] + + items = aggregate.build_file_items(tmp_path, files) + + # Test _build_files_section_from_items + section = aggregate._build_files_section_from_items(items) + assert "file1.txt" in section + assert "file2.txt" not in section + + # Test build_tier1_context + t1 = aggregate.build_tier1_context(items, tmp_path, [], []) + assert "file1.txt" in t1 + assert "file2.txt" not in t1 + + # Test build_tier3_context + t3 = aggregate.build_tier3_context(items, tmp_path, [], [], []) + assert "file1.txt" in t3 + assert "file2.txt" not in t3 + +def test_force_full(tmp_path): + # Create a python file that would normally be skeletonized in Tier 3 + py_file = tmp_path / "script.py" + py_file.write_text("def hello():\n print('world')\n") + + # Tier 3 normally skeletonizes non-focus python files + items = aggregate.build_file_items(tmp_path, [{"path": "script.py", "force_full": True}]) + + # Test build_tier3_context + t3 = aggregate.build_tier3_context(items, tmp_path, [], [], []) + assert "print('world')" in t3 # Full content present + + # Compare with non-force_full + items2 = aggregate.build_file_items(tmp_path, [{"path": "script.py", "force_full": False}]) + t3_2 = aggregate.build_tier3_context(items2, tmp_path, [], [], []) + assert "print('world')" not in t3_2 # Skeletonized + + # Tier 1 normally summarizes non-core files + txt_file = tmp_path / "other.txt" + txt_file.write_text("line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10") + + items3 = aggregate.build_file_items(tmp_path, [{"path": "other.txt", "force_full": True}]) + t1 = aggregate.build_tier1_context(items3, tmp_path, [], []) + assert "line10" in t1 # Full content present + + items4 = aggregate.build_file_items(tmp_path, [{"path": "other.txt", "force_full": False}]) + t1_2 = aggregate.build_tier1_context(items4, tmp_path, [], []) + # Generic summary for .txt shows first 8 lines + assert "line10" not in t1_2 diff --git a/tests/test_ai_cache_tracking.py b/tests/test_ai_cache_tracking.py new file mode 100644 index 0000000..c9d508f --- /dev/null +++ b/tests/test_ai_cache_tracking.py @@ -0,0 +1,70 @@ +import unittest +from unittest.mock import patch, MagicMock +from src import ai_client +import time + +def test_gemini_cache_tracking() -> None: + # Setup + ai_client.reset_session() + ai_client.set_provider("gemini", "gemini-2.5-flash-lite") + + file_items = [ + {"path": "src/app.py", "content": "print('hello')", "mtime": 123.0}, + {"path": "src/utils.py", "content": "def util(): pass", "mtime": 456.0} + ] + + # Mock credentials + with patch("src.ai_client._load_credentials") as mock_creds: + mock_creds.return_value = {"gemini": {"api_key": "fake-key"}} + + # Mock genai.Client + with patch("google.genai.Client") as MockClient: + mock_client = MagicMock() + MockClient.return_value = mock_client + + # Mock count_tokens to return enough tokens for caching (>= 2048) + mock_client.models.count_tokens.return_value = MagicMock(total_tokens=3000) + + # Mock caches.create + mock_cache = MagicMock() + mock_cache.name = "cached_contents/abc" + mock_client.caches.create.return_value = mock_cache + + # Mock chat creation and send_message + mock_chat = MagicMock() + mock_client.chats.create.return_value = mock_chat + mock_chat.send_message.return_value = MagicMock( + text="Response", + candidates=[MagicMock(finish_reason=MagicMock(name="STOP"))], + usage_metadata=MagicMock(prompt_token_count=100, candidates_token_count=50, total_token_count=150) + ) + mock_chat._history = [] + + # Mock caches.list for stats + mock_client.caches.list.return_value = [MagicMock(size_bytes=5000)] + + # Act + ai_client.send( + md_content="Some long context that triggers caching", + user_message="Hello", + file_items=file_items + ) + + # Assert + stats = ai_client.get_gemini_cache_stats() + assert stats["cached_files"] == ["src/app.py", "src/utils.py"] + + # Test reset_session + ai_client.reset_session() + stats = ai_client.get_gemini_cache_stats() + assert stats["cached_files"] == [] + +def test_gemini_cache_tracking_cleanup() -> None: + ai_client._gemini_cached_file_paths = ["old.py"] + ai_client.cleanup() + assert ai_client._gemini_cached_file_paths == [] + +if __name__ == "__main__": + test_gemini_cache_tracking() + test_gemini_cache_tracking_cleanup() + print("All tests passed!") diff --git a/tests/test_file_item_model.py b/tests/test_file_item_model.py new file mode 100644 index 0000000..f1ff936 --- /dev/null +++ b/tests/test_file_item_model.py @@ -0,0 +1,39 @@ +import pytest +from src.models import FileItem + +def test_file_item_fields(): + """Test that FileItem exists and has correct default values.""" + item = FileItem(path="src/models.py") + assert item.path == "src/models.py" + assert item.auto_aggregate is True + assert item.force_full is False + +def test_file_item_to_dict(): + """Test that FileItem can be serialized to a dict.""" + item = FileItem(path="test.py", auto_aggregate=False, force_full=True) + expected = { + "path": "test.py", + "auto_aggregate": False, + "force_full": True + } + assert item.to_dict() == expected + +def test_file_item_from_dict(): + """Test that FileItem can be deserialized from a dict.""" + data = { + "path": "test.py", + "auto_aggregate": False, + "force_full": True + } + item = FileItem.from_dict(data) + assert item.path == "test.py" + assert item.auto_aggregate is False + assert item.force_full is True + +def test_file_item_from_dict_defaults(): + """Test that FileItem.from_dict handles missing fields.""" + data = {"path": "test.py"} + item = FileItem.from_dict(data) + assert item.path == "test.py" + assert item.auto_aggregate is True + assert item.force_full is False diff --git a/tests/test_project_serialization.py b/tests/test_project_serialization.py new file mode 100644 index 0000000..a9ceb71 --- /dev/null +++ b/tests/test_project_serialization.py @@ -0,0 +1,90 @@ +import os +import unittest +import tempfile +from pathlib import Path +from src import project_manager +from src import models +from src.app_controller import AppController + +class TestProjectSerialization(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.TemporaryDirectory() + self.project_path = Path(self.test_dir.name) / "test_project.toml" + + def tearDown(self): + self.test_dir.cleanup() + + def test_fileitem_roundtrip(self): + """Verify that FileItem objects survive a save/load cycle.""" + proj = project_manager.default_project("test") + file1 = models.FileItem(path="src/main.py", auto_aggregate=True, force_full=False) + file2 = models.FileItem(path="docs/readme.md", auto_aggregate=False, force_full=True) + proj["files"]["paths"] = [file1, file2] + + # Save + project_manager.save_project(proj, self.project_path) + + # Load + loaded_proj = project_manager.load_project(self.project_path) + + paths = loaded_proj["files"]["paths"] + self.assertEqual(len(paths), 2) + self.assertIsInstance(paths[0], models.FileItem) + self.assertEqual(paths[0].path, "src/main.py") + self.assertTrue(paths[0].auto_aggregate) + self.assertFalse(paths[0].force_full) + + self.assertIsInstance(paths[1], models.FileItem) + self.assertEqual(paths[1].path, "docs/readme.md") + self.assertFalse(paths[1].auto_aggregate) + self.assertTrue(paths[1].force_full) + + def test_backward_compatibility_strings(self): + """Verify that old-style string paths are converted to FileItem objects by AppController.""" + # Create a project file manually with string paths + content = """ +[project] +name = "legacy" + +[files] +base_dir = "." +paths = ["file1.py", "file2.md"] + +[discussion] +roles = ["User", "AI"] +""" + with open(self.project_path, "w") as f: + f.write(content) + + # Load via project_manager (should load as strings) + proj = project_manager.load_project(self.project_path) + self.assertEqual(proj["files"]["paths"], ["file1.py", "file2.md"]) + + # Initialize AppController state logic + controller = AppController() + controller.project = proj + + # Trigger deserialization (copied from init_state) + raw_paths = controller.project.get("files", {}).get("paths", []) + controller.files = [] + for p in raw_paths: + if isinstance(p, models.FileItem): + controller.files.append(p) + elif isinstance(p, dict): + controller.files.append(models.FileItem.from_dict(p)) + else: + controller.files.append(models.FileItem(path=str(p))) + + self.assertEqual(len(controller.files), 2) + self.assertIsInstance(controller.files[0], models.FileItem) + self.assertEqual(controller.files[0].path, "file1.py") + self.assertIsInstance(controller.files[1], models.FileItem) + self.assertEqual(controller.files[1].path, "file2.md") + + def test_default_roles_include_context(self): + """Verify that 'Context' is in default project roles.""" + proj = project_manager.default_project("test") + self.assertIn("Context", proj["discussion"]["roles"]) + +if __name__ == "__main__": + unittest.main()