From 7a609cae6977dccb4b0f789673edf352d1bb1307 Mon Sep 17 00:00:00 2001 From: Ed_ Date: Thu, 26 Feb 2026 19:47:33 -0500 Subject: [PATCH] feat(mma): Implement ASTParser in file_cache.py and refactor mcp_client.py --- file_cache.py | 144 +++++++++++++++++++++++++++++++++++++++ mcp_client.py | 53 +------------- tests/test_ast_parser.py | 105 ++++++++++++++++++++++++++++ 3 files changed, 252 insertions(+), 50 deletions(-) create mode 100644 tests/test_ast_parser.py diff --git a/file_cache.py b/file_cache.py index 4927b46..c862358 100644 --- a/file_cache.py +++ b/file_cache.py @@ -7,6 +7,150 @@ This file is kept so that any stale imports do not break. from pathlib import Path from typing import Optional +import tree_sitter +import tree_sitter_python + + +class ASTParser: + """ + Parser for extracting AST-based views of source code. + Currently supports Python. + """ + def __init__(self, language: str): + if language != "python": + raise ValueError(f"Language '{language}' not supported yet.") + self.language_name = language + # Load the tree-sitter language grammar + self.language = tree_sitter.Language(tree_sitter_python.language()) + self.parser = tree_sitter.Parser(self.language) + + def parse(self, code: str) -> tree_sitter.Tree: + """Parse the given code and return the tree-sitter Tree.""" + return self.parser.parse(bytes(code, "utf8")) + + def get_skeleton(self, code: str) -> str: + """ + Returns a skeleton of a Python file (preserving docstrings, stripping function bodies). + """ + tree = self.parse(code) + edits = [] + + def is_docstring(node): + if node.type == "expression_statement" and node.child_count > 0: + if node.children[0].type == "string": + return True + return False + + def walk(node): + if node.type == "function_definition": + body = node.child_by_field_name("body") + if body and body.type == "block": + indent = " " * body.start_point.column + first_stmt = None + for child in body.children: + if child.type != "comment": + first_stmt = child + break + + if first_stmt and is_docstring(first_stmt): + start_byte = first_stmt.end_byte + end_byte = body.end_byte + if end_byte > start_byte: + edits.append((start_byte, end_byte, f"\n{indent}...")) + else: + start_byte = body.start_byte + end_byte = body.end_byte + edits.append((start_byte, end_byte, "...")) + + for child in node.children: + walk(child) + + walk(tree.root_node) + + # Apply edits in reverse to maintain byte offsets + edits.sort(key=lambda x: x[0], reverse=True) + code_bytes = bytearray(code, "utf8") + for start, end, replacement in edits: + code_bytes[start:end] = bytes(replacement, "utf8") + + return code_bytes.decode("utf8") + + def get_curated_view(self, code: str) -> str: + """ + Returns a curated skeleton of a Python file. + Preserves function bodies if they have @core_logic decorator or # [HOT] comment. + Otherwise strips bodies but preserves docstrings. + """ + tree = self.parse(code) + edits = [] + + def is_docstring(node): + if node.type == "expression_statement" and node.child_count > 0: + if node.children[0].type == "string": + return True + return False + + def has_core_logic_decorator(node): + # Check if parent is decorated_definition + parent = node.parent + if parent and parent.type == "decorated_definition": + for child in parent.children: + if child.type == "decorator": + # decorator -> ( '@', identifier ) or ( '@', call ) + if "@core_logic" in code[child.start_byte:child.end_byte]: + return True + return False + + def has_hot_comment(func_node): + # Check all descendants of the function_definition for a [HOT] comment + stack = [func_node] + while stack: + curr = stack.pop() + if curr.type == "comment": + comment_text = code[curr.start_byte:curr.end_byte] + if "[HOT]" in comment_text: + return True + for child in curr.children: + stack.append(child) + return False + + def walk(node): + if node.type == "function_definition": + body = node.child_by_field_name("body") + if body and body.type == "block": + # Check if we should preserve it + preserve = has_core_logic_decorator(node) or has_hot_comment(node) + + if not preserve: + indent = " " * body.start_point.column + first_stmt = None + for child in body.children: + if child.type != "comment": + first_stmt = child + break + + if first_stmt and is_docstring(first_stmt): + start_byte = first_stmt.end_byte + end_byte = body.end_byte + if end_byte > start_byte: + edits.append((start_byte, end_byte, f"\n{indent}...")) + else: + start_byte = body.start_byte + end_byte = body.end_byte + edits.append((start_byte, end_byte, "...")) + + for child in node.children: + walk(child) + + walk(tree.root_node) + + # Apply edits in reverse to maintain byte offsets + edits.sort(key=lambda x: x[0], reverse=True) + code_bytes = bytearray(code, "utf8") + for start, end, replacement in edits: + code_bytes[start:end] = bytes(replacement, "utf8") + + return code_bytes.decode("utf8") def reset_client(): diff --git a/mcp_client.py b/mcp_client.py index 5a06214..61284e7 100644 --- a/mcp_client.py +++ b/mcp_client.py @@ -242,57 +242,10 @@ def get_python_skeleton(path: str) -> str: return f"ERROR: not a python file: {path}" try: - # Use mma_exec's generator if possible, or a local simplified version - # For now, we will use a dedicated script or just inline logic here. - # Given we have tree-sitter already installed in the env... - import tree_sitter - import tree_sitter_python - + from file_cache import ASTParser code = p.read_text(encoding="utf-8") - PY_LANGUAGE = tree_sitter.Language(tree_sitter_python.language()) - parser = tree_sitter.Parser(PY_LANGUAGE) - tree = parser.parse(bytes(code, "utf8")) - - edits = [] - - def is_docstring(node): - if node.type == "expression_statement" and node.child_count > 0: - if node.children[0].type == "string": - return True - return False - - def walk(node): - if node.type == "function_definition": - body = node.child_by_field_name("body") - if body and body.type == "block": - indent = " " * body.start_point.column - first_stmt = None - for child in body.children: - if child.type != "comment": - first_stmt = child - break - - if first_stmt and is_docstring(first_stmt): - start_byte = first_stmt.end_byte - end_byte = body.end_byte - if end_byte > start_byte: - edits.append((start_byte, end_byte, f"\\n{indent}...")) - else: - start_byte = body.start_byte - end_byte = body.end_byte - edits.append((start_byte, end_byte, "...")) - - for child in node.children: - walk(child) - - walk(tree.root_node) - - edits.sort(key=lambda x: x[0], reverse=True) - code_bytes = bytearray(code, "utf8") - for start, end, replacement in edits: - code_bytes[start:end] = bytes(replacement, "utf8") - - return code_bytes.decode("utf8") + parser = ASTParser("python") + return parser.get_skeleton(code) except Exception as e: return f"ERROR generating skeleton for '{path}': {e}" diff --git a/tests/test_ast_parser.py b/tests/test_ast_parser.py new file mode 100644 index 0000000..d70ad8c --- /dev/null +++ b/tests/test_ast_parser.py @@ -0,0 +1,105 @@ +import pytest +import tree_sitter +from file_cache import ASTParser + +def test_ast_parser_initialization(): + """Verify that ASTParser can be initialized with a language string.""" + parser = ASTParser("python") + assert parser.language_name == "python" + +def test_ast_parser_parse(): + """Verify that the parse method returns a tree_sitter.Tree.""" + parser = ASTParser("python") + code = """def example_func(): + return 42""" + tree = parser.parse(code) + assert isinstance(tree, tree_sitter.Tree) + # Basic check that it parsed something + assert tree.root_node.type == "module" + +def test_ast_parser_get_skeleton_python(): + """Verify that get_skeleton replaces function bodies with '...' while preserving docstrings.""" + parser = ASTParser("python") + code = ''' +def complex_function(a, b): + """ + This is a docstring. + It should be preserved. + """ + result = a + b + if result > 0: + return result + return 0 + +class MyClass: + def method_without_docstring(self): + print("doing something") + return None +''' + skeleton = parser.get_skeleton(code) + + # Check that signatures are preserved + assert "def complex_function(a, b):" in skeleton + assert "class MyClass:" in skeleton + assert "def method_without_docstring(self):" in skeleton + + # Check that docstring is preserved + assert '"""' in skeleton + assert "This is a docstring." in skeleton + assert "It should be preserved." in skeleton + + # Check that bodies are replaced with '...' + assert "..." in skeleton + assert "result = a + b" not in skeleton + assert "return result" not in skeleton + assert 'print("doing something")' not in skeleton + +def test_ast_parser_invalid_language(): + """Verify handling of unsupported or invalid languages.""" + # This might raise an error or return a default, depending on implementation + # For now, we expect it to either fail gracefully or raise an exception we can catch + with pytest.raises(Exception): + ASTParser("not-a-language") + +def test_ast_parser_get_curated_view(): + """Verify that get_curated_view preserves function bodies with @core_logic or # [HOT].""" + parser = ASTParser("python") + code = ''' +@core_logic +def core_func(): + """Core logic doc.""" + print("this should be preserved") + return True + +def hot_func(): + # [HOT] + print("this should also be preserved") + return 42 + +def normal_func(): + """Normal doc.""" + print("this should be stripped") + return None + +class MyClass: + @core_logic + def core_method(self, x): + print("method preserved", x) +''' + curated = parser.get_curated_view(code) + + # Check that core_func is preserved + assert 'print("this should be preserved")' in curated + assert 'return True' in curated + + # Check that hot_func is preserved + assert '# [HOT]' in curated + assert 'print("this should also be preserved")' in curated + + # Check that normal_func is stripped but docstring is preserved + assert '"""Normal doc."""' in curated + assert 'print("this should be stripped")' not in curated + assert '...' in curated + + # Check that core_method is preserved + assert 'print("method preserved", x)' in curated