feat(mcp): Validate C++ tools against real-world gencpp components and improve enum support

This commit is contained in:
2026-05-05 20:40:21 -04:00
parent a809a6e213
commit 904dabe6a1
12 changed files with 6643 additions and 53 deletions
+58 -51
View File
@@ -94,15 +94,15 @@ class ASTParser:
_ast_cache[path] = (mtime, tree)
return tree
def _get_name(self, node: tree_sitter.Node, code: str) -> str:
def _get_name(self, node: tree_sitter.Node, code_bytes: bytes) -> str:
name_node = node.child_by_field_name("name")
if name_node:
return code[name_node.start_byte:name_node.end_byte]
return code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace")
if node.type in ("function_definition", "field_declaration"):
def find_id(n: tree_sitter.Node) -> str:
if n.type in ("identifier", "field_identifier", "qualified_identifier", "destructor_name"):
return code[n.start_byte:n.end_byte]
return code_bytes[n.start_byte:n.end_byte].decode("utf8", errors="replace")
# Try field name 'declarator' first
d = n.child_by_field_name("declarator")
if d:
@@ -118,19 +118,20 @@ class ASTParser:
if node.type == "template_declaration":
for child in node.children:
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "field_declaration"):
return self._get_name(child, code)
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "field_declaration"):
return self._get_name(child, code_bytes)
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
if node.type in ("struct_specifier", "class_specifier", "class_definition", "enum_specifier", "enum_definition", "namespace_definition"):
for child in node.children:
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
return code[child.start_byte:child.end_byte]
return code_bytes[child.start_byte:child.end_byte].decode("utf8", errors="replace")
return ""
def get_skeleton(self, code: str, path: Optional[str] = None) -> str:
"""
Returns a skeleton of a Python file (preserving docstrings, stripping function bodies).
"""
code_bytes = code.encode("utf8")
tree = self.get_cached_tree(path, code)
edits: List[Tuple[int, int, str]] = []
@@ -170,10 +171,10 @@ class ASTParser:
walk(tree.root_node)
# Apply edits in reverse to maintain byte offsets
edits.sort(key=lambda x: x[0], reverse=True)
code_bytes = bytearray(code, "utf8")
code_bytearray = bytearray(code_bytes)
for start, end, replacement in edits:
code_bytes[start:end] = bytes(replacement, "utf8")
return code_bytes.decode("utf8")
code_bytearray[start:end] = bytes(replacement, "utf8")
return code_bytearray.decode("utf8")
def get_curated_view(self, code: str, path: Optional[str] = None) -> str:
"""
@@ -181,6 +182,7 @@ class ASTParser:
Preserves function bodies if they have @core_logic decorator or # [HOT] comment.
Otherwise strips bodies but preserves docstrings.
"""
code_bytes = code.encode("utf8")
tree = self.get_cached_tree(path, code)
edits: List[Tuple[int, int, str]] = []
@@ -197,7 +199,7 @@ class ASTParser:
for child in parent.children:
if child.type == "decorator":
# decorator -> ( '@', identifier ) or ( '@', call )
if "@core_logic" in code[child.start_byte:child.end_byte]:
if b"@core_logic" in code_bytes[child.start_byte:child.end_byte]:
return True
return False
@@ -207,8 +209,8 @@ class ASTParser:
while stack:
curr = stack.pop()
if curr.type == "comment":
comment_text = code[curr.start_byte:curr.end_byte]
if "[HOT]" in comment_text:
comment_bytes = code_bytes[curr.start_byte:curr.end_byte]
if b"[HOT]" in comment_bytes:
return True
for child in curr.children:
stack.append(child)
@@ -241,16 +243,17 @@ class ASTParser:
walk(tree.root_node)
# Apply edits in reverse to maintain byte offsets
edits.sort(key=lambda x: x[0], reverse=True)
code_bytes = bytearray(code, "utf8")
code_bytearray = bytearray(code_bytes)
for start, end, replacement in edits:
code_bytes[start:end] = bytes(replacement, "utf8")
return code_bytes.decode("utf8")
code_bytearray[start:end] = bytes(replacement, "utf8")
return code_bytearray.decode("utf8")
def get_targeted_view(self, code: str, function_names: List[str], path: Optional[str] = None) -> str:
"""
Returns a targeted view of the code including only the specified functions
and their dependencies up to depth 2.
"""
code_bytes = code.encode("utf8")
tree = self.get_cached_tree(path, code)
all_functions = {}
@@ -258,13 +261,13 @@ class ASTParser:
if node.type == "function_definition":
name_node = node.child_by_field_name("name")
if name_node:
func_name = code[name_node.start_byte:name_node.end_byte]
func_name = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace")
full_name = f"{class_name}.{func_name}" if class_name else func_name
all_functions[full_name] = node
elif node.type == "class_definition":
name_node = node.child_by_field_name("name")
if name_node:
cname = code[name_node.start_byte:name_node.end_byte]
cname = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace")
full_cname = f"{class_name}.{cname}" if class_name else cname
body = node.child_by_field_name("body")
if body:
@@ -282,11 +285,11 @@ class ASTParser:
func_node = n.child_by_field_name("function")
if func_node:
if func_node.type == "identifier":
calls.add(code[func_node.start_byte:func_node.end_byte])
calls.add(code_bytes[func_node.start_byte:func_node.end_byte].decode("utf8", errors="replace"))
elif func_node.type == "attribute":
attr_node = func_node.child_by_field_name("attribute")
if attr_node:
calls.add(code[attr_node.start_byte:attr_node.end_byte])
calls.add(code_bytes[attr_node.start_byte:attr_node.end_byte].decode("utf8", errors="replace"))
for child in n.children:
walk_calls(child)
walk_calls(node)
@@ -329,12 +332,12 @@ class ASTParser:
def check_for_targeted(node, parent_class=None):
if node.type == "function_definition":
name_node = node.child_by_field_name("name")
fname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
fname = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace") if name_node else ""
fullname = f"{parent_class}.{fname}" if parent_class else fname
return fullname in all_found
if node.type == "class_definition":
name_node = node.child_by_field_name("name")
cname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
cname = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace") if name_node else ""
full_cname = f"{parent_class}.{cname}" if parent_class else cname
body = node.child_by_field_name("body")
if body:
@@ -350,7 +353,7 @@ class ASTParser:
def walk_edits(node, parent_class=None):
if node.type == "function_definition":
name_node = node.child_by_field_name("name")
fname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
fname = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace") if name_node else ""
fullname = f"{parent_class}.{fname}" if parent_class else fname
if fullname in all_found:
body = node.child_by_field_name("body")
@@ -376,7 +379,7 @@ class ASTParser:
if node.type == "class_definition":
if check_for_targeted(node, parent_class):
name_node = node.child_by_field_name("name")
cname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
cname = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace") if name_node else ""
full_cname = f"{parent_class}.{cname}" if parent_class else cname
body = node.child_by_field_name("body")
if body:
@@ -400,10 +403,10 @@ class ASTParser:
walk_edits(tree.root_node)
edits.sort(key=lambda x: x[0], reverse=True)
code_bytes = bytearray(code, "utf8")
code_bytearray = bytearray(code_bytes)
for start, end, replacement in edits:
code_bytes[start:end] = bytes(replacement, "utf8")
result = code_bytes.decode("utf8")
code_bytearray[start:end] = bytes(replacement, "utf8")
result = code_bytearray.decode("utf8")
result = re.sub(r'\n\s*\n\s*\n+', '\n\n', result)
return result.strip() + "\n"
@@ -412,6 +415,7 @@ class ASTParser:
Returns the full source code for a specific definition by name.
Supports 'ClassName::method' or 'method' for C++.
"""
code_bytes = code.encode("utf8")
tree = self.get_cached_tree(path, code)
parts = re.split(r'::|\.', name)
@@ -429,9 +433,9 @@ class ASTParser:
check_node = sub
break
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration", "field_declaration")
if is_interesting:
node_name = self._get_name(check_node, code)
node_name = self._get_name(check_node, code_bytes)
if node_name == target:
if len(target_parts) == 1:
return check_node if child.type != "field_declaration" else child
@@ -442,14 +446,14 @@ class ASTParser:
body = check_node.child_by_field_name("body")
if not body and check_node.type == "template_declaration":
for sub in check_node.children:
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition"):
body = sub.child_by_field_name("body")
break
if body:
found = walk(body, next_parts)
if found: return found
for sub in check_node.children:
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
if sub.type in ("field_declaration_list", "class_body", "declaration_list", "enum_body"):
found = walk(sub, next_parts)
if found: return found
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
@@ -458,8 +462,8 @@ class ASTParser:
return None
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if self._get_name(node, code) == target:
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration"):
if self._get_name(node, code_bytes) == target:
return node
for child in node.children:
res = deep_search(child, target)
@@ -471,7 +475,7 @@ class ASTParser:
found_node = deep_search(tree.root_node, name)
if found_node:
return code[found_node.start_byte:found_node.end_byte]
return code_bytes[found_node.start_byte:found_node.end_byte].decode("utf8", errors="replace")
return f"ERROR: definition '{name}' not found"
def get_signature(self, code: str, name: str, path: Optional[str] = None) -> str:
@@ -479,6 +483,7 @@ class ASTParser:
Returns only the signature part of a function or method.
For C/C++, this is the code from the start of the definition until the block start '{'.
"""
code_bytes = code.encode("utf8")
tree = self.get_cached_tree(path, code)
parts = re.split(r'::|\.', name)
@@ -496,9 +501,9 @@ class ASTParser:
check_node = sub
break
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration", "field_declaration")
if is_interesting:
node_name = self._get_name(check_node, code)
node_name = self._get_name(check_node, code_bytes)
if node_name == target:
if len(target_parts) == 1:
return check_node if child.type != "field_declaration" else child
@@ -509,14 +514,14 @@ class ASTParser:
body = check_node.child_by_field_name("body")
if not body and check_node.type == "template_declaration":
for sub in check_node.children:
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition"):
body = sub.child_by_field_name("body")
break
if body:
found = walk(body, next_parts)
if found: return found
for sub in check_node.children:
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
if sub.type in ("field_declaration_list", "class_body", "declaration_list", "enum_body"):
found = walk(sub, next_parts)
if found: return found
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
@@ -526,7 +531,7 @@ class ASTParser:
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
if node.type in ("function_definition", "template_declaration"):
if self._get_name(node, code) == target:
if self._get_name(node, code_bytes) == target:
return node
for child in node.children:
res = deep_search(child, target)
@@ -547,8 +552,8 @@ class ASTParser:
body = target_node.child_by_field_name("body")
if body:
return code[found_node.start_byte:body.start_byte].strip()
return code[found_node.start_byte:found_node.end_byte].strip()
return code_bytes[found_node.start_byte:body.start_byte].decode("utf8", errors="replace").strip()
return code_bytes[found_node.start_byte:found_node.end_byte].decode("utf8", errors="replace").strip()
return f"ERROR: signature for '{name}' not found"
@@ -556,6 +561,7 @@ class ASTParser:
"""
Returns a hierarchical outline of the code (classes, structs, functions, methods).
"""
code_bytes = code.encode("utf8")
tree = self.get_cached_tree(path, code)
output = []
@@ -570,7 +576,7 @@ class ASTParser:
label = "[Method]" if indent > 0 else "[Func]"
if label:
name = self._get_name(node, code)
name = self._get_name(node, code_bytes)
if name:
start = node.start_point.row + 1
end = node.end_point.row + 1
@@ -591,6 +597,7 @@ class ASTParser:
"""
Surgically replace the definition of a class or function by name.
"""
code_bytes = code.encode("utf8")
tree = self.get_cached_tree(path, code)
parts = re.split(r'::|\.', name)
@@ -608,9 +615,9 @@ class ASTParser:
check_node = sub
break
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration", "field_declaration")
if is_interesting:
node_name = self._get_name(check_node, code)
node_name = self._get_name(check_node, code_bytes)
if node_name == target:
if len(target_parts) == 1:
return check_node if child.type != "field_declaration" else child
@@ -621,14 +628,14 @@ class ASTParser:
body = check_node.child_by_field_name("body")
if not body and check_node.type == "template_declaration":
for sub in check_node.children:
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition"):
body = sub.child_by_field_name("body")
break
if body:
found = walk(body, next_parts)
if found: return found
for sub in check_node.children:
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
if sub.type in ("field_declaration_list", "class_body", "declaration_list", "enum_body"):
found = walk(sub, next_parts)
if found: return found
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
@@ -637,8 +644,8 @@ class ASTParser:
return None
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
if self._get_name(node, code) == target:
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration"):
if self._get_name(node, code_bytes) == target:
return node
for child in node.children:
res = deep_search(child, target)
@@ -650,9 +657,9 @@ class ASTParser:
found_node = deep_search(tree.root_node, name)
if found_node:
code_bytes = bytearray(code, "utf8")
code_bytes[found_node.start_byte:found_node.end_byte] = bytes(new_content, "utf8")
return code_bytes.decode("utf8")
code_bytearray = bytearray(code_bytes)
code_bytearray[found_node.start_byte:found_node.end_byte] = bytes(new_content, "utf8")
return code_bytearray.decode("utf8")
return f"ERROR: definition '{name}' not found"
def reset_client() -> None: