feat(parser): Implement C/C++ get_definition and get_signature
This commit is contained in:
@@ -374,6 +374,186 @@ class ASTParser:
|
|||||||
result = re.sub(r'\n\s*\n\s*\n+', '\n\n', result)
|
result = re.sub(r'\n\s*\n\s*\n+', '\n\n', result)
|
||||||
return result.strip() + "\n"
|
return result.strip() + "\n"
|
||||||
|
|
||||||
|
def get_definition(self, code: str, name: str, path: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Returns the full source code for a specific definition by name.
|
||||||
|
Supports 'ClassName::method' or 'method' for C++.
|
||||||
|
"""
|
||||||
|
tree = self.get_cached_tree(path, code)
|
||||||
|
|
||||||
|
def get_name(node: tree_sitter.Node) -> str:
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node:
|
||||||
|
return code[name_node.start_byte:name_node.end_byte]
|
||||||
|
|
||||||
|
if node.type == "function_definition":
|
||||||
|
decl = node.child_by_field_name("declarator")
|
||||||
|
while decl:
|
||||||
|
if decl.type in ("identifier", "field_identifier"):
|
||||||
|
return code[decl.start_byte:decl.end_byte]
|
||||||
|
next_decl = decl.child_by_field_name("declarator")
|
||||||
|
if not next_decl and decl.child_count > 0:
|
||||||
|
for child in decl.children:
|
||||||
|
if child.type in ("identifier", "field_identifier"):
|
||||||
|
return code[child.start_byte:child.end_byte]
|
||||||
|
decl = decl.children[0]
|
||||||
|
else:
|
||||||
|
decl = next_decl
|
||||||
|
|
||||||
|
if node.type == "template_declaration":
|
||||||
|
for child in node.children:
|
||||||
|
if child.type in ("function_definition", "class_definition"):
|
||||||
|
return get_name(child)
|
||||||
|
|
||||||
|
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
|
||||||
|
for child in node.children:
|
||||||
|
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
|
||||||
|
return code[child.start_byte:child.end_byte]
|
||||||
|
return ""
|
||||||
|
|
||||||
|
parts = re.split(r'::|\.', name)
|
||||||
|
|
||||||
|
def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]:
|
||||||
|
if not target_parts:
|
||||||
|
return None
|
||||||
|
target = target_parts[0]
|
||||||
|
for child in node.children:
|
||||||
|
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||||
|
if get_name(child) == target:
|
||||||
|
if len(target_parts) == 1:
|
||||||
|
return child
|
||||||
|
body = child.child_by_field_name("body")
|
||||||
|
if not body and child.type == "template_declaration":
|
||||||
|
for sub in child.children:
|
||||||
|
if sub.type in ("function_definition", "class_definition"):
|
||||||
|
body = sub.child_by_field_name("body")
|
||||||
|
break
|
||||||
|
if body:
|
||||||
|
found = walk(body, target_parts[1:])
|
||||||
|
if found: return found
|
||||||
|
for sub in child.children:
|
||||||
|
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||||
|
found = walk(sub, target_parts[1:])
|
||||||
|
if found: return found
|
||||||
|
# Recurse for top-level or namespaces
|
||||||
|
if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"):
|
||||||
|
for child in node.children:
|
||||||
|
if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||||
|
found = walk(child, target_parts)
|
||||||
|
if found: return found
|
||||||
|
return None
|
||||||
|
|
||||||
|
found_node = walk(tree.root_node, parts)
|
||||||
|
if not found_node and len(parts) == 1:
|
||||||
|
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||||
|
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||||
|
if get_name(node) == target:
|
||||||
|
return node
|
||||||
|
for child in node.children:
|
||||||
|
res = deep_search(child, target)
|
||||||
|
if res: return res
|
||||||
|
return None
|
||||||
|
found_node = deep_search(tree.root_node, parts[0])
|
||||||
|
|
||||||
|
if found_node:
|
||||||
|
return code[found_node.start_byte:found_node.end_byte]
|
||||||
|
return f"ERROR: definition '{name}' not found"
|
||||||
|
|
||||||
|
def get_signature(self, code: str, name: str, path: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Returns only the signature part of a function or method.
|
||||||
|
For C/C++, this is the code from the start of the definition until the block start '{'.
|
||||||
|
"""
|
||||||
|
tree = self.get_cached_tree(path, code)
|
||||||
|
|
||||||
|
def get_name(node: tree_sitter.Node) -> str:
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node:
|
||||||
|
return code[name_node.start_byte:name_node.end_byte]
|
||||||
|
|
||||||
|
if node.type == "function_definition":
|
||||||
|
decl = node.child_by_field_name("declarator")
|
||||||
|
while decl:
|
||||||
|
if decl.type in ("identifier", "field_identifier"):
|
||||||
|
return code[decl.start_byte:decl.end_byte]
|
||||||
|
next_decl = decl.child_by_field_name("declarator")
|
||||||
|
if not next_decl and decl.child_count > 0:
|
||||||
|
for child in decl.children:
|
||||||
|
if child.type in ("identifier", "field_identifier"):
|
||||||
|
return code[child.start_byte:child.end_byte]
|
||||||
|
decl = decl.children[0]
|
||||||
|
else:
|
||||||
|
decl = next_decl
|
||||||
|
|
||||||
|
if node.type == "template_declaration":
|
||||||
|
for child in node.children:
|
||||||
|
if child.type in ("function_definition", "class_definition"):
|
||||||
|
return get_name(child)
|
||||||
|
|
||||||
|
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
|
||||||
|
for child in node.children:
|
||||||
|
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
|
||||||
|
return code[child.start_byte:child.end_byte]
|
||||||
|
return ""
|
||||||
|
|
||||||
|
parts = re.split(r'::|\.', name)
|
||||||
|
|
||||||
|
def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]:
|
||||||
|
if not target_parts:
|
||||||
|
return None
|
||||||
|
target = target_parts[0]
|
||||||
|
for child in node.children:
|
||||||
|
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||||
|
if get_name(child) == target:
|
||||||
|
if len(target_parts) == 1:
|
||||||
|
return child
|
||||||
|
body = child.child_by_field_name("body")
|
||||||
|
if not body and child.type == "template_declaration":
|
||||||
|
for sub in child.children:
|
||||||
|
if sub.type in ("function_definition", "class_definition"):
|
||||||
|
body = sub.child_by_field_name("body")
|
||||||
|
break
|
||||||
|
if body:
|
||||||
|
found = walk(body, target_parts[1:])
|
||||||
|
if found: return found
|
||||||
|
for sub in child.children:
|
||||||
|
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||||
|
found = walk(sub, target_parts[1:])
|
||||||
|
if found: return found
|
||||||
|
if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"):
|
||||||
|
for child in node.children:
|
||||||
|
if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||||
|
found = walk(child, target_parts)
|
||||||
|
if found: return found
|
||||||
|
return None
|
||||||
|
|
||||||
|
found_node = walk(tree.root_node, parts)
|
||||||
|
if not found_node and len(parts) == 1:
|
||||||
|
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||||
|
if node.type in ("function_definition", "template_declaration"):
|
||||||
|
if get_name(node) == target:
|
||||||
|
return node
|
||||||
|
for child in node.children:
|
||||||
|
res = deep_search(child, target)
|
||||||
|
if res: return res
|
||||||
|
return None
|
||||||
|
found_node = deep_search(tree.root_node, parts[0])
|
||||||
|
|
||||||
|
if found_node:
|
||||||
|
target_node = found_node
|
||||||
|
if found_node.type == "template_declaration":
|
||||||
|
for child in found_node.children:
|
||||||
|
if child.type in ("function_definition", "class_definition"):
|
||||||
|
target_node = child
|
||||||
|
break
|
||||||
|
|
||||||
|
body = target_node.child_by_field_name("body")
|
||||||
|
if body:
|
||||||
|
return code[found_node.start_byte:body.start_byte].strip()
|
||||||
|
return code[found_node.start_byte:found_node.end_byte].strip()
|
||||||
|
|
||||||
|
return f"ERROR: signature for '{name}' not found"
|
||||||
|
|
||||||
def get_code_outline(self, code: str, path: Optional[str] = None) -> str:
|
def get_code_outline(self, code: str, path: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Returns a hierarchical outline of the code (classes, structs, functions, methods).
|
Returns a hierarchical outline of the code (classes, structs, functions, methods).
|
||||||
|
|||||||
@@ -204,3 +204,104 @@ public:
|
|||||||
assert '[Class] MyClass (Lines 2-6)' in outline
|
assert '[Class] MyClass (Lines 2-6)' in outline
|
||||||
assert ' [Method] myMethod (Lines 4-5)' in outline
|
assert ' [Method] myMethod (Lines 4-5)' in outline
|
||||||
|
|
||||||
|
def test_ast_parser_get_definition_c() -> None:
|
||||||
|
"""Verify get_definition for C."""
|
||||||
|
parser = ASTParser(language="c")
|
||||||
|
code = """
|
||||||
|
void my_func() {
|
||||||
|
printf("hello\\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MyStruct {
|
||||||
|
int x;
|
||||||
|
};
|
||||||
|
"""
|
||||||
|
def1 = parser.get_definition(code, "my_func")
|
||||||
|
assert 'void my_func() {' in def1
|
||||||
|
assert 'printf("hello\\n");' in def1
|
||||||
|
|
||||||
|
def2 = parser.get_definition(code, "MyStruct")
|
||||||
|
assert 'struct MyStruct {' in def2
|
||||||
|
assert 'int x;' in def2
|
||||||
|
|
||||||
|
def test_ast_parser_get_definition_cpp() -> None:
|
||||||
|
"""Verify get_definition for C++ including scoped methods."""
|
||||||
|
parser = ASTParser(language="cpp")
|
||||||
|
code = """
|
||||||
|
class MyClass {
|
||||||
|
public:
|
||||||
|
void myMethod() {
|
||||||
|
int x = 1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace MyNamespace {
|
||||||
|
void nsFunc() {}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Scoped lookup
|
||||||
|
def1 = parser.get_definition(code, "MyClass::myMethod")
|
||||||
|
assert 'void myMethod() {' in def1
|
||||||
|
assert 'int x = 1;' in def1
|
||||||
|
|
||||||
|
# Just name lookup
|
||||||
|
def2 = parser.get_definition(code, "myMethod")
|
||||||
|
assert 'void myMethod() {' in def2
|
||||||
|
|
||||||
|
# Namespace lookup
|
||||||
|
def3 = parser.get_definition(code, "MyNamespace::nsFunc")
|
||||||
|
assert 'void nsFunc() {}' in def3
|
||||||
|
|
||||||
|
def test_ast_parser_get_definition_cpp_template() -> None:
|
||||||
|
"""Verify get_definition for C++ templates."""
|
||||||
|
parser = ASTParser(language="cpp")
|
||||||
|
code = """
|
||||||
|
template <typename T>
|
||||||
|
void myTemplateFunc(T x) {
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
def1 = parser.get_definition(code, "myTemplateFunc")
|
||||||
|
assert 'template <typename T>' in def1
|
||||||
|
assert 'void myTemplateFunc(T x) {' in def1
|
||||||
|
|
||||||
|
def test_ast_parser_get_signature_c() -> None:
|
||||||
|
"""Verify get_signature for C."""
|
||||||
|
parser = ASTParser(language="c")
|
||||||
|
code = """
|
||||||
|
void my_func(int a,
|
||||||
|
char* b) {
|
||||||
|
printf("hello\\n");
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
sig = parser.get_signature(code, "my_func")
|
||||||
|
assert 'void my_func(int a,' in sig
|
||||||
|
assert 'char* b)' in sig
|
||||||
|
assert '{' not in sig
|
||||||
|
assert 'printf' not in sig
|
||||||
|
|
||||||
|
def test_ast_parser_get_signature_cpp() -> None:
|
||||||
|
"""Verify get_signature for C++ templates and methods."""
|
||||||
|
parser = ASTParser(language="cpp")
|
||||||
|
code = """
|
||||||
|
class MyClass {
|
||||||
|
public:
|
||||||
|
template <typename T>
|
||||||
|
T myTemplateMethod(T x) {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
void normalMethod() {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
"""
|
||||||
|
# Template method
|
||||||
|
sig1 = parser.get_signature(code, "MyClass::myTemplateMethod")
|
||||||
|
assert 'template <typename T>' in sig1
|
||||||
|
assert 'T myTemplateMethod(T x)' in sig1
|
||||||
|
assert '{' not in sig1
|
||||||
|
|
||||||
|
# Normal method
|
||||||
|
sig2 = parser.get_signature(code, "MyClass::normalMethod")
|
||||||
|
assert 'void normalMethod()' in sig2
|
||||||
|
assert '{' not in sig2
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user