feat(mcp): Finalize C/C++ AST tools with robust testing and bug fixes
This commit is contained in:
+164
-209
@@ -94,6 +94,39 @@ class ASTParser:
|
||||
_ast_cache[path] = (mtime, tree)
|
||||
return tree
|
||||
|
||||
def _get_name(self, node: tree_sitter.Node, code: str) -> str:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
return code[name_node.start_byte:name_node.end_byte]
|
||||
|
||||
if node.type in ("function_definition", "field_declaration"):
|
||||
def find_id(n: tree_sitter.Node) -> str:
|
||||
if n.type in ("identifier", "field_identifier", "qualified_identifier", "destructor_name"):
|
||||
return code[n.start_byte:n.end_byte]
|
||||
# Try field name 'declarator' first
|
||||
d = n.child_by_field_name("declarator")
|
||||
if d:
|
||||
res = find_id(d)
|
||||
if res: return res
|
||||
# Fallback to all children
|
||||
for child in n.children:
|
||||
if child.type == "compound_statement": continue # Don't look in body
|
||||
res = find_id(child)
|
||||
if res: return res
|
||||
return ""
|
||||
return find_id(node)
|
||||
|
||||
if node.type == "template_declaration":
|
||||
for child in node.children:
|
||||
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "field_declaration"):
|
||||
return self._get_name(child, code)
|
||||
|
||||
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
|
||||
for child in node.children:
|
||||
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
return ""
|
||||
|
||||
def get_skeleton(self, code: str, path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Returns a skeleton of a Python file (preserving docstrings, stripping function bodies).
|
||||
@@ -381,36 +414,6 @@ class ASTParser:
|
||||
"""
|
||||
tree = self.get_cached_tree(path, code)
|
||||
|
||||
def get_name(node: tree_sitter.Node) -> str:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
return code[name_node.start_byte:name_node.end_byte]
|
||||
|
||||
if node.type == "function_definition":
|
||||
decl = node.child_by_field_name("declarator")
|
||||
while decl:
|
||||
if decl.type in ("identifier", "field_identifier"):
|
||||
return code[decl.start_byte:decl.end_byte]
|
||||
next_decl = decl.child_by_field_name("declarator")
|
||||
if not next_decl and decl.child_count > 0:
|
||||
for child in decl.children:
|
||||
if child.type in ("identifier", "field_identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
decl = decl.children[0]
|
||||
else:
|
||||
decl = next_decl
|
||||
|
||||
if node.type == "template_declaration":
|
||||
for child in node.children:
|
||||
if child.type in ("function_definition", "class_definition"):
|
||||
return get_name(child)
|
||||
|
||||
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
|
||||
for child in node.children:
|
||||
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
return ""
|
||||
|
||||
parts = re.split(r'::|\.', name)
|
||||
|
||||
def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]:
|
||||
@@ -418,42 +421,54 @@ class ASTParser:
|
||||
return None
|
||||
target = target_parts[0]
|
||||
for child in node.children:
|
||||
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
if get_name(child) == target:
|
||||
# If it's a field_declaration, it might wrap a class/struct/enum definition
|
||||
check_node = child
|
||||
if child.type == "field_declaration":
|
||||
for sub in child.children:
|
||||
if sub.type in ("class_specifier", "struct_specifier", "enum_specifier"):
|
||||
check_node = sub
|
||||
break
|
||||
|
||||
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
|
||||
if is_interesting:
|
||||
node_name = self._get_name(check_node, code)
|
||||
if node_name == target:
|
||||
if len(target_parts) == 1:
|
||||
return child
|
||||
body = child.child_by_field_name("body")
|
||||
if not body and child.type == "template_declaration":
|
||||
for sub in child.children:
|
||||
if sub.type in ("function_definition", "class_definition"):
|
||||
body = sub.child_by_field_name("body")
|
||||
break
|
||||
if body:
|
||||
found = walk(body, target_parts[1:])
|
||||
if found: return found
|
||||
for sub in child.children:
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||
found = walk(sub, target_parts[1:])
|
||||
if found: return found
|
||||
# Recurse for top-level or namespaces
|
||||
if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"):
|
||||
for child in node.children:
|
||||
if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
found = walk(child, target_parts)
|
||||
return check_node if child.type != "field_declaration" else child
|
||||
next_parts = target_parts[1:]
|
||||
else:
|
||||
next_parts = target_parts
|
||||
|
||||
body = check_node.child_by_field_name("body")
|
||||
if not body and check_node.type == "template_declaration":
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
|
||||
body = sub.child_by_field_name("body")
|
||||
break
|
||||
if body:
|
||||
found = walk(body, next_parts)
|
||||
if found: return found
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||
found = walk(sub, next_parts)
|
||||
if found: return found
|
||||
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
|
||||
found = walk(child, target_parts)
|
||||
if found: return found
|
||||
return None
|
||||
|
||||
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
if self._get_name(node, code) == target:
|
||||
return node
|
||||
for child in node.children:
|
||||
res = deep_search(child, target)
|
||||
if res: return res
|
||||
return None
|
||||
|
||||
found_node = walk(tree.root_node, parts)
|
||||
if not found_node and len(parts) == 1:
|
||||
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
if get_name(node) == target:
|
||||
return node
|
||||
for child in node.children:
|
||||
res = deep_search(child, target)
|
||||
if res: return res
|
||||
return None
|
||||
found_node = deep_search(tree.root_node, parts[0])
|
||||
if not found_node:
|
||||
found_node = deep_search(tree.root_node, name)
|
||||
|
||||
if found_node:
|
||||
return code[found_node.start_byte:found_node.end_byte]
|
||||
@@ -466,36 +481,6 @@ class ASTParser:
|
||||
"""
|
||||
tree = self.get_cached_tree(path, code)
|
||||
|
||||
def get_name(node: tree_sitter.Node) -> str:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
return code[name_node.start_byte:name_node.end_byte]
|
||||
|
||||
if node.type == "function_definition":
|
||||
decl = node.child_by_field_name("declarator")
|
||||
while decl:
|
||||
if decl.type in ("identifier", "field_identifier"):
|
||||
return code[decl.start_byte:decl.end_byte]
|
||||
next_decl = decl.child_by_field_name("declarator")
|
||||
if not next_decl and decl.child_count > 0:
|
||||
for child in decl.children:
|
||||
if child.type in ("identifier", "field_identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
decl = decl.children[0]
|
||||
else:
|
||||
decl = next_decl
|
||||
|
||||
if node.type == "template_declaration":
|
||||
for child in node.children:
|
||||
if child.type in ("function_definition", "class_definition"):
|
||||
return get_name(child)
|
||||
|
||||
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
|
||||
for child in node.children:
|
||||
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
return ""
|
||||
|
||||
parts = re.split(r'::|\.', name)
|
||||
|
||||
def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]:
|
||||
@@ -503,47 +488,60 @@ class ASTParser:
|
||||
return None
|
||||
target = target_parts[0]
|
||||
for child in node.children:
|
||||
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
if get_name(child) == target:
|
||||
# If it's a field_declaration, it might wrap a class/struct/enum definition
|
||||
check_node = child
|
||||
if child.type == "field_declaration":
|
||||
for sub in child.children:
|
||||
if sub.type in ("class_specifier", "struct_specifier", "enum_specifier"):
|
||||
check_node = sub
|
||||
break
|
||||
|
||||
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
|
||||
if is_interesting:
|
||||
node_name = self._get_name(check_node, code)
|
||||
if node_name == target:
|
||||
if len(target_parts) == 1:
|
||||
return child
|
||||
body = child.child_by_field_name("body")
|
||||
if not body and child.type == "template_declaration":
|
||||
for sub in child.children:
|
||||
if sub.type in ("function_definition", "class_definition"):
|
||||
body = sub.child_by_field_name("body")
|
||||
break
|
||||
if body:
|
||||
found = walk(body, target_parts[1:])
|
||||
if found: return found
|
||||
for sub in child.children:
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||
found = walk(sub, target_parts[1:])
|
||||
if found: return found
|
||||
if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"):
|
||||
for child in node.children:
|
||||
if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
found = walk(child, target_parts)
|
||||
return check_node if child.type != "field_declaration" else child
|
||||
next_parts = target_parts[1:]
|
||||
else:
|
||||
next_parts = target_parts
|
||||
|
||||
body = check_node.child_by_field_name("body")
|
||||
if not body and check_node.type == "template_declaration":
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
|
||||
body = sub.child_by_field_name("body")
|
||||
break
|
||||
if body:
|
||||
found = walk(body, next_parts)
|
||||
if found: return found
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||
found = walk(sub, next_parts)
|
||||
if found: return found
|
||||
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
|
||||
found = walk(child, target_parts)
|
||||
if found: return found
|
||||
return None
|
||||
|
||||
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||
if node.type in ("function_definition", "template_declaration"):
|
||||
if self._get_name(node, code) == target:
|
||||
return node
|
||||
for child in node.children:
|
||||
res = deep_search(child, target)
|
||||
if res: return res
|
||||
return None
|
||||
|
||||
found_node = walk(tree.root_node, parts)
|
||||
if not found_node and len(parts) == 1:
|
||||
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||
if node.type in ("function_definition", "template_declaration"):
|
||||
if get_name(node) == target:
|
||||
return node
|
||||
for child in node.children:
|
||||
res = deep_search(child, target)
|
||||
if res: return res
|
||||
return None
|
||||
found_node = deep_search(tree.root_node, parts[0])
|
||||
if not found_node:
|
||||
found_node = deep_search(tree.root_node, name)
|
||||
|
||||
if found_node:
|
||||
target_node = found_node
|
||||
if found_node.type == "template_declaration":
|
||||
for child in found_node.children:
|
||||
if child.type in ("function_definition", "class_definition"):
|
||||
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
|
||||
target_node = child
|
||||
break
|
||||
|
||||
@@ -561,31 +559,6 @@ class ASTParser:
|
||||
tree = self.get_cached_tree(path, code)
|
||||
output = []
|
||||
|
||||
def get_name(node: tree_sitter.Node) -> str:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
return code[name_node.start_byte:name_node.end_byte]
|
||||
|
||||
if node.type == "function_definition":
|
||||
decl = node.child_by_field_name("declarator")
|
||||
while decl:
|
||||
if decl.type in ("identifier", "field_identifier"):
|
||||
return code[decl.start_byte:decl.end_byte]
|
||||
next_decl = decl.child_by_field_name("declarator")
|
||||
if not next_decl and decl.child_count > 0:
|
||||
for child in decl.children:
|
||||
if child.type in ("identifier", "field_identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
decl = decl.children[0]
|
||||
else:
|
||||
decl = next_decl
|
||||
|
||||
if node.type in ("struct_specifier", "class_specifier"):
|
||||
for child in node.children:
|
||||
if child.type in ("type_identifier", "identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
return ""
|
||||
|
||||
def walk(node: tree_sitter.Node, indent: int = 0) -> None:
|
||||
ntype = node.type
|
||||
label = ""
|
||||
@@ -597,7 +570,7 @@ class ASTParser:
|
||||
label = "[Method]" if indent > 0 else "[Func]"
|
||||
|
||||
if label:
|
||||
name = get_name(node)
|
||||
name = self._get_name(node, code)
|
||||
if name:
|
||||
start = node.start_point.row + 1
|
||||
end = node.end_point.row + 1
|
||||
@@ -620,36 +593,6 @@ class ASTParser:
|
||||
"""
|
||||
tree = self.get_cached_tree(path, code)
|
||||
|
||||
def get_name(node: tree_sitter.Node) -> str:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
return code[name_node.start_byte:name_node.end_byte]
|
||||
|
||||
if node.type == "function_definition":
|
||||
decl = node.child_by_field_name("declarator")
|
||||
while decl:
|
||||
if decl.type in ("identifier", "field_identifier"):
|
||||
return code[decl.start_byte:decl.end_byte]
|
||||
next_decl = decl.child_by_field_name("declarator")
|
||||
if not next_decl and decl.child_count > 0:
|
||||
for child in decl.children:
|
||||
if child.type in ("identifier", "field_identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
decl = decl.children[0]
|
||||
else:
|
||||
decl = next_decl
|
||||
|
||||
if node.type == "template_declaration":
|
||||
for child in node.children:
|
||||
if child.type in ("function_definition", "class_definition"):
|
||||
return get_name(child)
|
||||
|
||||
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
|
||||
for child in node.children:
|
||||
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
return ""
|
||||
|
||||
parts = re.split(r'::|\.', name)
|
||||
|
||||
def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]:
|
||||
@@ -657,42 +600,54 @@ class ASTParser:
|
||||
return None
|
||||
target = target_parts[0]
|
||||
for child in node.children:
|
||||
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
if get_name(child) == target:
|
||||
# If it's a field_declaration, it might wrap a class/struct/enum definition
|
||||
check_node = child
|
||||
if child.type == "field_declaration":
|
||||
for sub in child.children:
|
||||
if sub.type in ("class_specifier", "struct_specifier", "enum_specifier"):
|
||||
check_node = sub
|
||||
break
|
||||
|
||||
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
|
||||
if is_interesting:
|
||||
node_name = self._get_name(check_node, code)
|
||||
if node_name == target:
|
||||
if len(target_parts) == 1:
|
||||
return child
|
||||
body = child.child_by_field_name("body")
|
||||
if not body and child.type == "template_declaration":
|
||||
for sub in child.children:
|
||||
if sub.type in ("function_definition", "class_definition"):
|
||||
body = sub.child_by_field_name("body")
|
||||
break
|
||||
if body:
|
||||
found = walk(body, target_parts[1:])
|
||||
if found: return found
|
||||
for sub in child.children:
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||
found = walk(sub, target_parts[1:])
|
||||
if found: return found
|
||||
# Recurse for top-level or namespaces
|
||||
if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"):
|
||||
for child in node.children:
|
||||
if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
found = walk(child, target_parts)
|
||||
return check_node if child.type != "field_declaration" else child
|
||||
next_parts = target_parts[1:]
|
||||
else:
|
||||
next_parts = target_parts
|
||||
|
||||
body = check_node.child_by_field_name("body")
|
||||
if not body and check_node.type == "template_declaration":
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
|
||||
body = sub.child_by_field_name("body")
|
||||
break
|
||||
if body:
|
||||
found = walk(body, next_parts)
|
||||
if found: return found
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||
found = walk(sub, next_parts)
|
||||
if found: return found
|
||||
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
|
||||
found = walk(child, target_parts)
|
||||
if found: return found
|
||||
return None
|
||||
|
||||
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
if self._get_name(node, code) == target:
|
||||
return node
|
||||
for child in node.children:
|
||||
res = deep_search(child, target)
|
||||
if res: return res
|
||||
return None
|
||||
|
||||
found_node = walk(tree.root_node, parts)
|
||||
if not found_node and len(parts) == 1:
|
||||
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
if get_name(node) == target:
|
||||
return node
|
||||
for child in node.children:
|
||||
res = deep_search(child, target)
|
||||
if res: return res
|
||||
return None
|
||||
found_node = deep_search(tree.root_node, parts[0])
|
||||
if not found_node:
|
||||
found_node = deep_search(tree.root_node, name)
|
||||
|
||||
if found_node:
|
||||
code_bytes = bytearray(code, "utf8")
|
||||
|
||||
Reference in New Issue
Block a user