From 8642d894dfcce391f060ac7bdb37f745caad7fdc Mon Sep 17 00:00:00 2001 From: Ed_ Date: Tue, 5 May 2026 19:44:40 -0400 Subject: [PATCH] feat(parser): Implement C/C++ update_definition --- src/file_cache.py | 86 ++++++++++++++++++++++++++++++++++++++++ tests/test_ast_parser.py | 20 ++++++++++ 2 files changed, 106 insertions(+) diff --git a/src/file_cache.py b/src/file_cache.py index 924fefa..43b6d2b 100644 --- a/src/file_cache.py +++ b/src/file_cache.py @@ -614,6 +614,92 @@ class ASTParser: walk(tree.root_node) return "\n".join(output) + def update_definition(self, code: str, name: str, new_content: str, path: Optional[str] = None) -> str: + """ + Surgically replace the definition of a class or function by name. + """ + tree = self.get_cached_tree(path, code) + + def get_name(node: tree_sitter.Node) -> str: + name_node = node.child_by_field_name("name") + if name_node: + return code[name_node.start_byte:name_node.end_byte] + + if node.type == "function_definition": + decl = node.child_by_field_name("declarator") + while decl: + if decl.type in ("identifier", "field_identifier"): + return code[decl.start_byte:decl.end_byte] + next_decl = decl.child_by_field_name("declarator") + if not next_decl and decl.child_count > 0: + for child in decl.children: + if child.type in ("identifier", "field_identifier"): + return code[child.start_byte:child.end_byte] + decl = decl.children[0] + else: + decl = next_decl + + if node.type == "template_declaration": + for child in node.children: + if child.type in ("function_definition", "class_definition"): + return get_name(child) + + if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"): + for child in node.children: + if child.type in ("type_identifier", "identifier", "namespace_identifier"): + return code[child.start_byte:child.end_byte] + return "" + + parts = re.split(r'::|\.', name) + + def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]: + if not target_parts: + return None + target = target_parts[0] + for child in node.children: + if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"): + if get_name(child) == target: + if len(target_parts) == 1: + return child + body = child.child_by_field_name("body") + if not body and child.type == "template_declaration": + for sub in child.children: + if sub.type in ("function_definition", "class_definition"): + body = sub.child_by_field_name("body") + break + if body: + found = walk(body, target_parts[1:]) + if found: return found + for sub in child.children: + if sub.type in ("field_declaration_list", "class_body", "declaration_list"): + found = walk(sub, target_parts[1:]) + if found: return found + # Recurse for top-level or namespaces + if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"): + for child in node.children: + if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"): + found = walk(child, target_parts) + if found: return found + return None + + found_node = walk(tree.root_node, parts) + if not found_node and len(parts) == 1: + def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]: + if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"): + if get_name(node) == target: + return node + for child in node.children: + res = deep_search(child, target) + if res: return res + return None + found_node = deep_search(tree.root_node, parts[0]) + + if found_node: + code_bytes = bytearray(code, "utf8") + code_bytes[found_node.start_byte:found_node.end_byte] = bytes(new_content, "utf8") + return code_bytes.decode("utf8") + return f"ERROR: definition '{name}' not found" + def reset_client() -> None: pass diff --git a/tests/test_ast_parser.py b/tests/test_ast_parser.py index c85671f..7f8af49 100644 --- a/tests/test_ast_parser.py +++ b/tests/test_ast_parser.py @@ -305,3 +305,23 @@ public: assert 'void normalMethod()' in sig2 assert '{' not in sig2 +def test_ast_parser_update_definition_cpp() -> None: + """Verify update_definition for C++ including scoped methods.""" + parser = ASTParser(language="cpp") + code = """ +class MyClass { +public: + void myMethod() { + int x = 1; + } +}; +""" + new_method = """ void myMethod() { + int y = 2; + }""" + updated = parser.update_definition(code, "MyClass::myMethod", new_method) + assert 'void myMethod() {' in updated + assert 'int y = 2;' in updated + assert 'int x = 1;' not in updated + assert 'class MyClass {' in updated +