feat(ui): Enhanced context control with per-file flags and Gemini cache awareness

This commit is contained in:
2026-03-07 12:13:08 -05:00
parent 61f331aee6
commit d7a6ba7e51
10 changed files with 383 additions and 28 deletions

View File

@@ -122,26 +122,32 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[
wants to upload individual files rather than inline everything as markdown. wants to upload individual files rather than inline everything as markdown.
Each dict has: Each dict has:
path : Path (resolved absolute path) path : Path (resolved absolute path)
entry : str (original config entry string) entry : str (original config entry string)
content : str (file text, or error string) content : str (file text, or error string)
error : bool error : bool
mtime : float (last modification time, for skip-if-unchanged optimization) mtime : float (last modification time, for skip-if-unchanged optimization)
tier : int | None (optional tier for context management) tier : int | None (optional tier for context management)
auto_aggregate : bool
force_full : bool
""" """
items: list[dict[str, Any]] = [] items: list[dict[str, Any]] = []
for entry_raw in files: for entry_raw in files:
if isinstance(entry_raw, dict): if isinstance(entry_raw, dict):
entry = cast(str, entry_raw.get("path", "")) entry = cast(str, entry_raw.get("path", ""))
tier = entry_raw.get("tier") tier = entry_raw.get("tier")
auto_aggregate = entry_raw.get("auto_aggregate", True)
force_full = entry_raw.get("force_full", False)
else: else:
entry = entry_raw entry = entry_raw
tier = None tier = None
auto_aggregate = True
force_full = False
if not entry or not isinstance(entry, str): if not entry or not isinstance(entry, str):
continue continue
paths = resolve_paths(base_dir, entry) paths = resolve_paths(base_dir, entry)
if not paths: if not paths:
items.append({"path": None, "entry": entry, "content": f"ERROR: no files matched: {entry}", "error": True, "mtime": 0.0, "tier": tier}) items.append({"path": None, "entry": entry, "content": f"ERROR: no files matched: {entry}", "error": True, "mtime": 0.0, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full})
continue continue
for path in paths: for path in paths:
try: try:
@@ -156,7 +162,7 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[
content = f"ERROR: {e}" content = f"ERROR: {e}"
mtime = 0.0 mtime = 0.0
error = True error = True
items.append({"path": path, "entry": entry, "content": content, "error": error, "mtime": mtime, "tier": tier}) items.append({"path": path, "entry": entry, "content": content, "error": error, "mtime": mtime, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full})
return items return items
def build_summary_section(base_dir: Path, files: list[str | dict[str, Any]]) -> str: def build_summary_section(base_dir: Path, files: list[str | dict[str, Any]]) -> str:
@@ -171,6 +177,8 @@ def _build_files_section_from_items(file_items: list[dict[str, Any]]) -> str:
"""Build the files markdown section from pre-read file items (avoids double I/O).""" """Build the files markdown section from pre-read file items (avoids double I/O)."""
sections = [] sections = []
for item in file_items: for item in file_items:
if not item.get("auto_aggregate", True):
continue
path = item.get("path") path = item.get("path")
entry = cast(str, item.get("entry", "unknown")) entry = cast(str, item.get("entry", "unknown"))
content = cast(str, item.get("content", "")) content = cast(str, item.get("content", ""))
@@ -221,9 +229,11 @@ def build_tier1_context(file_items: list[dict[str, Any]], screenshot_base_dir: P
if file_items: if file_items:
sections = [] sections = []
for item in file_items: for item in file_items:
if not item.get("auto_aggregate", True):
continue
path = item.get("path") path = item.get("path")
name = path.name if path and isinstance(path, Path) else "" name = path.name if path and isinstance(path, Path) else ""
if name in core_files or item.get("tier") == 1: if name in core_files or item.get("tier") == 1 or item.get("force_full"):
# Include in full # Include in full
sections.append("### `" + (cast(str, item.get("entry")) or str(path)) + "`\n\n" + sections.append("### `" + (cast(str, item.get("entry")) or str(path)) + "`\n\n" +
f"```{path.suffix.lstrip('.') if path and isinstance(path, Path) and path.suffix else 'text'}\n{item.get('content', '')}\n```") f"```{path.suffix.lstrip('.') if path and isinstance(path, Path) and path.suffix else 'text'}\n{item.get('content', '')}\n```")
@@ -255,6 +265,8 @@ def build_tier3_context(file_items: list[dict[str, Any]], screenshot_base_dir: P
if file_items: if file_items:
sections = [] sections = []
for item in file_items: for item in file_items:
if not item.get("auto_aggregate", True):
continue
path = cast(Path, item.get("path")) path = cast(Path, item.get("path"))
entry = cast(str, item.get("entry", "")) entry = cast(str, item.get("entry", ""))
path_str = str(path) if path else "" path_str = str(path) if path else ""
@@ -264,7 +276,7 @@ def build_tier3_context(file_items: list[dict[str, Any]], screenshot_base_dir: P
if focus == entry or (path and focus == path.name) or (path_str and focus in path_str): if focus == entry or (path and focus == path.name) or (path_str and focus in path_str):
is_focus = True is_focus = True
break break
if is_focus or item.get("tier") == 3: if is_focus or item.get("tier") == 3 or item.get("force_full"):
sections.append("### `" + (entry or path_str) + "`\n\n" + sections.append("### `" + (entry or path_str) + "`\n\n" +
f"```{path.suffix.lstrip('.') if path and path.suffix else 'text'}\n{item.get('content', '')}\n```") f"```{path.suffix.lstrip('.') if path and path.suffix else 'text'}\n{item.get('content', '')}\n```")
else: else:

View File

@@ -63,6 +63,7 @@ _gemini_chat: Any = None
_gemini_cache: Any = None _gemini_cache: Any = None
_gemini_cache_md_hash: Optional[str] = None _gemini_cache_md_hash: Optional[str] = None
_gemini_cache_created_at: Optional[float] = None _gemini_cache_created_at: Optional[float] = None
_gemini_cached_file_paths: list[str] = []
# Gemini cache TTL in seconds. Caches are created with this TTL and # Gemini cache TTL in seconds. Caches are created with this TTL and
# proactively rebuilt at 90% of this value to avoid stale-reference errors. # proactively rebuilt at 90% of this value to avoid stale-reference errors.
@@ -343,16 +344,17 @@ def get_provider() -> str:
return _provider return _provider
def cleanup() -> None: def cleanup() -> None:
global _gemini_client, _gemini_cache global _gemini_client, _gemini_cache, _gemini_cached_file_paths
if _gemini_client and _gemini_cache: if _gemini_client and _gemini_cache:
try: try:
_gemini_client.caches.delete(name=_gemini_cache.name) _gemini_client.caches.delete(name=_gemini_cache.name)
except Exception: except Exception:
pass pass
_gemini_cached_file_paths = []
def reset_session() -> None: def reset_session() -> None:
global _gemini_client, _gemini_chat, _gemini_cache global _gemini_client, _gemini_chat, _gemini_cache
global _gemini_cache_md_hash, _gemini_cache_created_at global _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths
global _anthropic_client, _anthropic_history global _anthropic_client, _anthropic_history
global _deepseek_client, _deepseek_history global _deepseek_client, _deepseek_history
global _minimax_client, _minimax_history global _minimax_client, _minimax_history
@@ -368,6 +370,7 @@ def reset_session() -> None:
_gemini_cache = None _gemini_cache = None
_gemini_cache_md_hash = None _gemini_cache_md_hash = None
_gemini_cache_created_at = None _gemini_cache_created_at = None
_gemini_cached_file_paths = []
# Preserve binary_path if adapter exists # Preserve binary_path if adapter exists
old_path = _gemini_cli_adapter.binary_path if _gemini_cli_adapter else "gemini" old_path = _gemini_cli_adapter.binary_path if _gemini_cli_adapter else "gemini"
@@ -389,14 +392,14 @@ def reset_session() -> None:
def get_gemini_cache_stats() -> dict[str, Any]: def get_gemini_cache_stats() -> dict[str, Any]:
_ensure_gemini_client() _ensure_gemini_client()
if not _gemini_client: if not _gemini_client:
return {"cache_count": 0, "total_size_bytes": 0} return {"cache_count": 0, "total_size_bytes": 0, "cached_files": []}
caches_iterator = _gemini_client.caches.list() caches_iterator = _gemini_client.caches.list()
caches = list(caches_iterator) caches = list(caches_iterator)
total_size_bytes = sum(getattr(c, 'size_bytes', 0) for c in caches) total_size_bytes = sum(getattr(c, 'size_bytes', 0) for c in caches)
return { return {
"cache_count": len(caches), "cache_count": len(caches),
"total_size_bytes": total_size_bytes, "total_size_bytes": total_size_bytes,
"cached_files": _gemini_cached_file_paths,
} }
def list_models(provider: str) -> list[str]: def list_models(provider: str) -> list[str]:
@@ -803,7 +806,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
enable_tools: bool = True, enable_tools: bool = True,
stream_callback: Optional[Callable[[str], None]] = None, stream_callback: Optional[Callable[[str], None]] = None,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths
try: try:
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir]) _ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>" sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"
@@ -820,6 +823,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
_gemini_chat = None _gemini_chat = None
_gemini_cache = None _gemini_cache = None
_gemini_cache_created_at = None _gemini_cache_created_at = None
_gemini_cached_file_paths = []
_append_comms("OUT", "request", {"message": "[CONTEXT CHANGED] Rebuilding cache and chat session..."}) _append_comms("OUT", "request", {"message": "[CONTEXT CHANGED] Rebuilding cache and chat session..."})
if _gemini_chat and _gemini_cache and _gemini_cache_created_at: if _gemini_chat and _gemini_cache and _gemini_cache_created_at:
elapsed = time.time() - _gemini_cache_created_at elapsed = time.time() - _gemini_cache_created_at
@@ -830,6 +834,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
_gemini_chat = None _gemini_chat = None
_gemini_cache = None _gemini_cache = None
_gemini_cache_created_at = None _gemini_cache_created_at = None
_gemini_cached_file_paths = []
_append_comms("OUT", "request", {"message": f"[CACHE TTL] Rebuilding cache (expired after {int(elapsed)}s)..."}) _append_comms("OUT", "request", {"message": f"[CACHE TTL] Rebuilding cache (expired after {int(elapsed)}s)..."})
if not _gemini_chat: if not _gemini_chat:
chat_config = types.GenerateContentConfig( chat_config = types.GenerateContentConfig(
@@ -860,6 +865,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
) )
) )
_gemini_cache_created_at = time.time() _gemini_cache_created_at = time.time()
_gemini_cached_file_paths = [str(item.get("path", "")) for item in (file_items or []) if item.get("path")]
chat_config = types.GenerateContentConfig( chat_config = types.GenerateContentConfig(
cached_content=_gemini_cache.name, cached_content=_gemini_cache.name,
temperature=_temperature, temperature=_temperature,
@@ -870,6 +876,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
except Exception as e: except Exception as e:
_gemini_cache = None _gemini_cache = None
_gemini_cache_created_at = None _gemini_cache_created_at = None
_gemini_cached_file_paths = []
_append_comms("OUT", "request", {"message": f"[CACHE FAILED] {type(e).__name__}: {e} \u2014 falling back to inline system_instruction"}) _append_comms("OUT", "request", {"message": f"[CACHE FAILED] {type(e).__name__}: {e} \u2014 falling back to inline system_instruction"})
kwargs: dict[str, Any] = {"model": _model, "config": chat_config} kwargs: dict[str, Any] = {"model": _model, "config": chat_config}
if old_history: if old_history:

View File

@@ -147,6 +147,7 @@ class AppController:
self._tool_log: List[Dict[str, Any]] = [] self._tool_log: List[Dict[str, Any]] = []
self._tool_stats: Dict[str, Dict[str, Any]] = {} # {tool_name: {"count": 0, "total_time_ms": 0.0, "failures": 0}} self._tool_stats: Dict[str, Dict[str, Any]] = {} # {tool_name: {"count": 0, "total_time_ms": 0.0, "failures": 0}}
self._cached_cache_stats: Dict[str, Any] = {} # Pre-computed cache stats for GUI self._cached_cache_stats: Dict[str, Any] = {} # Pre-computed cache stats for GUI
self._cached_files: List[str] = []
self._token_history: List[Dict[str, Any]] = [] # Token usage over time [{"time": t, "input": n, "output": n, "model": s}, ...] self._token_history: List[Dict[str, Any]] = [] # Token usage over time [{"time": t, "input": n, "output": n, "model": s}, ...]
self._session_start_time: float = time.time() # For calculating burn rate self._session_start_time: float = time.time() # For calculating burn rate
self._ticket_start_times: dict[str, float] = {} self._ticket_start_times: dict[str, float] = {}
@@ -702,10 +703,19 @@ class AppController:
self.project_paths = list(projects_cfg.get("paths", [])) self.project_paths = list(projects_cfg.get("paths", []))
self.active_project_path = projects_cfg.get("active", "") self.active_project_path = projects_cfg.get("active", "")
self._load_active_project() self._load_active_project()
self.files = list(self.project.get("files", {}).get("paths", [])) # Deserialize FileItems in files.paths
raw_paths = self.project.get("files", {}).get("paths", [])
self.files = []
for p in raw_paths:
if isinstance(p, models.FileItem):
self.files.append(p)
elif isinstance(p, dict):
self.files.append(models.FileItem.from_dict(p))
else:
self.files.append(models.FileItem(path=str(p)))
self.screenshots = list(self.project.get("screenshots", {}).get("paths", [])) self.screenshots = list(self.project.get("screenshots", {}).get("paths", []))
disc_sec = self.project.get("discussion", {}) disc_sec = self.project.get("discussion", {})
self.disc_roles = list(disc_sec.get("roles", ["User", "AI", "Vendor API", "System"])) self.disc_roles = list(disc_sec.get("roles", ["User", "AI", "Vendor API", "System", "Reasoning", "Context"]))
self.active_discussion = disc_sec.get("active", "main") self.active_discussion = disc_sec.get("active", "main")
disc_data = disc_sec.get("discussions", {}).get(self.active_discussion, {}) disc_data = disc_sec.get("discussions", {}).get(self.active_discussion, {})
with self._disc_entries_lock: with self._disc_entries_lock:
@@ -1804,6 +1814,9 @@ class AppController:
if k in usage: if k in usage:
usage[k] += u.get(k, 0) or 0 usage[k] += u.get(k, 0) or 0
self.session_usage = usage self.session_usage = usage
# Update cached files list
stats = ai_client.get_gemini_cache_stats()
self._cached_files = stats.get("cached_files", [])
def _refresh_api_metrics(self, payload: dict[str, Any], md_content: str | None = None) -> None: def _refresh_api_metrics(self, payload: dict[str, Any], md_content: str | None = None) -> None:
if "latency" in payload: if "latency" in payload:

View File

@@ -774,7 +774,14 @@ class App:
imgui.separator() imgui.separator()
if imgui.button("Inject", imgui.ImVec2(120, 0)): if imgui.button("Inject", imgui.ImVec2(120, 0)):
formatted = f"## File: {self._inject_file_path}\n```python\n{self._inject_preview}\n```\n" formatted = f"## File: {self._inject_file_path}\n```python\n{self._inject_preview}\n```\n"
self.ui_ai_input += formatted with self._disc_entries_lock:
self.disc_entries.append({
"role": "Context",
"content": formatted,
"collapsed": True,
"ts": project_manager.now_ts()
})
self._scroll_disc_to_bottom = True
imgui.close_current_popup() imgui.close_current_popup()
imgui.same_line() imgui.same_line()
if imgui.button("Cancel", imgui.ImVec2(120, 0)): if imgui.button("Cancel", imgui.ImVec2(120, 0)):
@@ -1075,25 +1082,52 @@ class App:
imgui.separator() imgui.separator()
imgui.text("Paths") imgui.text("Paths")
imgui.begin_child("f_paths", imgui.ImVec2(0, -40), True) imgui.begin_child("f_paths", imgui.ImVec2(0, -40), True)
for i, f in enumerate(self.files): if imgui.begin_table("files_table", 4, imgui.TableFlags_.resizable | imgui.TableFlags_.borders):
if imgui.button(f"x##f{i}"): imgui.table_setup_column("Actions", imgui.TableColumnFlags_.width_fixed, 40)
self.files.pop(i) imgui.table_setup_column("File Path", imgui.TableColumnFlags_.width_stretch)
break imgui.table_setup_column("Flags", imgui.TableColumnFlags_.width_fixed, 150)
imgui.same_line() imgui.table_setup_column("Cache", imgui.TableColumnFlags_.width_fixed, 40)
imgui.text(f) imgui.table_headers_row()
for i, f_item in enumerate(self.files):
imgui.table_next_row()
# Actions
imgui.table_set_column_index(0)
if imgui.button(f"x##f{i}"):
self.files.pop(i)
break
# File Path
imgui.table_set_column_index(1)
imgui.text(f_item.path if hasattr(f_item, "path") else str(f_item))
# Flags
imgui.table_set_column_index(2)
if hasattr(f_item, "auto_aggregate"):
changed_agg, f_item.auto_aggregate = imgui.checkbox(f"Agg##a{i}", f_item.auto_aggregate)
imgui.same_line()
changed_full, f_item.force_full = imgui.checkbox(f"Full##f{i}", f_item.force_full)
# Cache
imgui.table_set_column_index(3)
path = f_item.path if hasattr(f_item, "path") else str(f_item)
is_cached = any(path in c for c in getattr(self, "_cached_files", []))
if is_cached:
imgui.text_colored("", imgui.ImVec4(0, 1, 0, 1)) # Green dot
else:
imgui.text_disabled("")
imgui.end_table()
imgui.end_child() imgui.end_child()
if imgui.button("Add File(s)"): if imgui.button("Add File(s)"):
r = hide_tk_root() r = hide_tk_root()
paths = filedialog.askopenfilenames() paths = filedialog.askopenfilenames()
r.destroy() r.destroy()
for p in paths: for p in paths:
if p not in self.files: self.files.append(p) if p not in [f.path if hasattr(f, "path") else f for f in self.files]:
self.files.append(models.FileItem(path=p))
imgui.same_line() imgui.same_line()
if imgui.button("Add Wildcard"): if imgui.button("Add Wildcard"):
r = hide_tk_root() r = hide_tk_root()
d = filedialog.askdirectory() d = filedialog.askdirectory()
r.destroy() r.destroy()
if d: self.files.append(str(Path(d) / "**" / "*")) if d: self.files.append(models.FileItem(path=str(Path(d) / "**" / "*")))
def _render_screenshots_panel(self) -> None: def _render_screenshots_panel(self) -> None:
imgui.text("Base Dir") imgui.text("Base Dir")

View File

@@ -233,3 +233,24 @@ class TrackState:
discussion=parsed_discussion, discussion=parsed_discussion,
tasks=[Ticket.from_dict(t) for t in data.get("tasks", [])], tasks=[Ticket.from_dict(t) for t in data.get("tasks", [])],
) )
@dataclass
class FileItem:
path: str
auto_aggregate: bool = True
force_full: bool = False
def to_dict(self) -> Dict[str, Any]:
return {
"path": self.path,
"auto_aggregate": self.auto_aggregate,
"force_full": self.force_full,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FileItem":
return cls(
path=data["path"],
auto_aggregate=data.get("auto_aggregate", True),
force_full=data.get("force_full", False),
)

View File

@@ -126,7 +126,7 @@ def default_project(name: str = "unnamed") -> dict[str, Any]:
} }
}, },
"discussion": { "discussion": {
"roles": ["User", "AI", "Vendor API", "System", "Reasoning"], "roles": ["User", "AI", "Vendor API", "System", "Reasoning", "Context"],
"active": "main", "active": "main",
"discussions": {"main": default_discussion()}, "discussions": {"main": default_discussion()},
}, },
@@ -150,6 +150,10 @@ def load_project(path: Union[str, Path]) -> dict[str, Any]:
""" """
with open(path, "rb") as f: with open(path, "rb") as f:
proj = tomllib.load(f) proj = tomllib.load(f)
# Deserialise FileItems in files.paths
if "files" in proj and "paths" in proj["files"]:
from src import models
proj["files"]["paths"] = [models.FileItem.from_dict(p) if isinstance(p, dict) else p for p in proj["files"]["paths"]]
hist_path = get_history_path(path) hist_path = get_history_path(path)
if "discussion" in proj: if "discussion" in proj:
disc = proj.pop("discussion") disc = proj.pop("discussion")
@@ -184,6 +188,9 @@ def save_project(proj: dict[str, Any], path: Union[str, Path], disc_data: Option
If 'discussion' is present in proj, it is moved to the sibling history file. If 'discussion' is present in proj, it is moved to the sibling history file.
""" """
proj = clean_nones(proj) proj = clean_nones(proj)
# Serialise FileItems
if "files" in proj and "paths" in proj["files"]:
proj["files"]["paths"] = [p.to_dict() if hasattr(p, "to_dict") else p for p in proj["files"]["paths"]]
if "discussion" in proj: if "discussion" in proj:
if disc_data is None: if disc_data is None:
disc_data = proj["discussion"] disc_data = proj["discussion"]
@@ -206,7 +213,7 @@ def migrate_from_legacy_config(cfg: dict[str, Any]) -> dict[str, Any]:
if key in cfg: if key in cfg:
proj[key] = dict(cfg[key]) proj[key] = dict(cfg[key])
disc = cfg.get("discussion", {}) disc = cfg.get("discussion", {})
proj["discussion"]["roles"] = disc.get("roles", ["User", "AI", "Vendor API", "System"]) proj["discussion"]["roles"] = disc.get("roles", ["User", "AI", "Vendor API", "System", "Context"])
main_disc = proj["discussion"]["discussions"]["main"] main_disc = proj["discussion"]["discussions"]["main"]
main_disc["history"] = disc.get("history", []) main_disc["history"] = disc.get("history", [])
main_disc["last_updated"] = now_ts() main_disc["last_updated"] = now_ts()

View File

@@ -0,0 +1,62 @@
import pytest
from pathlib import Path
from src import aggregate
def test_auto_aggregate_skip(tmp_path):
# Create some test files
f1 = tmp_path / "file1.txt"
f1.write_text("content1")
f2 = tmp_path / "file2.txt"
f2.write_text("content2")
files = [
{"path": "file1.txt", "auto_aggregate": True},
{"path": "file2.txt", "auto_aggregate": False},
]
items = aggregate.build_file_items(tmp_path, files)
# Test _build_files_section_from_items
section = aggregate._build_files_section_from_items(items)
assert "file1.txt" in section
assert "file2.txt" not in section
# Test build_tier1_context
t1 = aggregate.build_tier1_context(items, tmp_path, [], [])
assert "file1.txt" in t1
assert "file2.txt" not in t1
# Test build_tier3_context
t3 = aggregate.build_tier3_context(items, tmp_path, [], [], [])
assert "file1.txt" in t3
assert "file2.txt" not in t3
def test_force_full(tmp_path):
# Create a python file that would normally be skeletonized in Tier 3
py_file = tmp_path / "script.py"
py_file.write_text("def hello():\n print('world')\n")
# Tier 3 normally skeletonizes non-focus python files
items = aggregate.build_file_items(tmp_path, [{"path": "script.py", "force_full": True}])
# Test build_tier3_context
t3 = aggregate.build_tier3_context(items, tmp_path, [], [], [])
assert "print('world')" in t3 # Full content present
# Compare with non-force_full
items2 = aggregate.build_file_items(tmp_path, [{"path": "script.py", "force_full": False}])
t3_2 = aggregate.build_tier3_context(items2, tmp_path, [], [], [])
assert "print('world')" not in t3_2 # Skeletonized
# Tier 1 normally summarizes non-core files
txt_file = tmp_path / "other.txt"
txt_file.write_text("line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10")
items3 = aggregate.build_file_items(tmp_path, [{"path": "other.txt", "force_full": True}])
t1 = aggregate.build_tier1_context(items3, tmp_path, [], [])
assert "line10" in t1 # Full content present
items4 = aggregate.build_file_items(tmp_path, [{"path": "other.txt", "force_full": False}])
t1_2 = aggregate.build_tier1_context(items4, tmp_path, [], [])
# Generic summary for .txt shows first 8 lines
assert "line10" not in t1_2

View File

@@ -0,0 +1,70 @@
import unittest
from unittest.mock import patch, MagicMock
from src import ai_client
import time
def test_gemini_cache_tracking() -> None:
# Setup
ai_client.reset_session()
ai_client.set_provider("gemini", "gemini-2.5-flash-lite")
file_items = [
{"path": "src/app.py", "content": "print('hello')", "mtime": 123.0},
{"path": "src/utils.py", "content": "def util(): pass", "mtime": 456.0}
]
# Mock credentials
with patch("src.ai_client._load_credentials") as mock_creds:
mock_creds.return_value = {"gemini": {"api_key": "fake-key"}}
# Mock genai.Client
with patch("google.genai.Client") as MockClient:
mock_client = MagicMock()
MockClient.return_value = mock_client
# Mock count_tokens to return enough tokens for caching (>= 2048)
mock_client.models.count_tokens.return_value = MagicMock(total_tokens=3000)
# Mock caches.create
mock_cache = MagicMock()
mock_cache.name = "cached_contents/abc"
mock_client.caches.create.return_value = mock_cache
# Mock chat creation and send_message
mock_chat = MagicMock()
mock_client.chats.create.return_value = mock_chat
mock_chat.send_message.return_value = MagicMock(
text="Response",
candidates=[MagicMock(finish_reason=MagicMock(name="STOP"))],
usage_metadata=MagicMock(prompt_token_count=100, candidates_token_count=50, total_token_count=150)
)
mock_chat._history = []
# Mock caches.list for stats
mock_client.caches.list.return_value = [MagicMock(size_bytes=5000)]
# Act
ai_client.send(
md_content="Some long context that triggers caching",
user_message="Hello",
file_items=file_items
)
# Assert
stats = ai_client.get_gemini_cache_stats()
assert stats["cached_files"] == ["src/app.py", "src/utils.py"]
# Test reset_session
ai_client.reset_session()
stats = ai_client.get_gemini_cache_stats()
assert stats["cached_files"] == []
def test_gemini_cache_tracking_cleanup() -> None:
ai_client._gemini_cached_file_paths = ["old.py"]
ai_client.cleanup()
assert ai_client._gemini_cached_file_paths == []
if __name__ == "__main__":
test_gemini_cache_tracking()
test_gemini_cache_tracking_cleanup()
print("All tests passed!")

View File

@@ -0,0 +1,39 @@
import pytest
from src.models import FileItem
def test_file_item_fields():
"""Test that FileItem exists and has correct default values."""
item = FileItem(path="src/models.py")
assert item.path == "src/models.py"
assert item.auto_aggregate is True
assert item.force_full is False
def test_file_item_to_dict():
"""Test that FileItem can be serialized to a dict."""
item = FileItem(path="test.py", auto_aggregate=False, force_full=True)
expected = {
"path": "test.py",
"auto_aggregate": False,
"force_full": True
}
assert item.to_dict() == expected
def test_file_item_from_dict():
"""Test that FileItem can be deserialized from a dict."""
data = {
"path": "test.py",
"auto_aggregate": False,
"force_full": True
}
item = FileItem.from_dict(data)
assert item.path == "test.py"
assert item.auto_aggregate is False
assert item.force_full is True
def test_file_item_from_dict_defaults():
"""Test that FileItem.from_dict handles missing fields."""
data = {"path": "test.py"}
item = FileItem.from_dict(data)
assert item.path == "test.py"
assert item.auto_aggregate is True
assert item.force_full is False

View File

@@ -0,0 +1,90 @@
import os
import unittest
import tempfile
from pathlib import Path
from src import project_manager
from src import models
from src.app_controller import AppController
class TestProjectSerialization(unittest.TestCase):
def setUp(self):
self.test_dir = tempfile.TemporaryDirectory()
self.project_path = Path(self.test_dir.name) / "test_project.toml"
def tearDown(self):
self.test_dir.cleanup()
def test_fileitem_roundtrip(self):
"""Verify that FileItem objects survive a save/load cycle."""
proj = project_manager.default_project("test")
file1 = models.FileItem(path="src/main.py", auto_aggregate=True, force_full=False)
file2 = models.FileItem(path="docs/readme.md", auto_aggregate=False, force_full=True)
proj["files"]["paths"] = [file1, file2]
# Save
project_manager.save_project(proj, self.project_path)
# Load
loaded_proj = project_manager.load_project(self.project_path)
paths = loaded_proj["files"]["paths"]
self.assertEqual(len(paths), 2)
self.assertIsInstance(paths[0], models.FileItem)
self.assertEqual(paths[0].path, "src/main.py")
self.assertTrue(paths[0].auto_aggregate)
self.assertFalse(paths[0].force_full)
self.assertIsInstance(paths[1], models.FileItem)
self.assertEqual(paths[1].path, "docs/readme.md")
self.assertFalse(paths[1].auto_aggregate)
self.assertTrue(paths[1].force_full)
def test_backward_compatibility_strings(self):
"""Verify that old-style string paths are converted to FileItem objects by AppController."""
# Create a project file manually with string paths
content = """
[project]
name = "legacy"
[files]
base_dir = "."
paths = ["file1.py", "file2.md"]
[discussion]
roles = ["User", "AI"]
"""
with open(self.project_path, "w") as f:
f.write(content)
# Load via project_manager (should load as strings)
proj = project_manager.load_project(self.project_path)
self.assertEqual(proj["files"]["paths"], ["file1.py", "file2.md"])
# Initialize AppController state logic
controller = AppController()
controller.project = proj
# Trigger deserialization (copied from init_state)
raw_paths = controller.project.get("files", {}).get("paths", [])
controller.files = []
for p in raw_paths:
if isinstance(p, models.FileItem):
controller.files.append(p)
elif isinstance(p, dict):
controller.files.append(models.FileItem.from_dict(p))
else:
controller.files.append(models.FileItem(path=str(p)))
self.assertEqual(len(controller.files), 2)
self.assertIsInstance(controller.files[0], models.FileItem)
self.assertEqual(controller.files[0].path, "file1.py")
self.assertIsInstance(controller.files[1], models.FileItem)
self.assertEqual(controller.files[1].path, "file2.md")
def test_default_roles_include_context(self):
"""Verify that 'Context' is in default project roles."""
proj = project_manager.default_project("test")
self.assertIn("Context", proj["discussion"]["roles"])
if __name__ == "__main__":
unittest.main()