feat(mcp): Finalize C/C++ AST tools with robust testing and bug fixes

This commit is contained in:
2026-05-05 20:08:51 -04:00
parent 584e8e526e
commit 992e206769
10 changed files with 468 additions and 230 deletions
+164 -209
View File
@@ -94,6 +94,39 @@ class ASTParser:
_ast_cache[path] = (mtime, tree)
return tree
def _get_name(self, node: tree_sitter.Node, code: str) -> str:
name_node = node.child_by_field_name("name")
if name_node:
return code[name_node.start_byte:name_node.end_byte]
if node.type in ("function_definition", "field_declaration"):
def find_id(n: tree_sitter.Node) -> str:
if n.type in ("identifier", "field_identifier", "qualified_identifier", "destructor_name"):
return code[n.start_byte:n.end_byte]
# Try field name 'declarator' first
d = n.child_by_field_name("declarator")
if d:
res = find_id(d)
if res: return res
# Fallback to all children
for child in n.children:
if child.type == "compound_statement": continue # Don't look in body
res = find_id(child)
if res: return res
return ""
return find_id(node)
if node.type == "template_declaration":
for child in node.children:
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "field_declaration"):
return self._get_name(child, code)
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
for child in node.children:
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
return code[child.start_byte:child.end_byte]
return ""
def get_skeleton(self, code: str, path: Optional[str] = None) -> str:
"""
Returns a skeleton of a Python file (preserving docstrings, stripping function bodies).
@@ -381,36 +414,6 @@ class ASTParser:
"""
tree = self.get_cached_tree(path, code)
def get_name(node: tree_sitter.Node) -> str:
name_node = node.child_by_field_name("name")
if name_node:
return code[name_node.start_byte:name_node.end_byte]
if node.type == "function_definition":
decl = node.child_by_field_name("declarator")
while decl:
if decl.type in ("identifier", "field_identifier"):
return code[decl.start_byte:decl.end_byte]
next_decl = decl.child_by_field_name("declarator")
if not next_decl and decl.child_count > 0:
for child in decl.children:
if child.type in ("identifier", "field_identifier"):
return code[child.start_byte:child.end_byte]
decl = decl.children[0]
else:
decl = next_decl
if node.type == "template_declaration":
for child in node.children:
if child.type in ("function_definition", "class_definition"):
return get_name(child)
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
for child in node.children:
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
return code[child.start_byte:child.end_byte]
return ""
parts = re.split(r'::|\.', name)
def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]:
@@ -418,42 +421,54 @@ class ASTParser:
return None
target = target_parts[0]
for child in node.children:
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if get_name(child) == target:
# If it's a field_declaration, it might wrap a class/struct/enum definition
check_node = child
if child.type == "field_declaration":
for sub in child.children:
if sub.type in ("class_specifier", "struct_specifier", "enum_specifier"):
check_node = sub
break
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
if is_interesting:
node_name = self._get_name(check_node, code)
if node_name == target:
if len(target_parts) == 1:
return child
body = child.child_by_field_name("body")
if not body and child.type == "template_declaration":
for sub in child.children:
if sub.type in ("function_definition", "class_definition"):
body = sub.child_by_field_name("body")
break
if body:
found = walk(body, target_parts[1:])
if found: return found
for sub in child.children:
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
found = walk(sub, target_parts[1:])
if found: return found
# Recurse for top-level or namespaces
if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"):
for child in node.children:
if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
found = walk(child, target_parts)
return check_node if child.type != "field_declaration" else child
next_parts = target_parts[1:]
else:
next_parts = target_parts
body = check_node.child_by_field_name("body")
if not body and check_node.type == "template_declaration":
for sub in check_node.children:
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
body = sub.child_by_field_name("body")
break
if body:
found = walk(body, next_parts)
if found: return found
for sub in check_node.children:
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
found = walk(sub, next_parts)
if found: return found
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
found = walk(child, target_parts)
if found: return found
return None
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if self._get_name(node, code) == target:
return node
for child in node.children:
res = deep_search(child, target)
if res: return res
return None
found_node = walk(tree.root_node, parts)
if not found_node and len(parts) == 1:
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if get_name(node) == target:
return node
for child in node.children:
res = deep_search(child, target)
if res: return res
return None
found_node = deep_search(tree.root_node, parts[0])
if not found_node:
found_node = deep_search(tree.root_node, name)
if found_node:
return code[found_node.start_byte:found_node.end_byte]
@@ -466,36 +481,6 @@ class ASTParser:
"""
tree = self.get_cached_tree(path, code)
def get_name(node: tree_sitter.Node) -> str:
name_node = node.child_by_field_name("name")
if name_node:
return code[name_node.start_byte:name_node.end_byte]
if node.type == "function_definition":
decl = node.child_by_field_name("declarator")
while decl:
if decl.type in ("identifier", "field_identifier"):
return code[decl.start_byte:decl.end_byte]
next_decl = decl.child_by_field_name("declarator")
if not next_decl and decl.child_count > 0:
for child in decl.children:
if child.type in ("identifier", "field_identifier"):
return code[child.start_byte:child.end_byte]
decl = decl.children[0]
else:
decl = next_decl
if node.type == "template_declaration":
for child in node.children:
if child.type in ("function_definition", "class_definition"):
return get_name(child)
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
for child in node.children:
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
return code[child.start_byte:child.end_byte]
return ""
parts = re.split(r'::|\.', name)
def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]:
@@ -503,47 +488,60 @@ class ASTParser:
return None
target = target_parts[0]
for child in node.children:
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if get_name(child) == target:
# If it's a field_declaration, it might wrap a class/struct/enum definition
check_node = child
if child.type == "field_declaration":
for sub in child.children:
if sub.type in ("class_specifier", "struct_specifier", "enum_specifier"):
check_node = sub
break
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
if is_interesting:
node_name = self._get_name(check_node, code)
if node_name == target:
if len(target_parts) == 1:
return child
body = child.child_by_field_name("body")
if not body and child.type == "template_declaration":
for sub in child.children:
if sub.type in ("function_definition", "class_definition"):
body = sub.child_by_field_name("body")
break
if body:
found = walk(body, target_parts[1:])
if found: return found
for sub in child.children:
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
found = walk(sub, target_parts[1:])
if found: return found
if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"):
for child in node.children:
if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
found = walk(child, target_parts)
return check_node if child.type != "field_declaration" else child
next_parts = target_parts[1:]
else:
next_parts = target_parts
body = check_node.child_by_field_name("body")
if not body and check_node.type == "template_declaration":
for sub in check_node.children:
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
body = sub.child_by_field_name("body")
break
if body:
found = walk(body, next_parts)
if found: return found
for sub in check_node.children:
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
found = walk(sub, next_parts)
if found: return found
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
found = walk(child, target_parts)
if found: return found
return None
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
if node.type in ("function_definition", "template_declaration"):
if self._get_name(node, code) == target:
return node
for child in node.children:
res = deep_search(child, target)
if res: return res
return None
found_node = walk(tree.root_node, parts)
if not found_node and len(parts) == 1:
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
if node.type in ("function_definition", "template_declaration"):
if get_name(node) == target:
return node
for child in node.children:
res = deep_search(child, target)
if res: return res
return None
found_node = deep_search(tree.root_node, parts[0])
if not found_node:
found_node = deep_search(tree.root_node, name)
if found_node:
target_node = found_node
if found_node.type == "template_declaration":
for child in found_node.children:
if child.type in ("function_definition", "class_definition"):
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
target_node = child
break
@@ -561,31 +559,6 @@ class ASTParser:
tree = self.get_cached_tree(path, code)
output = []
def get_name(node: tree_sitter.Node) -> str:
name_node = node.child_by_field_name("name")
if name_node:
return code[name_node.start_byte:name_node.end_byte]
if node.type == "function_definition":
decl = node.child_by_field_name("declarator")
while decl:
if decl.type in ("identifier", "field_identifier"):
return code[decl.start_byte:decl.end_byte]
next_decl = decl.child_by_field_name("declarator")
if not next_decl and decl.child_count > 0:
for child in decl.children:
if child.type in ("identifier", "field_identifier"):
return code[child.start_byte:child.end_byte]
decl = decl.children[0]
else:
decl = next_decl
if node.type in ("struct_specifier", "class_specifier"):
for child in node.children:
if child.type in ("type_identifier", "identifier"):
return code[child.start_byte:child.end_byte]
return ""
def walk(node: tree_sitter.Node, indent: int = 0) -> None:
ntype = node.type
label = ""
@@ -597,7 +570,7 @@ class ASTParser:
label = "[Method]" if indent > 0 else "[Func]"
if label:
name = get_name(node)
name = self._get_name(node, code)
if name:
start = node.start_point.row + 1
end = node.end_point.row + 1
@@ -620,36 +593,6 @@ class ASTParser:
"""
tree = self.get_cached_tree(path, code)
def get_name(node: tree_sitter.Node) -> str:
name_node = node.child_by_field_name("name")
if name_node:
return code[name_node.start_byte:name_node.end_byte]
if node.type == "function_definition":
decl = node.child_by_field_name("declarator")
while decl:
if decl.type in ("identifier", "field_identifier"):
return code[decl.start_byte:decl.end_byte]
next_decl = decl.child_by_field_name("declarator")
if not next_decl and decl.child_count > 0:
for child in decl.children:
if child.type in ("identifier", "field_identifier"):
return code[child.start_byte:child.end_byte]
decl = decl.children[0]
else:
decl = next_decl
if node.type == "template_declaration":
for child in node.children:
if child.type in ("function_definition", "class_definition"):
return get_name(child)
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
for child in node.children:
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
return code[child.start_byte:child.end_byte]
return ""
parts = re.split(r'::|\.', name)
def walk(node: tree_sitter.Node, target_parts: List[str]) -> Optional[tree_sitter.Node]:
@@ -657,42 +600,54 @@ class ASTParser:
return None
target = target_parts[0]
for child in node.children:
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if get_name(child) == target:
# If it's a field_declaration, it might wrap a class/struct/enum definition
check_node = child
if child.type == "field_declaration":
for sub in child.children:
if sub.type in ("class_specifier", "struct_specifier", "enum_specifier"):
check_node = sub
break
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
if is_interesting:
node_name = self._get_name(check_node, code)
if node_name == target:
if len(target_parts) == 1:
return child
body = child.child_by_field_name("body")
if not body and child.type == "template_declaration":
for sub in child.children:
if sub.type in ("function_definition", "class_definition"):
body = sub.child_by_field_name("body")
break
if body:
found = walk(body, target_parts[1:])
if found: return found
for sub in child.children:
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
found = walk(sub, target_parts[1:])
if found: return found
# Recurse for top-level or namespaces
if node.type in ("module", "translation_unit", "namespace_definition", "declaration_list"):
for child in node.children:
if child.type not in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
found = walk(child, target_parts)
return check_node if child.type != "field_declaration" else child
next_parts = target_parts[1:]
else:
next_parts = target_parts
body = check_node.child_by_field_name("body")
if not body and check_node.type == "template_declaration":
for sub in check_node.children:
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
body = sub.child_by_field_name("body")
break
if body:
found = walk(body, next_parts)
if found: return found
for sub in check_node.children:
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
found = walk(sub, next_parts)
if found: return found
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
found = walk(child, target_parts)
if found: return found
return None
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if self._get_name(node, code) == target:
return node
for child in node.children:
res = deep_search(child, target)
if res: return res
return None
found_node = walk(tree.root_node, parts)
if not found_node and len(parts) == 1:
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if get_name(node) == target:
return node
for child in node.children:
res = deep_search(child, target)
if res: return res
return None
found_node = deep_search(tree.root_node, parts[0])
if not found_node:
found_node = deep_search(tree.root_node, name)
if found_node:
code_bytes = bytearray(code, "utf8")