diff --git a/src/file_cache.py b/src/file_cache.py index 5b5918a..2554610 100644 --- a/src/file_cache.py +++ b/src/file_cache.py @@ -108,27 +108,63 @@ class ASTParser: if node.type in ("function_definition", "field_declaration", "declaration"): def find_id(n: tree_sitter.Node) -> str: - # In C++, prefer function_declarator or operator_name + # In C++, prefer function_declarator or operator_name FIRST - look for it explicitly if self.language_name in ("cpp", "c"): - # Look for declarator field or child - decl = n.child_by_field_name("declarator") - if decl: - # If it's a function_declarator, the name is inside its 'declarator' or 'name' - if decl.type == "function_declarator": - nested_decl = decl.child_by_field_name("declarator") - if nested_decl: - return find_id(nested_decl) - elif decl.type == "pointer_declarator": - return find_id(decl) - elif decl.type in ("identifier", "field_identifier", "qualified_identifier", "operator_name", "destructor_name"): - return code_bytes[decl.start_byte:decl.end_byte].decode("utf8", errors="replace") + # First, check if this node itself is a simple identifier (method name) + if n.type in ("identifier", "field_identifier", "destructor_name"): + return code_bytes[n.start_byte:n.end_byte].decode("utf8", errors="replace") + # Handle reference_declarator which wraps function_declarator for return types like T& + if n.type in ("reference_declarator", "pointer_declarator"): + for child in n.children: + if child.type == "function_declarator": + nested_decl = child.child_by_field_name("declarator") + if nested_decl: + return find_id(nested_decl) + for fc in child.children: + if fc.type in ("field_identifier", "identifier", "operator_name", "destructor_name"): + return code_bytes[fc.start_byte:fc.end_byte].decode("utf8", errors="replace") if n.type in ("identifier", "field_identifier", "qualified_identifier", "operator_name", "destructor_name"): return code_bytes[n.start_byte:n.end_byte].decode("utf8", errors="replace") - + + # AVOID qualified_identifier in C++ - it's used in type expressions + if self.language_name in ("cpp", "c"): + if n.type in ("type_identifier", "primitive_type", "builtin_type", "qualified_identifier", "type_parameter", "template_type"): + return "" + + # AVOID parameter_list in C++ - it contains parameter names that can confuse identifier finding + if n.type in ("parameter_list", "parameter_declaration"): + return "" + + # For function_definition, check direct function_declarator child + if n.type == "function_definition": + for child in n.children: + if child.type == "function_declarator": + nested_decl = child.child_by_field_name("declarator") + if nested_decl: + return find_id(nested_decl) + for fc in child.children: + if fc.type in ("field_identifier", "identifier", "operator_name", "destructor_name"): + return code_bytes[fc.start_byte:fc.end_byte].decode("utf8", errors="replace") + + # For field_declarations with complex return types like T& or T* + # we need to look inside reference_declarator/pointer_declarator children + if n.type in ("field_declaration", "function_definition"): + for child in n.children: + if child.type in ("reference_declarator", "pointer_declarator"): + for subchild in child.children: + if subchild.type == "function_declarator": + nested_decl = subchild.child_by_field_name("declarator") + if nested_decl: + return find_id(nested_decl) + for fc in subchild.children: + if fc.type in ("field_identifier", "identifier", "operator_name", "destructor_name"): + return code_bytes[fc.start_byte:fc.end_byte].decode("utf8", errors="replace") + # Fallback to children, but avoid bodies and types for child in n.children: - if child.type in ("compound_statement", "field_declaration_list", "class_body", "declaration_list", "enum_body", "type_identifier", "primitive_type"): continue + if child.type in ("compound_statement", "field_declaration_list", "class_body", "declaration_list", "enum_body", "type_identifier", "primitive_type", "builtin_type", "namespace_identifier", "qualified_identifier", "reference_declarator", "pointer_declarator"): + continue res = find_id(child) if res: return res return "" @@ -473,6 +509,9 @@ class ASTParser: if node_name == target: if len(target_parts) == 1: match = check_node if child.type != "field_declaration" else child + # template_declaration should always be returned as-is (no body field but contains the definition) + if match.type == "template_declaration": + return match if match.child_by_field_name("body"): return match if not best_match: @@ -507,7 +546,7 @@ class ASTParser: def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]: best = None - if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration", "declaration"): + if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration", "declaration", "field_declaration"): if self._get_name(node, code_bytes) == target: if node.child_by_field_name("body"): return node @@ -522,7 +561,8 @@ class ASTParser: return best found_node = walk(tree.root_node, parts) - if not found_node or not found_node.child_by_field_name("body"): + # template_declaration doesn't have body field but is valid as-is + if not found_node or (not found_node.child_by_field_name("body") and found_node.type != "template_declaration"): alt = deep_search(tree.root_node, name) if alt: if not found_node or alt.child_by_field_name("body"): @@ -565,6 +605,9 @@ class ASTParser: if node_name == target: if len(target_parts) == 1: match = check_node if child.type != "field_declaration" else child + # template_declaration should always be returned as-is (no body field but contains the definition) + if match.type == "template_declaration": + return match if match.child_by_field_name("body"): return match if not best_match: @@ -614,7 +657,8 @@ class ASTParser: return best found_node = walk(tree.root_node, parts) - if not found_node or not found_node.child_by_field_name("body"): + # template_declaration doesn't have body field but is valid as-is + if not found_node or (not found_node.child_by_field_name("body") and found_node.type != "template_declaration"): alt = deep_search(tree.root_node, name) if alt: if not found_node or alt.child_by_field_name("body"): @@ -708,6 +752,9 @@ class ASTParser: if node_name == target: if len(target_parts) == 1: match = check_node if child.type != "field_declaration" else child + # template_declaration should always be returned as-is (no body field but contains the definition) + if match.type == "template_declaration": + return match if match.child_by_field_name("body"): return match if not best_match: @@ -742,7 +789,7 @@ class ASTParser: def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]: best = None - if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration", "declaration"): + if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration", "declaration", "field_declaration"): if self._get_name(node, code_bytes) == target: if node.child_by_field_name("body"): return node @@ -757,7 +804,8 @@ class ASTParser: return best found_node = walk(tree.root_node, parts) - if not found_node or not found_node.child_by_field_name("body"): + # template_declaration doesn't have body field but is valid as-is + if not found_node or (not found_node.child_by_field_name("body") and found_node.type != "template_declaration"): alt = deep_search(tree.root_node, name) if alt: if not found_node or alt.child_by_field_name("body"):