feat(mma): Implement ASTParser in file_cache.py and refactor mcp_client.py
This commit is contained in:
144
file_cache.py
144
file_cache.py
@@ -7,6 +7,150 @@ This file is kept so that any stale imports do not break.
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
import tree_sitter
|
||||||
|
import tree_sitter_python
|
||||||
|
|
||||||
|
|
||||||
|
class ASTParser:
|
||||||
|
"""
|
||||||
|
Parser for extracting AST-based views of source code.
|
||||||
|
Currently supports Python.
|
||||||
|
"""
|
||||||
|
def __init__(self, language: str):
|
||||||
|
if language != "python":
|
||||||
|
raise ValueError(f"Language '{language}' not supported yet.")
|
||||||
|
self.language_name = language
|
||||||
|
# Load the tree-sitter language grammar
|
||||||
|
self.language = tree_sitter.Language(tree_sitter_python.language())
|
||||||
|
self.parser = tree_sitter.Parser(self.language)
|
||||||
|
|
||||||
|
def parse(self, code: str) -> tree_sitter.Tree:
|
||||||
|
"""Parse the given code and return the tree-sitter Tree."""
|
||||||
|
return self.parser.parse(bytes(code, "utf8"))
|
||||||
|
|
||||||
|
def get_skeleton(self, code: str) -> str:
|
||||||
|
"""
|
||||||
|
Returns a skeleton of a Python file (preserving docstrings, stripping function bodies).
|
||||||
|
"""
|
||||||
|
tree = self.parse(code)
|
||||||
|
edits = []
|
||||||
|
|
||||||
|
def is_docstring(node):
|
||||||
|
if node.type == "expression_statement" and node.child_count > 0:
|
||||||
|
if node.children[0].type == "string":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def walk(node):
|
||||||
|
if node.type == "function_definition":
|
||||||
|
body = node.child_by_field_name("body")
|
||||||
|
if body and body.type == "block":
|
||||||
|
indent = " " * body.start_point.column
|
||||||
|
first_stmt = None
|
||||||
|
for child in body.children:
|
||||||
|
if child.type != "comment":
|
||||||
|
first_stmt = child
|
||||||
|
break
|
||||||
|
|
||||||
|
if first_stmt and is_docstring(first_stmt):
|
||||||
|
start_byte = first_stmt.end_byte
|
||||||
|
end_byte = body.end_byte
|
||||||
|
if end_byte > start_byte:
|
||||||
|
edits.append((start_byte, end_byte, f"\n{indent}..."))
|
||||||
|
else:
|
||||||
|
start_byte = body.start_byte
|
||||||
|
end_byte = body.end_byte
|
||||||
|
edits.append((start_byte, end_byte, "..."))
|
||||||
|
|
||||||
|
for child in node.children:
|
||||||
|
walk(child)
|
||||||
|
|
||||||
|
walk(tree.root_node)
|
||||||
|
|
||||||
|
# Apply edits in reverse to maintain byte offsets
|
||||||
|
edits.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
code_bytes = bytearray(code, "utf8")
|
||||||
|
for start, end, replacement in edits:
|
||||||
|
code_bytes[start:end] = bytes(replacement, "utf8")
|
||||||
|
|
||||||
|
return code_bytes.decode("utf8")
|
||||||
|
|
||||||
|
def get_curated_view(self, code: str) -> str:
|
||||||
|
"""
|
||||||
|
Returns a curated skeleton of a Python file.
|
||||||
|
Preserves function bodies if they have @core_logic decorator or # [HOT] comment.
|
||||||
|
Otherwise strips bodies but preserves docstrings.
|
||||||
|
"""
|
||||||
|
tree = self.parse(code)
|
||||||
|
edits = []
|
||||||
|
|
||||||
|
def is_docstring(node):
|
||||||
|
if node.type == "expression_statement" and node.child_count > 0:
|
||||||
|
if node.children[0].type == "string":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def has_core_logic_decorator(node):
|
||||||
|
# Check if parent is decorated_definition
|
||||||
|
parent = node.parent
|
||||||
|
if parent and parent.type == "decorated_definition":
|
||||||
|
for child in parent.children:
|
||||||
|
if child.type == "decorator":
|
||||||
|
# decorator -> ( '@', identifier ) or ( '@', call )
|
||||||
|
if "@core_logic" in code[child.start_byte:child.end_byte]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def has_hot_comment(func_node):
|
||||||
|
# Check all descendants of the function_definition for a [HOT] comment
|
||||||
|
stack = [func_node]
|
||||||
|
while stack:
|
||||||
|
curr = stack.pop()
|
||||||
|
if curr.type == "comment":
|
||||||
|
comment_text = code[curr.start_byte:curr.end_byte]
|
||||||
|
if "[HOT]" in comment_text:
|
||||||
|
return True
|
||||||
|
for child in curr.children:
|
||||||
|
stack.append(child)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def walk(node):
|
||||||
|
if node.type == "function_definition":
|
||||||
|
body = node.child_by_field_name("body")
|
||||||
|
if body and body.type == "block":
|
||||||
|
# Check if we should preserve it
|
||||||
|
preserve = has_core_logic_decorator(node) or has_hot_comment(node)
|
||||||
|
|
||||||
|
if not preserve:
|
||||||
|
indent = " " * body.start_point.column
|
||||||
|
first_stmt = None
|
||||||
|
for child in body.children:
|
||||||
|
if child.type != "comment":
|
||||||
|
first_stmt = child
|
||||||
|
break
|
||||||
|
|
||||||
|
if first_stmt and is_docstring(first_stmt):
|
||||||
|
start_byte = first_stmt.end_byte
|
||||||
|
end_byte = body.end_byte
|
||||||
|
if end_byte > start_byte:
|
||||||
|
edits.append((start_byte, end_byte, f"\n{indent}..."))
|
||||||
|
else:
|
||||||
|
start_byte = body.start_byte
|
||||||
|
end_byte = body.end_byte
|
||||||
|
edits.append((start_byte, end_byte, "..."))
|
||||||
|
|
||||||
|
for child in node.children:
|
||||||
|
walk(child)
|
||||||
|
|
||||||
|
walk(tree.root_node)
|
||||||
|
|
||||||
|
# Apply edits in reverse to maintain byte offsets
|
||||||
|
edits.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
code_bytes = bytearray(code, "utf8")
|
||||||
|
for start, end, replacement in edits:
|
||||||
|
code_bytes[start:end] = bytes(replacement, "utf8")
|
||||||
|
|
||||||
|
return code_bytes.decode("utf8")
|
||||||
|
|
||||||
|
|
||||||
def reset_client():
|
def reset_client():
|
||||||
|
|||||||
@@ -242,57 +242,10 @@ def get_python_skeleton(path: str) -> str:
|
|||||||
return f"ERROR: not a python file: {path}"
|
return f"ERROR: not a python file: {path}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use mma_exec's generator if possible, or a local simplified version
|
from file_cache import ASTParser
|
||||||
# For now, we will use a dedicated script or just inline logic here.
|
|
||||||
# Given we have tree-sitter already installed in the env...
|
|
||||||
import tree_sitter
|
|
||||||
import tree_sitter_python
|
|
||||||
|
|
||||||
code = p.read_text(encoding="utf-8")
|
code = p.read_text(encoding="utf-8")
|
||||||
PY_LANGUAGE = tree_sitter.Language(tree_sitter_python.language())
|
parser = ASTParser("python")
|
||||||
parser = tree_sitter.Parser(PY_LANGUAGE)
|
return parser.get_skeleton(code)
|
||||||
tree = parser.parse(bytes(code, "utf8"))
|
|
||||||
|
|
||||||
edits = []
|
|
||||||
|
|
||||||
def is_docstring(node):
|
|
||||||
if node.type == "expression_statement" and node.child_count > 0:
|
|
||||||
if node.children[0].type == "string":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def walk(node):
|
|
||||||
if node.type == "function_definition":
|
|
||||||
body = node.child_by_field_name("body")
|
|
||||||
if body and body.type == "block":
|
|
||||||
indent = " " * body.start_point.column
|
|
||||||
first_stmt = None
|
|
||||||
for child in body.children:
|
|
||||||
if child.type != "comment":
|
|
||||||
first_stmt = child
|
|
||||||
break
|
|
||||||
|
|
||||||
if first_stmt and is_docstring(first_stmt):
|
|
||||||
start_byte = first_stmt.end_byte
|
|
||||||
end_byte = body.end_byte
|
|
||||||
if end_byte > start_byte:
|
|
||||||
edits.append((start_byte, end_byte, f"\\n{indent}..."))
|
|
||||||
else:
|
|
||||||
start_byte = body.start_byte
|
|
||||||
end_byte = body.end_byte
|
|
||||||
edits.append((start_byte, end_byte, "..."))
|
|
||||||
|
|
||||||
for child in node.children:
|
|
||||||
walk(child)
|
|
||||||
|
|
||||||
walk(tree.root_node)
|
|
||||||
|
|
||||||
edits.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
code_bytes = bytearray(code, "utf8")
|
|
||||||
for start, end, replacement in edits:
|
|
||||||
code_bytes[start:end] = bytes(replacement, "utf8")
|
|
||||||
|
|
||||||
return code_bytes.decode("utf8")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"ERROR generating skeleton for '{path}': {e}"
|
return f"ERROR generating skeleton for '{path}': {e}"
|
||||||
|
|
||||||
|
|||||||
105
tests/test_ast_parser.py
Normal file
105
tests/test_ast_parser.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import pytest
|
||||||
|
import tree_sitter
|
||||||
|
from file_cache import ASTParser
|
||||||
|
|
||||||
|
def test_ast_parser_initialization():
|
||||||
|
"""Verify that ASTParser can be initialized with a language string."""
|
||||||
|
parser = ASTParser("python")
|
||||||
|
assert parser.language_name == "python"
|
||||||
|
|
||||||
|
def test_ast_parser_parse():
|
||||||
|
"""Verify that the parse method returns a tree_sitter.Tree."""
|
||||||
|
parser = ASTParser("python")
|
||||||
|
code = """def example_func():
|
||||||
|
return 42"""
|
||||||
|
tree = parser.parse(code)
|
||||||
|
assert isinstance(tree, tree_sitter.Tree)
|
||||||
|
# Basic check that it parsed something
|
||||||
|
assert tree.root_node.type == "module"
|
||||||
|
|
||||||
|
def test_ast_parser_get_skeleton_python():
|
||||||
|
"""Verify that get_skeleton replaces function bodies with '...' while preserving docstrings."""
|
||||||
|
parser = ASTParser("python")
|
||||||
|
code = '''
|
||||||
|
def complex_function(a, b):
|
||||||
|
"""
|
||||||
|
This is a docstring.
|
||||||
|
It should be preserved.
|
||||||
|
"""
|
||||||
|
result = a + b
|
||||||
|
if result > 0:
|
||||||
|
return result
|
||||||
|
return 0
|
||||||
|
|
||||||
|
class MyClass:
|
||||||
|
def method_without_docstring(self):
|
||||||
|
print("doing something")
|
||||||
|
return None
|
||||||
|
'''
|
||||||
|
skeleton = parser.get_skeleton(code)
|
||||||
|
|
||||||
|
# Check that signatures are preserved
|
||||||
|
assert "def complex_function(a, b):" in skeleton
|
||||||
|
assert "class MyClass:" in skeleton
|
||||||
|
assert "def method_without_docstring(self):" in skeleton
|
||||||
|
|
||||||
|
# Check that docstring is preserved
|
||||||
|
assert '"""' in skeleton
|
||||||
|
assert "This is a docstring." in skeleton
|
||||||
|
assert "It should be preserved." in skeleton
|
||||||
|
|
||||||
|
# Check that bodies are replaced with '...'
|
||||||
|
assert "..." in skeleton
|
||||||
|
assert "result = a + b" not in skeleton
|
||||||
|
assert "return result" not in skeleton
|
||||||
|
assert 'print("doing something")' not in skeleton
|
||||||
|
|
||||||
|
def test_ast_parser_invalid_language():
|
||||||
|
"""Verify handling of unsupported or invalid languages."""
|
||||||
|
# This might raise an error or return a default, depending on implementation
|
||||||
|
# For now, we expect it to either fail gracefully or raise an exception we can catch
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
ASTParser("not-a-language")
|
||||||
|
|
||||||
|
def test_ast_parser_get_curated_view():
|
||||||
|
"""Verify that get_curated_view preserves function bodies with @core_logic or # [HOT]."""
|
||||||
|
parser = ASTParser("python")
|
||||||
|
code = '''
|
||||||
|
@core_logic
|
||||||
|
def core_func():
|
||||||
|
"""Core logic doc."""
|
||||||
|
print("this should be preserved")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def hot_func():
|
||||||
|
# [HOT]
|
||||||
|
print("this should also be preserved")
|
||||||
|
return 42
|
||||||
|
|
||||||
|
def normal_func():
|
||||||
|
"""Normal doc."""
|
||||||
|
print("this should be stripped")
|
||||||
|
return None
|
||||||
|
|
||||||
|
class MyClass:
|
||||||
|
@core_logic
|
||||||
|
def core_method(self, x):
|
||||||
|
print("method preserved", x)
|
||||||
|
'''
|
||||||
|
curated = parser.get_curated_view(code)
|
||||||
|
|
||||||
|
# Check that core_func is preserved
|
||||||
|
assert 'print("this should be preserved")' in curated
|
||||||
|
assert 'return True' in curated
|
||||||
|
|
||||||
|
# Check that hot_func is preserved
|
||||||
|
assert '# [HOT]' in curated
|
||||||
|
assert 'print("this should also be preserved")' in curated
|
||||||
|
|
||||||
|
# Check that normal_func is stripped but docstring is preserved
|
||||||
|
assert '"""Normal doc."""' in curated
|
||||||
|
assert 'print("this should be stripped")' not in curated
|
||||||
|
assert '...' in curated
|
||||||
|
|
||||||
|
# Check that core_method is preserved
|
||||||
|
assert 'print("method preserved", x)' in curated
|
||||||
Reference in New Issue
Block a user