feat(mma): Implement Deep AST-Driven Context Pruning and mark track complete
This commit is contained in:
@@ -9,6 +9,9 @@ from pathlib import Path
|
||||
from typing import Optional, Any, List, Tuple, Dict
|
||||
import tree_sitter
|
||||
import tree_sitter_python
|
||||
import re
|
||||
|
||||
_ast_cache: Dict[str, Tuple[float, tree_sitter.Tree]] = {}
|
||||
|
||||
class ASTParser:
|
||||
"""
|
||||
@@ -28,11 +31,38 @@ class ASTParser:
|
||||
"""Parse the given code and return the tree-sitter Tree."""
|
||||
return self.parser.parse(bytes(code, "utf8"))
|
||||
|
||||
def get_skeleton(self, code: str) -> str:
|
||||
def get_cached_tree(self, path: Optional[str], code: str) -> tree_sitter.Tree:
|
||||
"""Get cached tree or parse and cache it."""
|
||||
if not path:
|
||||
return self.parse(code)
|
||||
|
||||
try:
|
||||
p = Path(path)
|
||||
mtime = p.stat().st_mtime if p.exists() else 0.0
|
||||
except Exception:
|
||||
mtime = 0.0
|
||||
|
||||
if path in _ast_cache:
|
||||
cached_mtime, tree = _ast_cache[path]
|
||||
if cached_mtime == mtime:
|
||||
return tree
|
||||
|
||||
tree = self.parse(code)
|
||||
if len(_ast_cache) >= 10:
|
||||
# Simple LRU: remove the first added entry
|
||||
try:
|
||||
first_key = next(iter(_ast_cache))
|
||||
del _ast_cache[first_key]
|
||||
except StopIteration:
|
||||
pass
|
||||
_ast_cache[path] = (mtime, tree)
|
||||
return tree
|
||||
|
||||
def get_skeleton(self, code: str, path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Returns a skeleton of a Python file (preserving docstrings, stripping function bodies).
|
||||
"""
|
||||
tree = self.parse(code)
|
||||
tree = self.get_cached_tree(path, code)
|
||||
edits: List[Tuple[int, int, str]] = []
|
||||
|
||||
def is_docstring(node: tree_sitter.Node) -> bool:
|
||||
@@ -70,13 +100,13 @@ class ASTParser:
|
||||
code_bytes[start:end] = bytes(replacement, "utf8")
|
||||
return code_bytes.decode("utf8")
|
||||
|
||||
def get_curated_view(self, code: str) -> str:
|
||||
def get_curated_view(self, code: str, path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Returns a curated skeleton of a Python file.
|
||||
Preserves function bodies if they have @core_logic decorator or # [HOT] comment.
|
||||
Otherwise strips bodies but preserves docstrings.
|
||||
"""
|
||||
tree = self.parse(code)
|
||||
tree = self.get_cached_tree(path, code)
|
||||
edits: List[Tuple[int, int, str]] = []
|
||||
|
||||
def is_docstring(node: tree_sitter.Node) -> bool:
|
||||
@@ -141,6 +171,167 @@ class ASTParser:
|
||||
code_bytes[start:end] = bytes(replacement, "utf8")
|
||||
return code_bytes.decode("utf8")
|
||||
|
||||
def get_targeted_view(self, code: str, function_names: List[str], path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Returns a targeted view of the code including only the specified functions
|
||||
and their dependencies up to depth 2.
|
||||
"""
|
||||
tree = self.get_cached_tree(path, code)
|
||||
all_functions = {}
|
||||
|
||||
def collect_functions(node, class_name=None):
|
||||
if node.type == "function_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
func_name = code[name_node.start_byte:name_node.end_byte]
|
||||
full_name = f"{class_name}.{func_name}" if class_name else func_name
|
||||
all_functions[full_name] = node
|
||||
elif node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
cname = code[name_node.start_byte:name_node.end_byte]
|
||||
full_cname = f"{class_name}.{cname}" if class_name else cname
|
||||
body = node.child_by_field_name("body")
|
||||
if body:
|
||||
collect_functions(body, full_cname)
|
||||
return
|
||||
for child in node.children:
|
||||
collect_functions(child, class_name)
|
||||
|
||||
collect_functions(tree.root_node)
|
||||
|
||||
def get_calls(node):
|
||||
calls = set()
|
||||
def walk_calls(n):
|
||||
if n.type == "call":
|
||||
func_node = n.child_by_field_name("function")
|
||||
if func_node:
|
||||
if func_node.type == "identifier":
|
||||
calls.add(code[func_node.start_byte:func_node.end_byte])
|
||||
elif func_node.type == "attribute":
|
||||
attr_node = func_node.child_by_field_name("attribute")
|
||||
if attr_node:
|
||||
calls.add(code[attr_node.start_byte:attr_node.end_byte])
|
||||
for child in n.children:
|
||||
walk_calls(child)
|
||||
walk_calls(node)
|
||||
return calls
|
||||
|
||||
to_include = set()
|
||||
for target in function_names:
|
||||
if target in all_functions:
|
||||
to_include.add(target)
|
||||
else:
|
||||
for full_name in all_functions:
|
||||
if full_name.split('.')[-1] == target:
|
||||
to_include.add(full_name)
|
||||
|
||||
current_layer = set(to_include)
|
||||
all_found = set(to_include)
|
||||
for _ in range(2):
|
||||
next_layer = set()
|
||||
for name in current_layer:
|
||||
if name in all_functions:
|
||||
node = all_functions[name]
|
||||
calls = get_calls(node)
|
||||
for call in calls:
|
||||
for func_name in all_functions:
|
||||
if func_name == call or func_name.split('.')[-1] == call:
|
||||
if func_name not in all_found:
|
||||
next_layer.add(func_name)
|
||||
all_found.add(func_name)
|
||||
current_layer = next_layer
|
||||
if not current_layer:
|
||||
break
|
||||
|
||||
edits = []
|
||||
def is_docstring(n) -> bool:
|
||||
if n.type == "expression_statement" and n.child_count > 0:
|
||||
if n.children[0].type == "string":
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_for_targeted(node, parent_class=None):
|
||||
if node.type == "function_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
fname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||
fullname = f"{parent_class}.{fname}" if parent_class else fname
|
||||
return fullname in all_found
|
||||
if node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
cname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||
full_cname = f"{parent_class}.{cname}" if parent_class else cname
|
||||
body = node.child_by_field_name("body")
|
||||
if body:
|
||||
for child in body.children:
|
||||
if check_for_targeted(child, full_cname):
|
||||
return True
|
||||
return False
|
||||
for child in node.children:
|
||||
if check_for_targeted(child, parent_class):
|
||||
return True
|
||||
return False
|
||||
|
||||
def walk_edits(node, parent_class=None):
|
||||
if node.type == "function_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
fname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||
fullname = f"{parent_class}.{fname}" if parent_class else fname
|
||||
if fullname in all_found:
|
||||
body = node.child_by_field_name("body")
|
||||
if body and body.type == "block":
|
||||
indent = " " * body.start_point.column
|
||||
first_stmt = None
|
||||
for child in body.children:
|
||||
if child.type != "comment":
|
||||
first_stmt = child
|
||||
break
|
||||
if first_stmt and is_docstring(first_stmt):
|
||||
start_byte = first_stmt.end_byte
|
||||
end_byte = body.end_byte
|
||||
if end_byte > start_byte:
|
||||
edits.append((start_byte, end_byte, f"\n{indent}..."))
|
||||
else:
|
||||
start_byte = body.start_byte
|
||||
end_byte = body.end_byte
|
||||
edits.append((start_byte, end_byte, "..."))
|
||||
else:
|
||||
edits.append((node.start_byte, node.end_byte, ""))
|
||||
return
|
||||
if node.type == "class_definition":
|
||||
if check_for_targeted(node, parent_class):
|
||||
name_node = node.child_by_field_name("name")
|
||||
cname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||
full_cname = f"{parent_class}.{cname}" if parent_class else cname
|
||||
body = node.child_by_field_name("body")
|
||||
if body:
|
||||
for child in body.children:
|
||||
walk_edits(child, full_cname)
|
||||
else:
|
||||
edits.append((node.start_byte, node.end_byte, ""))
|
||||
return
|
||||
if node.type in ("import_statement", "import_from_statement"):
|
||||
return
|
||||
if node.type == "module":
|
||||
for child in node.children:
|
||||
walk_edits(child, parent_class)
|
||||
else:
|
||||
if node.parent and node.parent.type == "module":
|
||||
if node.type not in ("comment",):
|
||||
edits.append((node.start_byte, node.end_byte, ""))
|
||||
else:
|
||||
for child in node.children:
|
||||
walk_edits(child, parent_class)
|
||||
|
||||
walk_edits(tree.root_node)
|
||||
edits.sort(key=lambda x: x[0], reverse=True)
|
||||
code_bytes = bytearray(code, "utf8")
|
||||
for start, end, replacement in edits:
|
||||
code_bytes[start:end] = bytes(replacement, "utf8")
|
||||
result = code_bytes.decode("utf8")
|
||||
result = re.sub(r'\n\s*\n\s*\n+', '\n\n', result)
|
||||
return result.strip() + "\n"
|
||||
|
||||
def reset_client() -> None:
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user