feat(mcp): Validate C++ tools against real-world gencpp components and improve enum support
This commit is contained in:
+58
-51
@@ -94,15 +94,15 @@ class ASTParser:
|
||||
_ast_cache[path] = (mtime, tree)
|
||||
return tree
|
||||
|
||||
def _get_name(self, node: tree_sitter.Node, code: str) -> str:
|
||||
def _get_name(self, node: tree_sitter.Node, code_bytes: bytes) -> str:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
return code[name_node.start_byte:name_node.end_byte]
|
||||
return code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace")
|
||||
|
||||
if node.type in ("function_definition", "field_declaration"):
|
||||
def find_id(n: tree_sitter.Node) -> str:
|
||||
if n.type in ("identifier", "field_identifier", "qualified_identifier", "destructor_name"):
|
||||
return code[n.start_byte:n.end_byte]
|
||||
return code_bytes[n.start_byte:n.end_byte].decode("utf8", errors="replace")
|
||||
# Try field name 'declarator' first
|
||||
d = n.child_by_field_name("declarator")
|
||||
if d:
|
||||
@@ -118,19 +118,20 @@ class ASTParser:
|
||||
|
||||
if node.type == "template_declaration":
|
||||
for child in node.children:
|
||||
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "field_declaration"):
|
||||
return self._get_name(child, code)
|
||||
if child.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "field_declaration"):
|
||||
return self._get_name(child, code_bytes)
|
||||
|
||||
if node.type in ("struct_specifier", "class_specifier", "class_definition", "namespace_definition"):
|
||||
if node.type in ("struct_specifier", "class_specifier", "class_definition", "enum_specifier", "enum_definition", "namespace_definition"):
|
||||
for child in node.children:
|
||||
if child.type in ("type_identifier", "identifier", "namespace_identifier"):
|
||||
return code[child.start_byte:child.end_byte]
|
||||
return code_bytes[child.start_byte:child.end_byte].decode("utf8", errors="replace")
|
||||
return ""
|
||||
|
||||
def get_skeleton(self, code: str, path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Returns a skeleton of a Python file (preserving docstrings, stripping function bodies).
|
||||
"""
|
||||
code_bytes = code.encode("utf8")
|
||||
tree = self.get_cached_tree(path, code)
|
||||
edits: List[Tuple[int, int, str]] = []
|
||||
|
||||
@@ -170,10 +171,10 @@ class ASTParser:
|
||||
walk(tree.root_node)
|
||||
# Apply edits in reverse to maintain byte offsets
|
||||
edits.sort(key=lambda x: x[0], reverse=True)
|
||||
code_bytes = bytearray(code, "utf8")
|
||||
code_bytearray = bytearray(code_bytes)
|
||||
for start, end, replacement in edits:
|
||||
code_bytes[start:end] = bytes(replacement, "utf8")
|
||||
return code_bytes.decode("utf8")
|
||||
code_bytearray[start:end] = bytes(replacement, "utf8")
|
||||
return code_bytearray.decode("utf8")
|
||||
|
||||
def get_curated_view(self, code: str, path: Optional[str] = None) -> str:
|
||||
"""
|
||||
@@ -181,6 +182,7 @@ class ASTParser:
|
||||
Preserves function bodies if they have @core_logic decorator or # [HOT] comment.
|
||||
Otherwise strips bodies but preserves docstrings.
|
||||
"""
|
||||
code_bytes = code.encode("utf8")
|
||||
tree = self.get_cached_tree(path, code)
|
||||
edits: List[Tuple[int, int, str]] = []
|
||||
|
||||
@@ -197,7 +199,7 @@ class ASTParser:
|
||||
for child in parent.children:
|
||||
if child.type == "decorator":
|
||||
# decorator -> ( '@', identifier ) or ( '@', call )
|
||||
if "@core_logic" in code[child.start_byte:child.end_byte]:
|
||||
if b"@core_logic" in code_bytes[child.start_byte:child.end_byte]:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -207,8 +209,8 @@ class ASTParser:
|
||||
while stack:
|
||||
curr = stack.pop()
|
||||
if curr.type == "comment":
|
||||
comment_text = code[curr.start_byte:curr.end_byte]
|
||||
if "[HOT]" in comment_text:
|
||||
comment_bytes = code_bytes[curr.start_byte:curr.end_byte]
|
||||
if b"[HOT]" in comment_bytes:
|
||||
return True
|
||||
for child in curr.children:
|
||||
stack.append(child)
|
||||
@@ -241,16 +243,17 @@ class ASTParser:
|
||||
walk(tree.root_node)
|
||||
# Apply edits in reverse to maintain byte offsets
|
||||
edits.sort(key=lambda x: x[0], reverse=True)
|
||||
code_bytes = bytearray(code, "utf8")
|
||||
code_bytearray = bytearray(code_bytes)
|
||||
for start, end, replacement in edits:
|
||||
code_bytes[start:end] = bytes(replacement, "utf8")
|
||||
return code_bytes.decode("utf8")
|
||||
code_bytearray[start:end] = bytes(replacement, "utf8")
|
||||
return code_bytearray.decode("utf8")
|
||||
|
||||
def get_targeted_view(self, code: str, function_names: List[str], path: Optional[str] = None) -> str:
|
||||
"""
|
||||
Returns a targeted view of the code including only the specified functions
|
||||
and their dependencies up to depth 2.
|
||||
"""
|
||||
code_bytes = code.encode("utf8")
|
||||
tree = self.get_cached_tree(path, code)
|
||||
all_functions = {}
|
||||
|
||||
@@ -258,13 +261,13 @@ class ASTParser:
|
||||
if node.type == "function_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
func_name = code[name_node.start_byte:name_node.end_byte]
|
||||
func_name = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace")
|
||||
full_name = f"{class_name}.{func_name}" if class_name else func_name
|
||||
all_functions[full_name] = node
|
||||
elif node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
cname = code[name_node.start_byte:name_node.end_byte]
|
||||
cname = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace")
|
||||
full_cname = f"{class_name}.{cname}" if class_name else cname
|
||||
body = node.child_by_field_name("body")
|
||||
if body:
|
||||
@@ -282,11 +285,11 @@ class ASTParser:
|
||||
func_node = n.child_by_field_name("function")
|
||||
if func_node:
|
||||
if func_node.type == "identifier":
|
||||
calls.add(code[func_node.start_byte:func_node.end_byte])
|
||||
calls.add(code_bytes[func_node.start_byte:func_node.end_byte].decode("utf8", errors="replace"))
|
||||
elif func_node.type == "attribute":
|
||||
attr_node = func_node.child_by_field_name("attribute")
|
||||
if attr_node:
|
||||
calls.add(code[attr_node.start_byte:attr_node.end_byte])
|
||||
calls.add(code_bytes[attr_node.start_byte:attr_node.end_byte].decode("utf8", errors="replace"))
|
||||
for child in n.children:
|
||||
walk_calls(child)
|
||||
walk_calls(node)
|
||||
@@ -329,12 +332,12 @@ class ASTParser:
|
||||
def check_for_targeted(node, parent_class=None):
|
||||
if node.type == "function_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
fname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||
fname = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace") if name_node else ""
|
||||
fullname = f"{parent_class}.{fname}" if parent_class else fname
|
||||
return fullname in all_found
|
||||
if node.type == "class_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
cname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||
cname = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace") if name_node else ""
|
||||
full_cname = f"{parent_class}.{cname}" if parent_class else cname
|
||||
body = node.child_by_field_name("body")
|
||||
if body:
|
||||
@@ -350,7 +353,7 @@ class ASTParser:
|
||||
def walk_edits(node, parent_class=None):
|
||||
if node.type == "function_definition":
|
||||
name_node = node.child_by_field_name("name")
|
||||
fname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||
fname = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace") if name_node else ""
|
||||
fullname = f"{parent_class}.{fname}" if parent_class else fname
|
||||
if fullname in all_found:
|
||||
body = node.child_by_field_name("body")
|
||||
@@ -376,7 +379,7 @@ class ASTParser:
|
||||
if node.type == "class_definition":
|
||||
if check_for_targeted(node, parent_class):
|
||||
name_node = node.child_by_field_name("name")
|
||||
cname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||
cname = code_bytes[name_node.start_byte:name_node.end_byte].decode("utf8", errors="replace") if name_node else ""
|
||||
full_cname = f"{parent_class}.{cname}" if parent_class else cname
|
||||
body = node.child_by_field_name("body")
|
||||
if body:
|
||||
@@ -400,10 +403,10 @@ class ASTParser:
|
||||
|
||||
walk_edits(tree.root_node)
|
||||
edits.sort(key=lambda x: x[0], reverse=True)
|
||||
code_bytes = bytearray(code, "utf8")
|
||||
code_bytearray = bytearray(code_bytes)
|
||||
for start, end, replacement in edits:
|
||||
code_bytes[start:end] = bytes(replacement, "utf8")
|
||||
result = code_bytes.decode("utf8")
|
||||
code_bytearray[start:end] = bytes(replacement, "utf8")
|
||||
result = code_bytearray.decode("utf8")
|
||||
result = re.sub(r'\n\s*\n\s*\n+', '\n\n', result)
|
||||
return result.strip() + "\n"
|
||||
|
||||
@@ -412,6 +415,7 @@ class ASTParser:
|
||||
Returns the full source code for a specific definition by name.
|
||||
Supports 'ClassName::method' or 'method' for C++.
|
||||
"""
|
||||
code_bytes = code.encode("utf8")
|
||||
tree = self.get_cached_tree(path, code)
|
||||
|
||||
parts = re.split(r'::|\.', name)
|
||||
@@ -429,9 +433,9 @@ class ASTParser:
|
||||
check_node = sub
|
||||
break
|
||||
|
||||
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
|
||||
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration", "field_declaration")
|
||||
if is_interesting:
|
||||
node_name = self._get_name(check_node, code)
|
||||
node_name = self._get_name(check_node, code_bytes)
|
||||
if node_name == target:
|
||||
if len(target_parts) == 1:
|
||||
return check_node if child.type != "field_declaration" else child
|
||||
@@ -442,14 +446,14 @@ class ASTParser:
|
||||
body = check_node.child_by_field_name("body")
|
||||
if not body and check_node.type == "template_declaration":
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
|
||||
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition"):
|
||||
body = sub.child_by_field_name("body")
|
||||
break
|
||||
if body:
|
||||
found = walk(body, next_parts)
|
||||
if found: return found
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list", "enum_body"):
|
||||
found = walk(sub, next_parts)
|
||||
if found: return found
|
||||
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
|
||||
@@ -458,8 +462,8 @@ class ASTParser:
|
||||
return None
|
||||
|
||||
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
if self._get_name(node, code) == target:
|
||||
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration"):
|
||||
if self._get_name(node, code_bytes) == target:
|
||||
return node
|
||||
for child in node.children:
|
||||
res = deep_search(child, target)
|
||||
@@ -471,7 +475,7 @@ class ASTParser:
|
||||
found_node = deep_search(tree.root_node, name)
|
||||
|
||||
if found_node:
|
||||
return code[found_node.start_byte:found_node.end_byte]
|
||||
return code_bytes[found_node.start_byte:found_node.end_byte].decode("utf8", errors="replace")
|
||||
return f"ERROR: definition '{name}' not found"
|
||||
|
||||
def get_signature(self, code: str, name: str, path: Optional[str] = None) -> str:
|
||||
@@ -479,6 +483,7 @@ class ASTParser:
|
||||
Returns only the signature part of a function or method.
|
||||
For C/C++, this is the code from the start of the definition until the block start '{'.
|
||||
"""
|
||||
code_bytes = code.encode("utf8")
|
||||
tree = self.get_cached_tree(path, code)
|
||||
|
||||
parts = re.split(r'::|\.', name)
|
||||
@@ -496,9 +501,9 @@ class ASTParser:
|
||||
check_node = sub
|
||||
break
|
||||
|
||||
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
|
||||
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration", "field_declaration")
|
||||
if is_interesting:
|
||||
node_name = self._get_name(check_node, code)
|
||||
node_name = self._get_name(check_node, code_bytes)
|
||||
if node_name == target:
|
||||
if len(target_parts) == 1:
|
||||
return check_node if child.type != "field_declaration" else child
|
||||
@@ -509,14 +514,14 @@ class ASTParser:
|
||||
body = check_node.child_by_field_name("body")
|
||||
if not body and check_node.type == "template_declaration":
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
|
||||
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition"):
|
||||
body = sub.child_by_field_name("body")
|
||||
break
|
||||
if body:
|
||||
found = walk(body, next_parts)
|
||||
if found: return found
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list", "enum_body"):
|
||||
found = walk(sub, next_parts)
|
||||
if found: return found
|
||||
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
|
||||
@@ -526,7 +531,7 @@ class ASTParser:
|
||||
|
||||
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||
if node.type in ("function_definition", "template_declaration"):
|
||||
if self._get_name(node, code) == target:
|
||||
if self._get_name(node, code_bytes) == target:
|
||||
return node
|
||||
for child in node.children:
|
||||
res = deep_search(child, target)
|
||||
@@ -547,8 +552,8 @@ class ASTParser:
|
||||
|
||||
body = target_node.child_by_field_name("body")
|
||||
if body:
|
||||
return code[found_node.start_byte:body.start_byte].strip()
|
||||
return code[found_node.start_byte:found_node.end_byte].strip()
|
||||
return code_bytes[found_node.start_byte:body.start_byte].decode("utf8", errors="replace").strip()
|
||||
return code_bytes[found_node.start_byte:found_node.end_byte].decode("utf8", errors="replace").strip()
|
||||
|
||||
return f"ERROR: signature for '{name}' not found"
|
||||
|
||||
@@ -556,6 +561,7 @@ class ASTParser:
|
||||
"""
|
||||
Returns a hierarchical outline of the code (classes, structs, functions, methods).
|
||||
"""
|
||||
code_bytes = code.encode("utf8")
|
||||
tree = self.get_cached_tree(path, code)
|
||||
output = []
|
||||
|
||||
@@ -570,7 +576,7 @@ class ASTParser:
|
||||
label = "[Method]" if indent > 0 else "[Func]"
|
||||
|
||||
if label:
|
||||
name = self._get_name(node, code)
|
||||
name = self._get_name(node, code_bytes)
|
||||
if name:
|
||||
start = node.start_point.row + 1
|
||||
end = node.end_point.row + 1
|
||||
@@ -591,6 +597,7 @@ class ASTParser:
|
||||
"""
|
||||
Surgically replace the definition of a class or function by name.
|
||||
"""
|
||||
code_bytes = code.encode("utf8")
|
||||
tree = self.get_cached_tree(path, code)
|
||||
|
||||
parts = re.split(r'::|\.', name)
|
||||
@@ -608,9 +615,9 @@ class ASTParser:
|
||||
check_node = sub
|
||||
break
|
||||
|
||||
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration", "field_declaration")
|
||||
is_interesting = check_node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration", "field_declaration")
|
||||
if is_interesting:
|
||||
node_name = self._get_name(check_node, code)
|
||||
node_name = self._get_name(check_node, code_bytes)
|
||||
if node_name == target:
|
||||
if len(target_parts) == 1:
|
||||
return check_node if child.type != "field_declaration" else child
|
||||
@@ -621,14 +628,14 @@ class ASTParser:
|
||||
body = check_node.child_by_field_name("body")
|
||||
if not body and check_node.type == "template_declaration":
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier"):
|
||||
if sub.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition"):
|
||||
body = sub.child_by_field_name("body")
|
||||
break
|
||||
if body:
|
||||
found = walk(body, next_parts)
|
||||
if found: return found
|
||||
for sub in check_node.children:
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list"):
|
||||
if sub.type in ("field_declaration_list", "class_body", "declaration_list", "enum_body"):
|
||||
found = walk(sub, next_parts)
|
||||
if found: return found
|
||||
elif child.type in ("module", "translation_unit", "namespace_definition", "declaration_list", "field_declaration_list", "class_body"):
|
||||
@@ -637,8 +644,8 @@ class ASTParser:
|
||||
return None
|
||||
|
||||
def deep_search(node: tree_sitter.Node, target: str) -> Optional[tree_sitter.Node]:
|
||||
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "namespace_definition", "template_declaration"):
|
||||
if self._get_name(node, code) == target:
|
||||
if node.type in ("function_definition", "class_definition", "class_specifier", "struct_specifier", "enum_specifier", "enum_definition", "namespace_definition", "template_declaration"):
|
||||
if self._get_name(node, code_bytes) == target:
|
||||
return node
|
||||
for child in node.children:
|
||||
res = deep_search(child, target)
|
||||
@@ -650,9 +657,9 @@ class ASTParser:
|
||||
found_node = deep_search(tree.root_node, name)
|
||||
|
||||
if found_node:
|
||||
code_bytes = bytearray(code, "utf8")
|
||||
code_bytes[found_node.start_byte:found_node.end_byte] = bytes(new_content, "utf8")
|
||||
return code_bytes.decode("utf8")
|
||||
code_bytearray = bytearray(code_bytes)
|
||||
code_bytearray[found_node.start_byte:found_node.end_byte] = bytes(new_content, "utf8")
|
||||
return code_bytearray.decode("utf8")
|
||||
return f"ERROR: definition '{name}' not found"
|
||||
|
||||
def reset_client() -> None:
|
||||
|
||||
Reference in New Issue
Block a user