feat(context): Interactive AST Tree Masking with per-symbol toggles
This commit is contained in:
+27
-2
@@ -123,6 +123,7 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[
|
|||||||
force_full = entry_raw.get("force_full", False)
|
force_full = entry_raw.get("force_full", False)
|
||||||
ast_signatures = entry_raw.get("ast_signatures", False)
|
ast_signatures = entry_raw.get("ast_signatures", False)
|
||||||
ast_definitions = entry_raw.get("ast_definitions", False)
|
ast_definitions = entry_raw.get("ast_definitions", False)
|
||||||
|
ast_mask = entry_raw.get("ast_mask", {})
|
||||||
elif hasattr(entry_raw, "path"):
|
elif hasattr(entry_raw, "path"):
|
||||||
entry = entry_raw.path
|
entry = entry_raw.path
|
||||||
tier = getattr(entry_raw, "tier", None)
|
tier = getattr(entry_raw, "tier", None)
|
||||||
@@ -130,6 +131,7 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[
|
|||||||
force_full = getattr(entry_raw, "force_full", False)
|
force_full = getattr(entry_raw, "force_full", False)
|
||||||
ast_signatures = getattr(entry_raw, "ast_signatures", False)
|
ast_signatures = getattr(entry_raw, "ast_signatures", False)
|
||||||
ast_definitions = getattr(entry_raw, "ast_definitions", False)
|
ast_definitions = getattr(entry_raw, "ast_definitions", False)
|
||||||
|
ast_mask = getattr(entry_raw, "ast_mask", {})
|
||||||
else:
|
else:
|
||||||
entry = entry_raw
|
entry = entry_raw
|
||||||
tier = None
|
tier = None
|
||||||
@@ -137,11 +139,12 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[
|
|||||||
force_full = False
|
force_full = False
|
||||||
ast_signatures = False
|
ast_signatures = False
|
||||||
ast_definitions = False
|
ast_definitions = False
|
||||||
|
ast_mask = {}
|
||||||
if not entry or not isinstance(entry, str):
|
if not entry or not isinstance(entry, str):
|
||||||
continue
|
continue
|
||||||
paths = resolve_paths(base_dir, entry)
|
paths = resolve_paths(base_dir, entry)
|
||||||
if not paths:
|
if not paths:
|
||||||
items.append({"path": None, "entry": entry, "content": f"ERROR: no files matched: {entry}", "error": True, "mtime": 0.0, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full, "ast_signatures": ast_signatures, "ast_definitions": ast_definitions})
|
items.append({"path": None, "entry": entry, "content": f"ERROR: no files matched: {entry}", "error": True, "mtime": 0.0, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full, "ast_signatures": ast_signatures, "ast_definitions": ast_definitions, "ast_mask": ast_mask})
|
||||||
continue
|
continue
|
||||||
for path in paths:
|
for path in paths:
|
||||||
try:
|
try:
|
||||||
@@ -156,7 +159,7 @@ def build_file_items(base_dir: Path, files: list[str | dict[str, Any]]) -> list[
|
|||||||
content = f"ERROR: {e}"
|
content = f"ERROR: {e}"
|
||||||
mtime = 0.0
|
mtime = 0.0
|
||||||
error = True
|
error = True
|
||||||
items.append({"path": path, "entry": entry, "content": content, "error": error, "mtime": mtime, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full, "ast_signatures": ast_signatures, "ast_definitions": ast_definitions})
|
items.append({"path": path, "entry": entry, "content": content, "error": error, "mtime": mtime, "tier": tier, "auto_aggregate": auto_aggregate, "force_full": force_full, "ast_signatures": ast_signatures, "ast_definitions": ast_definitions, "ast_mask": ast_mask})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
|
||||||
@@ -272,6 +275,7 @@ def build_tier3_context(file_items: list[dict[str, Any]], screenshot_base_dir: P
|
|||||||
force_full = item.get("force_full")
|
force_full = item.get("force_full")
|
||||||
ast_signatures = item.get("ast_signatures", False)
|
ast_signatures = item.get("ast_signatures", False)
|
||||||
ast_definitions = item.get("ast_definitions", False)
|
ast_definitions = item.get("ast_definitions", False)
|
||||||
|
ast_mask = item.get("ast_mask", {})
|
||||||
content = item.get("content", "")
|
content = item.get("content", "")
|
||||||
is_focus = entry in focus_set or (name and name in focus_set) or (path_str and path_str in focus_set)
|
is_focus = entry in focus_set or (name and name in focus_set) or (path_str and path_str in focus_set)
|
||||||
if not is_focus and path_str:
|
if not is_focus and path_str:
|
||||||
@@ -284,6 +288,27 @@ def build_tier3_context(file_items: list[dict[str, Any]], screenshot_base_dir: P
|
|||||||
suffix = path.suffix.lstrip(".") if path and path.suffix else "text"
|
suffix = path.suffix.lstrip(".") if path and path.suffix else "text"
|
||||||
sections.append(f"### `{original}`\n\n```{suffix}\n{content}\n```")
|
sections.append(f"### `{original}`\n\n```{suffix}\n{content}\n```")
|
||||||
elif path:
|
elif path:
|
||||||
|
if ast_mask and not item.get("error"):
|
||||||
|
mask_sections = []
|
||||||
|
from src import mcp_client
|
||||||
|
for symbol, mode in ast_mask.items():
|
||||||
|
if mode == "hide":
|
||||||
|
continue
|
||||||
|
res = ""
|
||||||
|
if path.suffix == ".py":
|
||||||
|
res = mcp_client.py_get_definition(str(path), symbol) if mode == "def" else mcp_client.py_get_signature(str(path), symbol)
|
||||||
|
elif path.suffix in [".c", ".h", ".cpp", ".hpp", ".cxx", ".cc"]:
|
||||||
|
is_cpp = any(ext in path.suffix for ext in [".cpp", ".hpp", ".cxx", ".cc"])
|
||||||
|
if mode == "def":
|
||||||
|
res = mcp_client.ts_cpp_get_definition(str(path), symbol) if is_cpp else mcp_client.ts_c_get_definition(str(path), symbol)
|
||||||
|
else:
|
||||||
|
res = mcp_client.ts_cpp_get_signature(str(path), symbol) if is_cpp else mcp_client.ts_c_get_signature(str(path), symbol)
|
||||||
|
if res:
|
||||||
|
mask_sections.append(res)
|
||||||
|
if mask_sections:
|
||||||
|
suffix = path.suffix.lstrip(".") if path.suffix else "text"
|
||||||
|
sections.append(f"### `{original}` (Masked)\n\n```{suffix}\n" + "\n\n".join(mask_sections) + "\n```")
|
||||||
|
continue
|
||||||
if path.suffix in ['.c', '.h', '.cpp', '.hpp', '.cxx', '.cc'] and not item.get("error"):
|
if path.suffix in ['.c', '.h', '.cpp', '.hpp', '.cxx', '.cc'] and not item.get("error"):
|
||||||
from src import mcp_client
|
from src import mcp_client
|
||||||
if ast_definitions:
|
if ast_definitions:
|
||||||
|
|||||||
@@ -239,6 +239,9 @@ class App:
|
|||||||
self.shader_uniforms = {'crt': 1.0, 'scanline': 0.5, 'bloom': 0.8}
|
self.shader_uniforms = {'crt': 1.0, 'scanline': 0.5, 'bloom': 0.8}
|
||||||
self.ui_new_context_preset_name = ""
|
self.ui_new_context_preset_name = ""
|
||||||
self._focus_md_cache: dict[str, str] = {}
|
self._focus_md_cache: dict[str, str] = {}
|
||||||
|
self.ui_inspecting_ast_file = None
|
||||||
|
self._cached_ast_nodes = []
|
||||||
|
self._cached_ast_file_path = ''
|
||||||
"""UI-level wrapper for approving a pending tool execution ask."""
|
"""UI-level wrapper for approving a pending tool execution ask."""
|
||||||
self._handle_approve_ask()
|
self._handle_approve_ask()
|
||||||
|
|
||||||
@@ -1337,6 +1340,9 @@ class App:
|
|||||||
if imgui.button("Cancel", imgui.ImVec2(120, 0)):
|
if imgui.button("Cancel", imgui.ImVec2(120, 0)):
|
||||||
imgui.close_current_popup()
|
imgui.close_current_popup()
|
||||||
imgui.end_popup()
|
imgui.end_popup()
|
||||||
|
|
||||||
|
self._render_ast_inspector_modal()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR in _gui_func: {e}")
|
print(f"ERROR in _gui_func: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@@ -1469,6 +1475,86 @@ class App:
|
|||||||
imgui.close_current_popup()
|
imgui.close_current_popup()
|
||||||
imgui.end_popup()
|
imgui.end_popup()
|
||||||
|
|
||||||
|
def _render_ast_inspector_modal(self) -> None:
|
||||||
|
expanded, opened = imgui.begin_popup_modal('AST Inspector', True, imgui.WindowFlags_.always_auto_resize)
|
||||||
|
if expanded:
|
||||||
|
if self.ui_inspecting_ast_file is None:
|
||||||
|
imgui.close_current_popup()
|
||||||
|
else:
|
||||||
|
f_item = self.ui_inspecting_ast_file
|
||||||
|
f_path = f_item.path if hasattr(f_item, "path") else str(f_item)
|
||||||
|
|
||||||
|
if f_path != self._cached_ast_file_path:
|
||||||
|
outline = ""
|
||||||
|
try:
|
||||||
|
if f_path.lower().endswith('.py'):
|
||||||
|
outline = mcp_client.py_get_code_outline(f_path)
|
||||||
|
elif f_path.lower().endswith(('.c', '.h')):
|
||||||
|
outline = mcp_client.ts_c_get_code_outline(f_path)
|
||||||
|
else:
|
||||||
|
outline = mcp_client.ts_cpp_get_code_outline(f_path)
|
||||||
|
except Exception as e:
|
||||||
|
outline = f"Error fetching outline: {e}"
|
||||||
|
|
||||||
|
self._cached_ast_nodes = []
|
||||||
|
import re
|
||||||
|
pattern = re.compile(r'^(\s*)\[(.*?)\] (.*?) \(Lines \d+-\d+\)')
|
||||||
|
stack = [] # (indent, name)
|
||||||
|
for line in outline.splitlines():
|
||||||
|
m = pattern.match(line)
|
||||||
|
if m:
|
||||||
|
indent_str, kind, name = m.groups()
|
||||||
|
indent = len(indent_str)
|
||||||
|
while stack and stack[-1][0] >= indent:
|
||||||
|
stack.pop()
|
||||||
|
stack.append((indent, name))
|
||||||
|
full_path = '::'.join([s[1] for s in stack])
|
||||||
|
self._cached_ast_nodes.append({
|
||||||
|
'indent': indent,
|
||||||
|
'kind': kind,
|
||||||
|
'name': name,
|
||||||
|
'full_path': full_path
|
||||||
|
})
|
||||||
|
self._cached_ast_file_path = f_path
|
||||||
|
|
||||||
|
imgui.text(f"Inspecting AST: {f_path}")
|
||||||
|
imgui.separator()
|
||||||
|
|
||||||
|
imgui.begin_child("ast_tree_scroll", imgui.ImVec2(800, 600), True)
|
||||||
|
if not self._cached_ast_nodes:
|
||||||
|
imgui.text("No AST nodes found or error fetching outline.")
|
||||||
|
else:
|
||||||
|
for node in self._cached_ast_nodes:
|
||||||
|
indent = node['indent']
|
||||||
|
kind = node['kind']
|
||||||
|
name = node['name']
|
||||||
|
full_path = node['full_path']
|
||||||
|
|
||||||
|
imgui.dummy(imgui.ImVec2(indent * 10, 0))
|
||||||
|
imgui.same_line()
|
||||||
|
imgui.text(f"[{kind}] {name}")
|
||||||
|
imgui.same_line(imgui.get_window_width() - 200)
|
||||||
|
|
||||||
|
current_mode = f_item.ast_mask.get(full_path, 'hide')
|
||||||
|
|
||||||
|
imgui.push_id(full_path)
|
||||||
|
if imgui.radio_button("Def", current_mode == 'def'):
|
||||||
|
f_item.ast_mask[full_path] = 'def'
|
||||||
|
imgui.same_line()
|
||||||
|
if imgui.radio_button("Sig", current_mode == 'sig'):
|
||||||
|
f_item.ast_mask[full_path] = 'sig'
|
||||||
|
imgui.same_line()
|
||||||
|
if imgui.radio_button("Hide", current_mode == 'hide'):
|
||||||
|
f_item.ast_mask[full_path] = 'hide'
|
||||||
|
imgui.pop_id()
|
||||||
|
imgui.end_child()
|
||||||
|
|
||||||
|
imgui.separator()
|
||||||
|
if imgui.button("Close", imgui.ImVec2(120, 0)):
|
||||||
|
self.ui_inspecting_ast_file = None
|
||||||
|
imgui.close_current_popup()
|
||||||
|
imgui.end_popup()
|
||||||
|
|
||||||
def _render_save_workspace_profile_modal(self) -> None:
|
def _render_save_workspace_profile_modal(self) -> None:
|
||||||
if self._show_save_workspace_profile_modal:
|
if self._show_save_workspace_profile_modal:
|
||||||
imgui.open_popup("Save Workspace Profile")
|
imgui.open_popup("Save Workspace Profile")
|
||||||
@@ -2619,6 +2705,12 @@ class App:
|
|||||||
imgui.same_line()
|
imgui.same_line()
|
||||||
imgui.text(f_path)
|
imgui.text(f_path)
|
||||||
|
|
||||||
|
if f_path.lower().endswith(('.c', '.cpp', '.h', '.hpp', '.cxx', '.cc')):
|
||||||
|
imgui.same_line()
|
||||||
|
if imgui.button(f"[Inspect]##{i}"):
|
||||||
|
self.ui_inspecting_ast_file = f_item
|
||||||
|
imgui.open_popup('AST Inspector')
|
||||||
|
|
||||||
imgui.table_set_column_index(1)
|
imgui.table_set_column_index(1)
|
||||||
if hasattr(f_item, "auto_aggregate"):
|
if hasattr(f_item, "auto_aggregate"):
|
||||||
changed_agg, f_item.auto_aggregate = imgui.checkbox(f"Agg##cc{i}", f_item.auto_aggregate)
|
changed_agg, f_item.auto_aggregate = imgui.checkbox(f"Agg##cc{i}", f_item.auto_aggregate)
|
||||||
|
|||||||
@@ -497,6 +497,7 @@ class FileItem:
|
|||||||
force_full: bool = False
|
force_full: bool = False
|
||||||
ast_signatures: bool = False
|
ast_signatures: bool = False
|
||||||
ast_definitions: bool = False
|
ast_definitions: bool = False
|
||||||
|
ast_mask: dict[str, str] = field(default_factory=dict)
|
||||||
injected_at: Optional[float] = None
|
injected_at: Optional[float] = None
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
@@ -509,6 +510,7 @@ class FileItem:
|
|||||||
"force_full": self.force_full,
|
"force_full": self.force_full,
|
||||||
"ast_signatures": self.ast_signatures,
|
"ast_signatures": self.ast_signatures,
|
||||||
"ast_definitions": self.ast_definitions,
|
"ast_definitions": self.ast_definitions,
|
||||||
|
"ast_mask": self.ast_mask,
|
||||||
"injected_at": self.injected_at,
|
"injected_at": self.injected_at,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -523,6 +525,7 @@ class FileItem:
|
|||||||
force_full=data.get("force_full", False),
|
force_full=data.get("force_full", False),
|
||||||
ast_signatures=data.get("ast_signatures", False),
|
ast_signatures=data.get("ast_signatures", False),
|
||||||
ast_definitions=data.get("ast_definitions", False),
|
ast_definitions=data.get("ast_definitions", False),
|
||||||
|
ast_mask=data.get("ast_mask", {}),
|
||||||
injected_at=data.get("injected_at"),
|
injected_at=data.get("injected_at"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ def test_file_item_fields():
|
|||||||
assert item.path == "src/models.py"
|
assert item.path == "src/models.py"
|
||||||
assert item.auto_aggregate is True
|
assert item.auto_aggregate is True
|
||||||
assert item.force_full is False
|
assert item.force_full is False
|
||||||
|
assert item.ast_mask == {}
|
||||||
assert item.injected_at is None
|
assert item.injected_at is None
|
||||||
|
|
||||||
def test_file_item_to_dict():
|
def test_file_item_to_dict():
|
||||||
@@ -18,6 +19,7 @@ def test_file_item_to_dict():
|
|||||||
"force_full": True,
|
"force_full": True,
|
||||||
"ast_signatures": False,
|
"ast_signatures": False,
|
||||||
"ast_definitions": False,
|
"ast_definitions": False,
|
||||||
|
"ast_mask": {},
|
||||||
"injected_at": None
|
"injected_at": None
|
||||||
}
|
}
|
||||||
assert item.to_dict() == expected
|
assert item.to_dict() == expected
|
||||||
@@ -35,6 +37,7 @@ def test_file_item_from_dict():
|
|||||||
assert item.auto_aggregate is False
|
assert item.auto_aggregate is False
|
||||||
assert item.force_full is True
|
assert item.force_full is True
|
||||||
assert item.injected_at == 123.456
|
assert item.injected_at == 123.456
|
||||||
|
assert item.ast_mask == {}
|
||||||
|
|
||||||
def test_file_item_from_dict_defaults():
|
def test_file_item_from_dict_defaults():
|
||||||
"""Test that FileItem.from_dict handles missing fields."""
|
"""Test that FileItem.from_dict handles missing fields."""
|
||||||
@@ -43,4 +46,5 @@ def test_file_item_from_dict_defaults():
|
|||||||
assert item.path == "test.py"
|
assert item.path == "test.py"
|
||||||
assert item.auto_aggregate is True
|
assert item.auto_aggregate is True
|
||||||
assert item.force_full is False
|
assert item.force_full is False
|
||||||
assert item.injected_at is None
|
assert item.ast_mask == {}
|
||||||
|
assert item.injected_at is None
|
||||||
|
|||||||
Reference in New Issue
Block a user