feat(parser): Implement C/C++ update_definition
This commit is contained in:
@@ -614,6 +614,92 @@ class ASTParser:
|
||||
walk(tree.root_node)
|
||||
return "\n".join(output)
|
||||
|
||||
def update_definition(self, code: str, name: str, new_content: str, path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Surgically replace the definition of a class or function by name.
|
||||
"""
|
||||
tree = self.get_cached_tree(path, code)
|
||||
|
||||
def get_name(node: tree_sitter.Node) -> str:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
return code[name_node.start_byte:name_node.end_byte]
|
||||
|
||||
if node.type == "function_definition":
|
||||
decl = node.child_by_field_name("declarator")
|
||||
while decl:
|
||||
if decl.type in ("identifier", "field_identifier"):
|
||||
return code[decl.start_byte:decl.end_byte]
|
||||
next_decl = decl.child_by_field_name("declarator")
|
||||
if not next_decl and decl.child_count > 0:
|
||||
for child in decl.children:
|
||||
if child.type in ("identifier", "field_identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
decl = decl.children[0]
|
||||
else:
|
||||
decl = next_decl
|
||||
|
||||
if node.type == "template_declaration":
|
||||
for child in node.children:
|
||||
if child.type in ("function_definition", "class_definition"):
|
||||
return get_name(child)
|
||||
|
||||
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
|
||||
for child in node.children:
|
||||
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
return ""
|
||||
|
||||
parts = re.split(r'::|\.', name)
|
||||
|
||||
def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]:
|
||||
if not target_parts:
|
||||
return None
|
||||
target = target_parts[0]
|
||||
for child in node.children:
|
||||
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
if get_name(child) == target:
|
||||
if len(target_parts) == 1:
|
||||
return child
|
||||
body = child.child_by_field_name("body")
|
||||
if not body and child.type == "template_declaration":
|
||||
for sub in child.children:
|
||||
if sub.type in ("function_definition", "class_definition"):
|
||||
body = sub.child_by_field_name("body")
|
||||
break
|
||||
if body:
|
||||
found = walk(body, target_parts[1:])
|
||||
if found: return found
|
||||
for sub in child.children:
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||
found = walk(sub, target_parts[1:])
|
||||
if found: return found
|
||||
# Recurse for top-level or namespaces
|
||||
if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"):
|
||||
for child in node.children:
|
||||
if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
found = walk(child, target_parts)
|
||||
if found: return found
|
||||
return None
|
||||
|
||||
found_node = walk(tree.root_node, parts)
|
||||
if not found_node and len(parts) == 1:
|
||||
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
if get_name(node) == target:
|
||||
return node
|
||||
for child in node.children:
|
||||
res = deep_search(child, target)
|
||||
if res: return res
|
||||
return None
|
||||
found_node = deep_search(tree.root_node, parts[0])
|
||||
|
||||
if found_node:
|
||||
code_bytes = bytearray(code, "utf8")
|
||||
code_bytes[found_node.start_byte:found_node.end_byte] = bytes(new_content, "utf8")
|
||||
return code_bytes.decode("utf8")
|
||||
return f"ERROR: definition '{name}' not found"
|
||||
|
||||
def reset_client() -> None:
|
||||
pass
|
||||
|
||||
|
||||
@@ -305,3 +305,23 @@ public:
|
||||
assert 'void normalMethod()' in sig2
|
||||
assert '{' not in sig2
|
||||
|
||||
def test_ast_parser_update_definition_cpp() -> None:
|
||||
"""Verify update_definition for C++ including scoped methods."""
|
||||
parser = ASTParser(language="cpp")
|
||||
code = """
|
||||
class MyClass {
|
||||
public:
|
||||
void myMethod() {
|
||||
int x = 1;
|
||||
}
|
||||
};
|
||||
"""
|
||||
new_method = """ void myMethod() {
|
||||
int y = 2;
|
||||
}"""
|
||||
updated = parser.update_definition(code, "MyClass::myMethod", new_method)
|
||||
assert 'void myMethod() {' in updated
|
||||
assert 'int y = 2;' in updated
|
||||
assert 'int x = 1;' not in updated
|
||||
assert 'class MyClass {' in updated
|
||||
|
||||
|
||||
Reference in New Issue
Block a user