From 799feb0f946851da53fd89677e0d4438464c4c30 Mon Sep 17 00:00:00 2001 From: Ed_ Date: Tue, 5 May 2026 19:42:14 -0400 Subject: [PATCH] feat(parser): Implement C/C++ get_definition and get_signature --- src/file_cache.py | 180 +++++++++++++++++++++++++++++++++++++++ tests/test_ast_parser.py | 101 ++++++++++++++++++++++ 2 files changed, 281 insertions(+) diff --git a/src/file_cache.py b/src/file_cache.py index 57d06ef..924fefa 100644 --- a/src/file_cache.py +++ b/src/file_cache.py @@ -374,6 +374,186 @@ class ASTParser: result = re.sub(r'\n\s*\n\s*\n+', '\n\n', result) return result.strip() + "\n" + def get_definition(self, code: str, name: str, path: Optional[str] = None) -> str: + """ + Returns the full source code for a specific definition by name. + Supports 'ClassName::method' or 'method' for C++. + """ + tree = self.get_cached_tree(path, code) + + def get_name(node: tree_sitter.Node) -> str: + name_node = node.child_by_field_name("name") + if name_node: + return code[name_node.start_byte:name_node.end_byte] + + if node.type == "function_definition": + decl = node.child_by_field_name("declarator") + while decl: + if decl.type in ("identifier", "field_identifier"): + return code[decl.start_byte:decl.end_byte] + next_decl = decl.child_by_field_name("declarator") + if not next_decl and decl.child_count > 0: + for child in decl.children: + if child.type in ("identifier", "field_identifier"): + return code[child.start_byte:child.end_byte] + decl = decl.children[0] + else: + decl = next_decl + + if node.type == "template_declaration": + for child in node.children: + if child.type in ("function_definition", "class_definition"): + return get_name(child) + + if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"): + for child in node.children: + if child.type in ("type_identifier", "identifier", "namespace_identifier"): + return code[child.start_byte:child.end_byte] + return "" + + parts = re.split(r'::|\.', name) + + def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]: + if not target_parts: + return None + target = target_parts[0] + for child in node.children: + if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"): + if get_name(child) == target: + if len(target_parts) == 1: + return child + body = child.child_by_field_name("body") + if not body and child.type == "template_declaration": + for sub in child.children: + if sub.type in ("function_definition", "class_definition"): + body = sub.child_by_field_name("body") + break + if body: + found = walk(body, target_parts[1:]) + if found: return found + for sub in child.children: + if sub.type in ("field_declaration_list", "class_body", "declaration_list"): + found = walk(sub, target_parts[1:]) + if found: return found + # Recurse for top-level or namespaces + if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"): + for child in node.children: + if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"): + found = walk(child, target_parts) + if found: return found + return None + + found_node = walk(tree.root_node, parts) + if not found_node and len(parts) == 1: + def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]: + if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"): + if get_name(node) == target: + return node + for child in node.children: + res = deep_search(child, target) + if res: return res + return None + found_node = deep_search(tree.root_node, parts[0]) + + if found_node: + return code[found_node.start_byte:found_node.end_byte] + return f"ERROR: definition '{name}' not found" + + def get_signature(self, code: str, name: str, path: Optional[str] = None) -> str: + """ + Returns only the signature part of a function or method. + For C/C++, this is the code from the start of the definition until the block start '{'. + """ + tree = self.get_cached_tree(path, code) + + def get_name(node: tree_sitter.Node) -> str: + name_node = node.child_by_field_name("name") + if name_node: + return code[name_node.start_byte:name_node.end_byte] + + if node.type == "function_definition": + decl = node.child_by_field_name("declarator") + while decl: + if decl.type in ("identifier", "field_identifier"): + return code[decl.start_byte:decl.end_byte] + next_decl = decl.child_by_field_name("declarator") + if not next_decl and decl.child_count > 0: + for child in decl.children: + if child.type in ("identifier", "field_identifier"): + return code[child.start_byte:child.end_byte] + decl = decl.children[0] + else: + decl = next_decl + + if node.type == "template_declaration": + for child in node.children: + if child.type in ("function_definition", "class_definition"): + return get_name(child) + + if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"): + for child in node.children: + if child.type in ("type_identifier", "identifier", "namespace_identifier"): + return code[child.start_byte:child.end_byte] + return "" + + parts = re.split(r'::|\.', name) + + def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]: + if not target_parts: + return None + target = target_parts[0] + for child in node.children: + if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"): + if get_name(child) == target: + if len(target_parts) == 1: + return child + body = child.child_by_field_name("body") + if not body and child.type == "template_declaration": + for sub in child.children: + if sub.type in ("function_definition", "class_definition"): + body = sub.child_by_field_name("body") + break + if body: + found = walk(body, target_parts[1:]) + if found: return found + for sub in child.children: + if sub.type in ("field_declaration_list", "class_body", "declaration_list"): + found = walk(sub, target_parts[1:]) + if found: return found + if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"): + for child in node.children: + if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"): + found = walk(child, target_parts) + if found: return found + return None + + found_node = walk(tree.root_node, parts) + if not found_node and len(parts) == 1: + def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]: + if node.type in ("function_definition", "template_declaration"): + if get_name(node) == target: + return node + for child in node.children: + res = deep_search(child, target) + if res: return res + return None + found_node = deep_search(tree.root_node, parts[0]) + + if found_node: + target_node = found_node + if found_node.type == "template_declaration": + for child in found_node.children: + if child.type in ("function_definition", "class_definition"): + target_node = child + break + + body = target_node.child_by_field_name("body") + if body: + return code[found_node.start_byte:body.start_byte].strip() + return code[found_node.start_byte:found_node.end_byte].strip() + + return f"ERROR: signature for '{name}' not found" + def get_code_outline(self, code: str, path: Optional[str] = None) -> str: """ Returns a hierarchical outline of the code (classes, structs, functions, methods). diff --git a/tests/test_ast_parser.py b/tests/test_ast_parser.py index 70ceb83..c85671f 100644 --- a/tests/test_ast_parser.py +++ b/tests/test_ast_parser.py @@ -204,3 +204,104 @@ public: assert '[Class] MyClass (Lines 2-6)' in outline assert ' [Method] myMethod (Lines 4-5)' in outline +def test_ast_parser_get_definition_c() -> None: + """Verify get_definition for C.""" + parser = ASTParser(language="c") + code = """ +void my_func() { + printf("hello\\n"); +} + +struct MyStruct { + int x; +}; +""" + def1 = parser.get_definition(code, "my_func") + assert 'void my_func() {' in def1 + assert 'printf("hello\\n");' in def1 + + def2 = parser.get_definition(code, "MyStruct") + assert 'struct MyStruct {' in def2 + assert 'int x;' in def2 + +def test_ast_parser_get_definition_cpp() -> None: + """Verify get_definition for C++ including scoped methods.""" + parser = ASTParser(language="cpp") + code = """ +class MyClass { +public: + void myMethod() { + int x = 1; + } +}; + +namespace MyNamespace { + void nsFunc() {} +} +""" + # Scoped lookup + def1 = parser.get_definition(code, "MyClass::myMethod") + assert 'void myMethod() {' in def1 + assert 'int x = 1;' in def1 + + # Just name lookup + def2 = parser.get_definition(code, "myMethod") + assert 'void myMethod() {' in def2 + + # Namespace lookup + def3 = parser.get_definition(code, "MyNamespace::nsFunc") + assert 'void nsFunc() {}' in def3 + +def test_ast_parser_get_definition_cpp_template() -> None: + """Verify get_definition for C++ templates.""" + parser = ASTParser(language="cpp") + code = """ +template +void myTemplateFunc(T x) { +} +""" + def1 = parser.get_definition(code, "myTemplateFunc") + assert 'template ' in def1 + assert 'void myTemplateFunc(T x) {' in def1 + +def test_ast_parser_get_signature_c() -> None: + """Verify get_signature for C.""" + parser = ASTParser(language="c") + code = """ +void my_func(int a, + char* b) { + printf("hello\\n"); +} +""" + sig = parser.get_signature(code, "my_func") + assert 'void my_func(int a,' in sig + assert 'char* b)' in sig + assert '{' not in sig + assert 'printf' not in sig + +def test_ast_parser_get_signature_cpp() -> None: + """Verify get_signature for C++ templates and methods.""" + parser = ASTParser(language="cpp") + code = """ +class MyClass { +public: + template + T myTemplateMethod(T x) { + return x; + } + + void normalMethod() { + } +}; +""" + # Template method + sig1 = parser.get_signature(code, "MyClass::myTemplateMethod") + assert 'template ' in sig1 + assert 'T myTemplateMethod(T x)' in sig1 + assert '{' not in sig1 + + # Normal method + sig2 = parser.get_signature(code, "MyClass::normalMethod") + assert 'void normalMethod()' in sig2 + assert '{' not in sig2 +