feat(parser): Implement C/C++ get_definition and get_signature

This commit is contained in:
2026-05-05 19:42:14 -04:00
parent b8460107b9
commit 799feb0f94
2 changed files with 281 additions and 0 deletions
+180
View File
@@ -374,6 +374,186 @@ class ASTParser:
result = re.sub(r'\n\s*\n\s*\n+', '\n\n', result)
return result.strip() + "\n"
def get_definition(self, code: str, name: str, path: Optional[str] = None) -> str:
"""
Returns the full source code for a specific definition by name.
Supports 'ClassName::method' or 'method' for C++.
"""
tree = self.get_cached_tree(path, code)
def get_name(node: tree_sitter.Node) -> str:
name_node = node.child_by_field_name("name")
if name_node:
return code[name_node.start_byte:name_node.end_byte]
if node.type == "function_definition":
decl = node.child_by_field_name("declarator")
while decl:
if decl.type in ("identifier", "field_identifier"):
return code[decl.start_byte:decl.end_byte]
next_decl = decl.child_by_field_name("declarator")
if not next_decl and decl.child_count > 0:
for child in decl.children:
if child.type in ("identifier", "field_identifier"):
return code[child.start_byte:child.end_byte]
decl = decl.children[0]
else:
decl = next_decl
if node.type == "template_declaration":
for child in node.children:
if child.type in ("function_definition", "class_definition"):
return get_name(child)
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
for child in node.children:
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
return code[child.start_byte:child.end_byte]
return ""
parts = re.split(r'::|\.', name)
def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]:
if not target_parts:
return None
target = target_parts[0]
for child in node.children:
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if get_name(child) == target:
if len(target_parts) == 1:
return child
body = child.child_by_field_name("body")
if not body and child.type == "template_declaration":
for sub in child.children:
if sub.type in ("function_definition", "class_definition"):
body = sub.child_by_field_name("body")
break
if body:
found = walk(body, target_parts[1:])
if found: return found
for sub in child.children:
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
found = walk(sub, target_parts[1:])
if found: return found
# Recurse for top-level or namespaces
if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"):
for child in node.children:
if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
found = walk(child, target_parts)
if found: return found
return None
found_node = walk(tree.root_node, parts)
if not found_node and len(parts) == 1:
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if get_name(node) == target:
return node
for child in node.children:
res = deep_search(child, target)
if res: return res
return None
found_node = deep_search(tree.root_node, parts[0])
if found_node:
return code[found_node.start_byte:found_node.end_byte]
return f"ERROR: definition '{name}' not found"
def get_signature(self, code: str, name: str, path: Optional[str] = None) -> str:
"""
Returns only the signature part of a function or method.
For C/C++, this is the code from the start of the definition until the block start '{'.
"""
tree = self.get_cached_tree(path, code)
def get_name(node: tree_sitter.Node) -> str:
name_node = node.child_by_field_name("name")
if name_node:
return code[name_node.start_byte:name_node.end_byte]
if node.type == "function_definition":
decl = node.child_by_field_name("declarator")
while decl:
if decl.type in ("identifier", "field_identifier"):
return code[decl.start_byte:decl.end_byte]
next_decl = decl.child_by_field_name("declarator")
if not next_decl and decl.child_count > 0:
for child in decl.children:
if child.type in ("identifier", "field_identifier"):
return code[child.start_byte:child.end_byte]
decl = decl.children[0]
else:
decl = next_decl
if node.type == "template_declaration":
for child in node.children:
if child.type in ("function_definition", "class_definition"):
return get_name(child)
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
for child in node.children:
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
return code[child.start_byte:child.end_byte]
return ""
parts = re.split(r'::|\.', name)
def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]:
if not target_parts:
return None
target = target_parts[0]
for child in node.children:
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if get_name(child) == target:
if len(target_parts) == 1:
return child
body = child.child_by_field_name("body")
if not body and child.type == "template_declaration":
for sub in child.children:
if sub.type in ("function_definition", "class_definition"):
body = sub.child_by_field_name("body")
break
if body:
found = walk(body, target_parts[1:])
if found: return found
for sub in child.children:
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
found = walk(sub, target_parts[1:])
if found: return found
if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"):
for child in node.children:
if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
found = walk(child, target_parts)
if found: return found
return None
found_node = walk(tree.root_node, parts)
if not found_node and len(parts) == 1:
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
if node.type in ("function_definition", "template_declaration"):
if get_name(node) == target:
return node
for child in node.children:
res = deep_search(child, target)
if res: return res
return None
found_node = deep_search(tree.root_node, parts[0])
if found_node:
target_node = found_node
if found_node.type == "template_declaration":
for child in found_node.children:
if child.type in ("function_definition", "class_definition"):
target_node = child
break
body = target_node.child_by_field_name("body")
if body:
return code[found_node.start_byte:body.start_byte].strip()
return code[found_node.start_byte:found_node.end_byte].strip()
return f"ERROR: signature for '{name}' not found"
def get_code_outline(self, code: str, path: Optional[str] = None) -> str:
"""
Returns a hierarchical outline of the code (classes, structs, functions, methods).