From b4f8633bd658f395af5134f5fcd0c38dd6ebf20e Mon Sep 17 00:00:00 2001 From: Ed_ Date: Sun, 10 May 2026 13:28:15 -0400 Subject: [PATCH] feat(context): Interactive AST Tree Masking with per-symbol toggles --- src/aggregate.py | 29 ++++++++++- src/gui_2.py | 92 +++++++++++++++++++++++++++++++++++ src/models.py | 3 ++ tests/test_file_item_model.py | 6 ++- 4 files changed, 127 insertions(+), 3 deletions(-) diff --git a/src/aggregate.py b/src/aggregate.py index 7a8f664..5f61645 100644 --- a/src/aggregate.py +++ b/src/aggregate.py @@ -123,6 +123,7 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[ force_full = entry_raw.get("force_full", False) ast_signatures = entry_raw.get("ast_signatures", False) ast_definitions = entry_raw.get("ast_definitions", False) + ast_mask = entry_raw.get("ast_mask", {}) elif hasattr(entry_raw, "path"): entry = entry_raw.path tier = getattr(entry_raw, "tier", None) @@ -130,6 +131,7 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[ force_full = getattr(entry_raw, "force_full", False) ast_signatures = getattr(entry_raw, "ast_signatures", False) ast_definitions = getattr(entry_raw, "ast_definitions", False) + ast_mask = getattr(entry_raw, "ast_mask", {}) else: entry = entry_raw tier = None @@ -137,11 +139,12 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[ force_full = False ast_signatures = False ast_definitions = False + ast_mask = {} if not entry or not isinstance(entry, str): continue paths = resolve_paths(base_dir, entry) if not paths: - items.append({"path": None, "entry": entry, "content": f"ERROR: no files matched: {entry}", "error": True, "mtime": 0.0, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full, "ast_signatures": ast_signatures, "ast_definitions": ast_definitions}) + items.append({"path": None, "entry": entry, "content": f"ERROR: no files matched: {entry}", "error": True, "mtime": 0.0, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full, "ast_signatures": ast_signatures, "ast_definitions": ast_definitions, "ast_mask": ast_mask}) continue for path in paths: try: @@ -156,7 +159,7 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[ content = f"ERROR: {e}" mtime = 0.0 error = True - items.append({"path": path, "entry": entry, "content": content, "error": error, "mtime": mtime, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full, "ast_signatures": ast_signatures, "ast_definitions": ast_definitions}) + items.append({"path": path, "entry": entry, "content": content, "error": error, "mtime": mtime, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full, "ast_signatures": ast_signatures, "ast_definitions": ast_definitions, "ast_mask": ast_mask}) return items @@ -272,6 +275,7 @@ def build_tier3_context(file_items: list[dict[str, Any]], screenshot_base_dir: P force_full = item.get("force_full") ast_signatures = item.get("ast_signatures", False) ast_definitions = item.get("ast_definitions", False) + ast_mask = item.get("ast_mask", {}) content = item.get("content", "") is_focus = entry in focus_set or (name and name in focus_set) or (path_str and path_str in focus_set) if not is_focus and path_str: @@ -284,6 +288,27 @@ def build_tier3_context(file_items: list[dict[str, Any]], screenshot_base_dir: P suffix = path.suffix.lstrip(".") if path and path.suffix else "text" sections.append(f"### `{original}`\n\n```{suffix}\n{content}\n```") elif path: + if ast_mask and not item.get("error"): + mask_sections = [] + from src import mcp_client + for symbol, mode in ast_mask.items(): + if mode == "hide": + continue + res = "" + if path.suffix == ".py": + res = mcp_client.py_get_definition(str(path), symbol) if mode == "def" else mcp_client.py_get_signature(str(path), symbol) + elif path.suffix in [".c", ".h", ".cpp", ".hpp", ".cxx", ".cc"]: + is_cpp = any(ext in path.suffix for ext in [".cpp", ".hpp", ".cxx", ".cc"]) + if mode == "def": + res = mcp_client.ts_cpp_get_definition(str(path), symbol) if is_cpp else mcp_client.ts_c_get_definition(str(path), symbol) + else: + res = mcp_client.ts_cpp_get_signature(str(path), symbol) if is_cpp else mcp_client.ts_c_get_signature(str(path), symbol) + if res: + mask_sections.append(res) + if mask_sections: + suffix = path.suffix.lstrip(".") if path.suffix else "text" + sections.append(f"### `{original}` (Masked)\n\n```{suffix}\n" + "\n\n".join(mask_sections) + "\n```") + continue if path.suffix in ['.c', '.h', '.cpp', '.hpp', '.cxx', '.cc'] and not item.get("error"): from src import mcp_client if ast_definitions: diff --git a/src/gui_2.py b/src/gui_2.py index a89d6db..33d8dea 100644 --- a/src/gui_2.py +++ b/src/gui_2.py @@ -239,6 +239,9 @@ class App: self.shader_uniforms = {'crt': 1.0, 'scanline': 0.5, 'bloom': 0.8} self.ui_new_context_preset_name = "" self._focus_md_cache: dict[str, str] = {} + self.ui_inspecting_ast_file = None + self._cached_ast_nodes = [] + self._cached_ast_file_path = '' """UI-level wrapper for approving a pending tool execution ask.""" self._handle_approve_ask() @@ -1337,6 +1340,9 @@ class App: if imgui.button("Cancel", imgui.ImVec2(120, 0)): imgui.close_current_popup() imgui.end_popup() + + self._render_ast_inspector_modal() + except Exception as e: print(f"ERROR in _gui_func: {e}") traceback.print_exc() @@ -1469,6 +1475,86 @@ class App: imgui.close_current_popup() imgui.end_popup() + def _render_ast_inspector_modal(self) -> None: + expanded, opened = imgui.begin_popup_modal('AST Inspector', True, imgui.WindowFlags_.always_auto_resize) + if expanded: + if self.ui_inspecting_ast_file is None: + imgui.close_current_popup() + else: + f_item = self.ui_inspecting_ast_file + f_path = f_item.path if hasattr(f_item, "path") else str(f_item) + + if f_path != self._cached_ast_file_path: + outline = "" + try: + if f_path.lower().endswith('.py'): + outline = mcp_client.py_get_code_outline(f_path) + elif f_path.lower().endswith(('.c', '.h')): + outline = mcp_client.ts_c_get_code_outline(f_path) + else: + outline = mcp_client.ts_cpp_get_code_outline(f_path) + except Exception as e: + outline = f"Error fetching outline: {e}" + + self._cached_ast_nodes = [] + import re + pattern = re.compile(r'^(\s*)\[(.*?)\] (.*?) \(Lines \d+-\d+\)') + stack = [] # (indent, name) + for line in outline.splitlines(): + m = pattern.match(line) + if m: + indent_str, kind, name = m.groups() + indent = len(indent_str) + while stack and stack[-1][0] >= indent: + stack.pop() + stack.append((indent, name)) + full_path = '::'.join([s[1] for s in stack]) + self._cached_ast_nodes.append({ + 'indent': indent, + 'kind': kind, + 'name': name, + 'full_path': full_path + }) + self._cached_ast_file_path = f_path + + imgui.text(f"Inspecting AST: {f_path}") + imgui.separator() + + imgui.begin_child("ast_tree_scroll", imgui.ImVec2(800, 600), True) + if not self._cached_ast_nodes: + imgui.text("No AST nodes found or error fetching outline.") + else: + for node in self._cached_ast_nodes: + indent = node['indent'] + kind = node['kind'] + name = node['name'] + full_path = node['full_path'] + + imgui.dummy(imgui.ImVec2(indent * 10, 0)) + imgui.same_line() + imgui.text(f"[{kind}] {name}") + imgui.same_line(imgui.get_window_width() - 200) + + current_mode = f_item.ast_mask.get(full_path, 'hide') + + imgui.push_id(full_path) + if imgui.radio_button("Def", current_mode == 'def'): + f_item.ast_mask[full_path] = 'def' + imgui.same_line() + if imgui.radio_button("Sig", current_mode == 'sig'): + f_item.ast_mask[full_path] = 'sig' + imgui.same_line() + if imgui.radio_button("Hide", current_mode == 'hide'): + f_item.ast_mask[full_path] = 'hide' + imgui.pop_id() + imgui.end_child() + + imgui.separator() + if imgui.button("Close", imgui.ImVec2(120, 0)): + self.ui_inspecting_ast_file = None + imgui.close_current_popup() + imgui.end_popup() + def _render_save_workspace_profile_modal(self) -> None: if self._show_save_workspace_profile_modal: imgui.open_popup("Save Workspace Profile") @@ -2619,6 +2705,12 @@ class App: imgui.same_line() imgui.text(f_path) + if f_path.lower().endswith(('.c', '.cpp', '.h', '.hpp', '.cxx', '.cc')): + imgui.same_line() + if imgui.button(f"[Inspect]##{i}"): + self.ui_inspecting_ast_file = f_item + imgui.open_popup('AST Inspector') + imgui.table_set_column_index(1) if hasattr(f_item, "auto_aggregate"): changed_agg, f_item.auto_aggregate = imgui.checkbox(f"Agg##cc{i}", f_item.auto_aggregate) diff --git a/src/models.py b/src/models.py index 95cef81..9a7ac0b 100644 --- a/src/models.py +++ b/src/models.py @@ -497,6 +497,7 @@ class FileItem: force_full: bool = False ast_signatures: bool = False ast_definitions: bool = False + ast_mask: dict[str, str] = field(default_factory=dict) injected_at: Optional[float] = None def to_dict(self) -> Dict[str, Any]: @@ -509,6 +510,7 @@ class FileItem: "force_full": self.force_full, "ast_signatures": self.ast_signatures, "ast_definitions": self.ast_definitions, + "ast_mask": self.ast_mask, "injected_at": self.injected_at, } @@ -523,6 +525,7 @@ class FileItem: force_full=data.get("force_full", False), ast_signatures=data.get("ast_signatures", False), ast_definitions=data.get("ast_definitions", False), + ast_mask=data.get("ast_mask", {}), injected_at=data.get("injected_at"), ) diff --git a/tests/test_file_item_model.py b/tests/test_file_item_model.py index 2e6537c..b8f089b 100644 --- a/tests/test_file_item_model.py +++ b/tests/test_file_item_model.py @@ -7,6 +7,7 @@ def test_file_item_fields(): assert item.path == "src/models.py" assert item.auto_aggregate is True assert item.force_full is False + assert item.ast_mask == {} assert item.injected_at is None def test_file_item_to_dict(): @@ -18,6 +19,7 @@ def test_file_item_to_dict(): "force_full": True, "ast_signatures": False, "ast_definitions": False, + "ast_mask": {}, "injected_at": None } assert item.to_dict() == expected @@ -35,6 +37,7 @@ def test_file_item_from_dict(): assert item.auto_aggregate is False assert item.force_full is True assert item.injected_at == 123.456 + assert item.ast_mask == {} def test_file_item_from_dict_defaults(): """Test that FileItem.from_dict handles missing fields.""" @@ -43,4 +46,5 @@ def test_file_item_from_dict_defaults(): assert item.path == "test.py" assert item.auto_aggregate is True assert item.force_full is False - assert item.injected_at is None \ No newline at end of file + assert item.ast_mask == {} + assert item.injected_at is None