feat(mma): Implement Deep AST-Driven Context Pruning and mark track complete
This commit is contained in:
@@ -20,7 +20,7 @@ This file tracks all major tracks for the project. Each track has its own detail
|
|||||||
1. [x] **Track: True Parallel Worker Execution (The DAG Realization)**
|
1. [x] **Track: True Parallel Worker Execution (The DAG Realization)**
|
||||||
*Link: [./tracks/true_parallel_worker_execution_20260306/](./tracks/true_parallel_worker_execution_20260306/)*
|
*Link: [./tracks/true_parallel_worker_execution_20260306/](./tracks/true_parallel_worker_execution_20260306/)*
|
||||||
|
|
||||||
2. [ ] **Track: Deep AST-Driven Context Pruning (RAG for Code)**
|
2. [x] **Track: Deep AST-Driven Context Pruning (RAG for Code)**
|
||||||
*Link: [./tracks/deep_ast_context_pruning_20260306/](./tracks/deep_ast_context_pruning_20260306/)*
|
*Link: [./tracks/deep_ast_context_pruning_20260306/](./tracks/deep_ast_context_pruning_20260306/)*
|
||||||
|
|
||||||
3. [ ] **Track: Visual DAG & Interactive Ticket Editing**
|
3. [ ] **Track: Visual DAG & Interactive Ticket Editing**
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ from pathlib import Path
|
|||||||
from typing import Optional, Any, List, Tuple, Dict
|
from typing import Optional, Any, List, Tuple, Dict
|
||||||
import tree_sitter
|
import tree_sitter
|
||||||
import tree_sitter_python
|
import tree_sitter_python
|
||||||
|
import re
|
||||||
|
|
||||||
|
_ast_cache: Dict[str, Tuple[float, tree_sitter.Tree]] = {}
|
||||||
|
|
||||||
class ASTParser:
|
class ASTParser:
|
||||||
"""
|
"""
|
||||||
@@ -28,11 +31,38 @@ class ASTParser:
|
|||||||
"""Parse the given code and return the tree-sitter Tree."""
|
"""Parse the given code and return the tree-sitter Tree."""
|
||||||
return self.parser.parse(bytes(code, "utf8"))
|
return self.parser.parse(bytes(code, "utf8"))
|
||||||
|
|
||||||
def get_skeleton(self, code: str) -> str:
|
def get_cached_tree(self, path: Optional[str], code: str) -> tree_sitter.Tree:
|
||||||
|
"""Get cached tree or parse and cache it."""
|
||||||
|
if not path:
|
||||||
|
return self.parse(code)
|
||||||
|
|
||||||
|
try:
|
||||||
|
p = Path(path)
|
||||||
|
mtime = p.stat().st_mtime if p.exists() else 0.0
|
||||||
|
except Exception:
|
||||||
|
mtime = 0.0
|
||||||
|
|
||||||
|
if path in _ast_cache:
|
||||||
|
cached_mtime, tree = _ast_cache[path]
|
||||||
|
if cached_mtime == mtime:
|
||||||
|
return tree
|
||||||
|
|
||||||
|
tree = self.parse(code)
|
||||||
|
if len(_ast_cache) >= 10:
|
||||||
|
# Simple LRU: remove the first added entry
|
||||||
|
try:
|
||||||
|
first_key = next(iter(_ast_cache))
|
||||||
|
del _ast_cache[first_key]
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
_ast_cache[path] = (mtime, tree)
|
||||||
|
return tree
|
||||||
|
|
||||||
|
def get_skeleton(self, code: str, path: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Returns a skeleton of a Python file (preserving docstrings, stripping function bodies).
|
Returns a skeleton of a Python file (preserving docstrings, stripping function bodies).
|
||||||
"""
|
"""
|
||||||
tree = self.parse(code)
|
tree = self.get_cached_tree(path, code)
|
||||||
edits: List[Tuple[int, int, str]] = []
|
edits: List[Tuple[int, int, str]] = []
|
||||||
|
|
||||||
def is_docstring(node: tree_sitter.Node) -> bool:
|
def is_docstring(node: tree_sitter.Node) -> bool:
|
||||||
@@ -70,13 +100,13 @@ class ASTParser:
|
|||||||
code_bytes[start:end] = bytes(replacement, "utf8")
|
code_bytes[start:end] = bytes(replacement, "utf8")
|
||||||
return code_bytes.decode("utf8")
|
return code_bytes.decode("utf8")
|
||||||
|
|
||||||
def get_curated_view(self, code: str) -> str:
|
def get_curated_view(self, code: str, path: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Returns a curated skeleton of a Python file.
|
Returns a curated skeleton of a Python file.
|
||||||
Preserves function bodies if they have @core_logic decorator or # [HOT] comment.
|
Preserves function bodies if they have @core_logic decorator or # [HOT] comment.
|
||||||
Otherwise strips bodies but preserves docstrings.
|
Otherwise strips bodies but preserves docstrings.
|
||||||
"""
|
"""
|
||||||
tree = self.parse(code)
|
tree = self.get_cached_tree(path, code)
|
||||||
edits: List[Tuple[int, int, str]] = []
|
edits: List[Tuple[int, int, str]] = []
|
||||||
|
|
||||||
def is_docstring(node: tree_sitter.Node) -> bool:
|
def is_docstring(node: tree_sitter.Node) -> bool:
|
||||||
@@ -141,6 +171,167 @@ class ASTParser:
|
|||||||
code_bytes[start:end] = bytes(replacement, "utf8")
|
code_bytes[start:end] = bytes(replacement, "utf8")
|
||||||
return code_bytes.decode("utf8")
|
return code_bytes.decode("utf8")
|
||||||
|
|
||||||
|
def get_targeted_view(self, code: str, function_names: List[str], path: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Returns a targeted view of the code including only the specified functions
|
||||||
|
and their dependencies up to depth 2.
|
||||||
|
"""
|
||||||
|
tree = self.get_cached_tree(path, code)
|
||||||
|
all_functions = {}
|
||||||
|
|
||||||
|
def collect_functions(node, class_name=None):
|
||||||
|
if node.type == "function_definition":
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node:
|
||||||
|
func_name = code[name_node.start_byte:name_node.end_byte]
|
||||||
|
full_name = f"{class_name}.{func_name}" if class_name else func_name
|
||||||
|
all_functions[full_name] = node
|
||||||
|
elif node.type == "class_definition":
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
if name_node:
|
||||||
|
cname = code[name_node.start_byte:name_node.end_byte]
|
||||||
|
full_cname = f"{class_name}.{cname}" if class_name else cname
|
||||||
|
body = node.child_by_field_name("body")
|
||||||
|
if body:
|
||||||
|
collect_functions(body, full_cname)
|
||||||
|
return
|
||||||
|
for child in node.children:
|
||||||
|
collect_functions(child, class_name)
|
||||||
|
|
||||||
|
collect_functions(tree.root_node)
|
||||||
|
|
||||||
|
def get_calls(node):
|
||||||
|
calls = set()
|
||||||
|
def walk_calls(n):
|
||||||
|
if n.type == "call":
|
||||||
|
func_node = n.child_by_field_name("function")
|
||||||
|
if func_node:
|
||||||
|
if func_node.type == "identifier":
|
||||||
|
calls.add(code[func_node.start_byte:func_node.end_byte])
|
||||||
|
elif func_node.type == "attribute":
|
||||||
|
attr_node = func_node.child_by_field_name("attribute")
|
||||||
|
if attr_node:
|
||||||
|
calls.add(code[attr_node.start_byte:attr_node.end_byte])
|
||||||
|
for child in n.children:
|
||||||
|
walk_calls(child)
|
||||||
|
walk_calls(node)
|
||||||
|
return calls
|
||||||
|
|
||||||
|
to_include = set()
|
||||||
|
for target in function_names:
|
||||||
|
if target in all_functions:
|
||||||
|
to_include.add(target)
|
||||||
|
else:
|
||||||
|
for full_name in all_functions:
|
||||||
|
if full_name.split('.')[-1] == target:
|
||||||
|
to_include.add(full_name)
|
||||||
|
|
||||||
|
current_layer = set(to_include)
|
||||||
|
all_found = set(to_include)
|
||||||
|
for _ in range(2):
|
||||||
|
next_layer = set()
|
||||||
|
for name in current_layer:
|
||||||
|
if name in all_functions:
|
||||||
|
node = all_functions[name]
|
||||||
|
calls = get_calls(node)
|
||||||
|
for call in calls:
|
||||||
|
for func_name in all_functions:
|
||||||
|
if func_name == call or func_name.split('.')[-1] == call:
|
||||||
|
if func_name not in all_found:
|
||||||
|
next_layer.add(func_name)
|
||||||
|
all_found.add(func_name)
|
||||||
|
current_layer = next_layer
|
||||||
|
if not current_layer:
|
||||||
|
break
|
||||||
|
|
||||||
|
edits = []
|
||||||
|
def is_docstring(n) -> bool:
|
||||||
|
if n.type == "expression_statement" and n.child_count > 0:
|
||||||
|
if n.children[0].type == "string":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_for_targeted(node, parent_class=None):
|
||||||
|
if node.type == "function_definition":
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
fname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||||
|
fullname = f"{parent_class}.{fname}" if parent_class else fname
|
||||||
|
return fullname in all_found
|
||||||
|
if node.type == "class_definition":
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
cname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||||
|
full_cname = f"{parent_class}.{cname}" if parent_class else cname
|
||||||
|
body = node.child_by_field_name("body")
|
||||||
|
if body:
|
||||||
|
for child in body.children:
|
||||||
|
if check_for_targeted(child, full_cname):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
for child in node.children:
|
||||||
|
if check_for_targeted(child, parent_class):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def walk_edits(node, parent_class=None):
|
||||||
|
if node.type == "function_definition":
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
fname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||||
|
fullname = f"{parent_class}.{fname}" if parent_class else fname
|
||||||
|
if fullname in all_found:
|
||||||
|
body = node.child_by_field_name("body")
|
||||||
|
if body and body.type == "block":
|
||||||
|
indent = " " * body.start_point.column
|
||||||
|
first_stmt = None
|
||||||
|
for child in body.children:
|
||||||
|
if child.type != "comment":
|
||||||
|
first_stmt = child
|
||||||
|
break
|
||||||
|
if first_stmt and is_docstring(first_stmt):
|
||||||
|
start_byte = first_stmt.end_byte
|
||||||
|
end_byte = body.end_byte
|
||||||
|
if end_byte > start_byte:
|
||||||
|
edits.append((start_byte, end_byte, f"\n{indent}..."))
|
||||||
|
else:
|
||||||
|
start_byte = body.start_byte
|
||||||
|
end_byte = body.end_byte
|
||||||
|
edits.append((start_byte, end_byte, "..."))
|
||||||
|
else:
|
||||||
|
edits.append((node.start_byte, node.end_byte, ""))
|
||||||
|
return
|
||||||
|
if node.type == "class_definition":
|
||||||
|
if check_for_targeted(node, parent_class):
|
||||||
|
name_node = node.child_by_field_name("name")
|
||||||
|
cname = code[name_node.start_byte:name_node.end_byte] if name_node else ""
|
||||||
|
full_cname = f"{parent_class}.{cname}" if parent_class else cname
|
||||||
|
body = node.child_by_field_name("body")
|
||||||
|
if body:
|
||||||
|
for child in body.children:
|
||||||
|
walk_edits(child, full_cname)
|
||||||
|
else:
|
||||||
|
edits.append((node.start_byte, node.end_byte, ""))
|
||||||
|
return
|
||||||
|
if node.type in ("import_statement", "import_from_statement"):
|
||||||
|
return
|
||||||
|
if node.type == "module":
|
||||||
|
for child in node.children:
|
||||||
|
walk_edits(child, parent_class)
|
||||||
|
else:
|
||||||
|
if node.parent and node.parent.type == "module":
|
||||||
|
if node.type not in ("comment",):
|
||||||
|
edits.append((node.start_byte, node.end_byte, ""))
|
||||||
|
else:
|
||||||
|
for child in node.children:
|
||||||
|
walk_edits(child, parent_class)
|
||||||
|
|
||||||
|
walk_edits(tree.root_node)
|
||||||
|
edits.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
code_bytes = bytearray(code, "utf8")
|
||||||
|
for start, end, replacement in edits:
|
||||||
|
code_bytes[start:end] = bytes(replacement, "utf8")
|
||||||
|
result = code_bytes.decode("utf8")
|
||||||
|
result = re.sub(r'\n\s*\n\s*\n+', '\n\n', result)
|
||||||
|
return result.strip() + "\n"
|
||||||
|
|
||||||
def reset_client() -> None:
|
def reset_client() -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -88,12 +88,14 @@ CONSTRAINTS:
|
|||||||
OUTPUT REQUIREMENT:
|
OUTPUT REQUIREMENT:
|
||||||
Return a JSON array of 'Tickets' in Godot ECS Flat List format.
|
Return a JSON array of 'Tickets' in Godot ECS Flat List format.
|
||||||
Include 'depends_on' pointers to construct an execution DAG (Directed Acyclic Graph).
|
Include 'depends_on' pointers to construct an execution DAG (Directed Acyclic Graph).
|
||||||
|
Include 'target_symbols' (list of strings) to specify which functions or classes should be extracted if using a targeted context view.
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id": "ticket_id",
|
"id": "ticket_id",
|
||||||
"type": "Ticket",
|
"type": "Ticket",
|
||||||
"goal": "Surgical implementation task",
|
"goal": "Surgical implementation task",
|
||||||
"target_file": "path/to/file",
|
"target_file": "path/to/file",
|
||||||
|
"target_symbols": ["function_name", "ClassName.method_name"],
|
||||||
"depends_on": ["other_ticket_id"],
|
"depends_on": ["other_ticket_id"],
|
||||||
"context_requirements": ["list_of_needed_skeletons"]
|
"context_requirements": ["list_of_needed_skeletons"]
|
||||||
},
|
},
|
||||||
|
|||||||
167
src/models.py
167
src/models.py
@@ -1,59 +1,35 @@
|
|||||||
from dataclasses import dataclass, field
|
from __future__ import annotations
|
||||||
from typing import List, Optional, Dict, Any
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
import os
|
|
||||||
import tomllib
|
import tomllib
|
||||||
from src import project_manager
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional, Dict, Any, Union
|
||||||
CONFIG_PATH: Path = Path(os.environ.get("SLOP_CONFIG", "config.toml"))
|
from pathlib import Path
|
||||||
DISC_ROLES: list[str] = ["User", "AI", "Vendor API", "System"]
|
|
||||||
AGENT_TOOL_NAMES: list[str] = [
|
|
||||||
"run_powershell",
|
|
||||||
"read_file",
|
|
||||||
"list_directory",
|
|
||||||
"search_files",
|
|
||||||
"get_file_summary",
|
|
||||||
"web_search",
|
|
||||||
"fetch_url",
|
|
||||||
"py_get_skeleton",
|
|
||||||
"py_get_code_outline",
|
|
||||||
"get_file_slice",
|
|
||||||
"py_get_definition",
|
|
||||||
"py_get_signature",
|
|
||||||
"py_get_class_summary",
|
|
||||||
"py_get_var_declaration",
|
|
||||||
"get_git_diff",
|
|
||||||
"py_find_usages",
|
|
||||||
"py_get_imports",
|
|
||||||
"py_check_syntax",
|
|
||||||
"py_get_hierarchy",
|
|
||||||
"py_get_docstring",
|
|
||||||
"get_tree",
|
|
||||||
"get_ui_performance",
|
|
||||||
# Mutating tools — disabled by default
|
|
||||||
"set_file_slice",
|
|
||||||
"py_update_definition",
|
|
||||||
"py_set_signature",
|
|
||||||
"py_set_var_declaration",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
CONFIG_PATH = Path("config.toml")
|
||||||
|
|
||||||
def load_config() -> dict[str, Any]:
|
def load_config() -> dict[str, Any]:
|
||||||
with open(CONFIG_PATH, "rb") as f:
|
with open(CONFIG_PATH, "rb") as f:
|
||||||
return tomllib.load(f)
|
return tomllib.load(f)
|
||||||
|
|
||||||
|
# Global constants for agent tools
|
||||||
def parse_history_entries(
|
AGENT_TOOL_NAMES = [
|
||||||
history: list[str], roles: list[str] | None = None
|
"read_file",
|
||||||
) -> list[dict[str, Any]]:
|
"list_directory",
|
||||||
known = roles if roles is not None else DISC_ROLES
|
"search_files",
|
||||||
entries = []
|
"web_search",
|
||||||
for raw in history:
|
"fetch_url",
|
||||||
entry = project_manager.str_to_entry(raw, known)
|
"get_file_summary",
|
||||||
entries.append(entry)
|
"py_get_skeleton",
|
||||||
return entries
|
"py_get_code_outline",
|
||||||
|
"py_get_definition",
|
||||||
|
"py_get_signature",
|
||||||
|
"py_get_class_summary",
|
||||||
|
"py_get_var_declaration",
|
||||||
|
"py_get_docstring",
|
||||||
|
"py_find_usages",
|
||||||
|
"py_get_imports",
|
||||||
|
"py_check_syntax",
|
||||||
|
"py_get_hierarchy"
|
||||||
|
]
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Ticket:
|
class Ticket:
|
||||||
@@ -66,6 +42,7 @@ class Ticket:
|
|||||||
status: str = "todo"
|
status: str = "todo"
|
||||||
assigned_to: str = "unassigned"
|
assigned_to: str = "unassigned"
|
||||||
target_file: Optional[str] = None
|
target_file: Optional[str] = None
|
||||||
|
target_symbols: List[str] = field(default_factory=list)
|
||||||
context_requirements: List[str] = field(default_factory=list)
|
context_requirements: List[str] = field(default_factory=list)
|
||||||
depends_on: List[str] = field(default_factory=list)
|
depends_on: List[str] = field(default_factory=list)
|
||||||
blocked_reason: Optional[str] = None
|
blocked_reason: Optional[str] = None
|
||||||
@@ -92,6 +69,7 @@ class Ticket:
|
|||||||
"status": self.status,
|
"status": self.status,
|
||||||
"assigned_to": self.assigned_to,
|
"assigned_to": self.assigned_to,
|
||||||
"target_file": self.target_file,
|
"target_file": self.target_file,
|
||||||
|
"target_symbols": self.target_symbols,
|
||||||
"context_requirements": self.context_requirements,
|
"context_requirements": self.context_requirements,
|
||||||
"depends_on": self.depends_on,
|
"depends_on": self.depends_on,
|
||||||
"blocked_reason": self.blocked_reason,
|
"blocked_reason": self.blocked_reason,
|
||||||
@@ -107,6 +85,7 @@ class Ticket:
|
|||||||
status=data.get("status", "todo"),
|
status=data.get("status", "todo"),
|
||||||
assigned_to=data.get("assigned_to", ""),
|
assigned_to=data.get("assigned_to", ""),
|
||||||
target_file=data.get("target_file"),
|
target_file=data.get("target_file"),
|
||||||
|
target_symbols=data.get("target_symbols", []),
|
||||||
context_requirements=data.get("context_requirements", []),
|
context_requirements=data.get("context_requirements", []),
|
||||||
depends_on=data.get("depends_on", []),
|
depends_on=data.get("depends_on", []),
|
||||||
blocked_reason=data.get("blocked_reason"),
|
blocked_reason=data.get("blocked_reason"),
|
||||||
@@ -125,104 +104,78 @@ class Track:
|
|||||||
description: str
|
description: str
|
||||||
tickets: List[Ticket] = field(default_factory=list)
|
tickets: List[Ticket] = field(default_factory=list)
|
||||||
|
|
||||||
def get_executable_tickets(self) -> List[Ticket]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""
|
return {
|
||||||
Returns all 'todo' tickets whose dependencies are all 'completed'.
|
"id": self.id,
|
||||||
"""
|
"description": self.description,
|
||||||
# Map ticket IDs to their current status for efficient lookup
|
"tickets": [t.to_dict() for t in self.tickets],
|
||||||
status_map = {t.id: t.status for t in self.tickets}
|
}
|
||||||
executable = []
|
|
||||||
for ticket in self.tickets:
|
@classmethod
|
||||||
if ticket.status != "todo":
|
def from_dict(cls, data: Dict[str, Any]) -> "Track":
|
||||||
continue
|
return cls(
|
||||||
# Check if all dependencies are completed
|
id=data["id"],
|
||||||
all_deps_completed = True
|
description=data.get("description", ""),
|
||||||
for dep_id in ticket.depends_on:
|
tickets=[Ticket.from_dict(t) for t in data.get("tickets", [])],
|
||||||
# If a dependency is missing from the track, we treat it as not completed (or we could raise an error)
|
)
|
||||||
if status_map.get(dep_id) != "completed":
|
|
||||||
all_deps_completed = False
|
|
||||||
break
|
|
||||||
if all_deps_completed:
|
|
||||||
executable.append(ticket)
|
|
||||||
return executable
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WorkerContext:
|
class WorkerContext:
|
||||||
"""
|
"""
|
||||||
Represents the context provided to a Tier 3 Worker for a specific ticket.
|
State preserved for a specific worker throughout its ticket lifecycle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ticket_id: str
|
ticket_id: str
|
||||||
model_name: str
|
model_name: str
|
||||||
messages: List[Dict[str, Any]]
|
messages: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Metadata:
|
class Metadata:
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
status: Optional[str] = None
|
status: str
|
||||||
created_at: Optional[datetime] = None
|
created_at: Union[str, Any]
|
||||||
updated_at: Optional[datetime] = None
|
updated_at: Union[str, Any]
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"status": self.status,
|
"status": self.status,
|
||||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
"created_at": str(self.created_at),
|
||||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
"updated_at": str(self.updated_at),
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "Metadata":
|
def from_dict(cls, data: Dict[str, Any]) -> "Metadata":
|
||||||
return cls(
|
return cls(
|
||||||
id=data["id"],
|
id=data["id"],
|
||||||
name=data["name"],
|
name=data.get("name", ""),
|
||||||
status=data.get("status"),
|
status=data.get("status", "todo"),
|
||||||
created_at=datetime.fromisoformat(data["created_at"])
|
created_at=data.get("created_at"),
|
||||||
if data.get("created_at")
|
updated_at=data.get("updated_at"),
|
||||||
else None,
|
|
||||||
updated_at=datetime.fromisoformat(data["updated_at"])
|
|
||||||
if data.get("updated_at")
|
|
||||||
else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrackState:
|
class TrackState:
|
||||||
metadata: Metadata
|
metadata: Metadata
|
||||||
discussion: List[Dict[str, Any]]
|
discussion: List[str] = field(default_factory=list)
|
||||||
tasks: List[Ticket]
|
tasks: List[Ticket] = field(default_factory=list)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"metadata": self.metadata.to_dict(),
|
"metadata": self.metadata.to_dict(),
|
||||||
"discussion": [
|
"discussion": self.discussion,
|
||||||
{
|
"tasks": [t.to_dict() for t in self.tasks],
|
||||||
k: v.isoformat() if isinstance(v, datetime) else v
|
|
||||||
for k, v in item.items()
|
|
||||||
}
|
|
||||||
for item in self.discussion
|
|
||||||
],
|
|
||||||
"tasks": [task.to_dict() for task in self.tasks],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "TrackState":
|
def from_dict(cls, data: Dict[str, Any]) -> "TrackState":
|
||||||
metadata = Metadata.from_dict(data["metadata"])
|
|
||||||
tasks = [Ticket.from_dict(task_data) for task_data in data["tasks"]]
|
|
||||||
return cls(
|
return cls(
|
||||||
metadata=metadata,
|
metadata=Metadata.from_dict(data["metadata"]),
|
||||||
discussion=[
|
discussion=data.get("discussion", []),
|
||||||
{
|
tasks=[Ticket.from_dict(t) for t in data.get("tasks", [])],
|
||||||
k: datetime.fromisoformat(v)
|
|
||||||
if isinstance(v, str) and "T" in v
|
|
||||||
else v # Basic check for ISO format
|
|
||||||
for k, v in item.items()
|
|
||||||
}
|
|
||||||
for item in data["discussion"]
|
|
||||||
],
|
|
||||||
tasks=tasks,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -313,6 +313,12 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
|||||||
ai_client.reset_session()
|
ai_client.reset_session()
|
||||||
ai_client.set_provider(ai_client.get_provider(), context.model_name)
|
ai_client.set_provider(ai_client.get_provider(), context.model_name)
|
||||||
context_injection = ""
|
context_injection = ""
|
||||||
|
tokens_before = 0
|
||||||
|
tokens_after = 0
|
||||||
|
|
||||||
|
def _count_tokens(text: str) -> int:
|
||||||
|
return len(text) // 4 # Rough estimate
|
||||||
|
|
||||||
if context_files:
|
if context_files:
|
||||||
parser = ASTParser(language="python")
|
parser = ASTParser(language="python")
|
||||||
for i, file_path in enumerate(context_files):
|
for i, file_path in enumerate(context_files):
|
||||||
@@ -321,14 +327,26 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
|||||||
# (This is a bit simplified, but helps)
|
# (This is a bit simplified, but helps)
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
|
|
||||||
|
tokens_before += _count_tokens(content)
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
view = parser.get_curated_view(content)
|
view = parser.get_curated_view(content, path=file_path)
|
||||||
|
elif ticket.target_file and Path(file_path).resolve() == Path(ticket.target_file).resolve() and ticket.target_symbols:
|
||||||
|
view = parser.get_targeted_view(content, ticket.target_symbols, path=file_path)
|
||||||
else:
|
else:
|
||||||
view = parser.get_skeleton(content)
|
view = parser.get_skeleton(content, path=file_path)
|
||||||
|
|
||||||
|
tokens_after += _count_tokens(view)
|
||||||
context_injection += f"\nFile: {file_path}\n{view}\n"
|
context_injection += f"\nFile: {file_path}\n{view}\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
context_injection += f"\nError reading {file_path}: {e}\n"
|
context_injection += f"\nError reading {file_path}: {e}\n"
|
||||||
# Build a prompt for the worker
|
|
||||||
|
if tokens_before > 0:
|
||||||
|
reduction = ((tokens_before - tokens_after) / tokens_before) * 100
|
||||||
|
print(f"[MMA] Context pruning for {ticket.id}: {tokens_before} -> {tokens_after} tokens ({reduction:.1f}% reduction)")
|
||||||
|
|
||||||
|
# Build a prompt for the worker
|
||||||
user_message = (
|
user_message = (
|
||||||
f"You are assigned to Ticket {ticket.id}.\n"
|
f"You are assigned to Ticket {ticket.id}.\n"
|
||||||
f"Task Description: {ticket.description}\n"
|
f"Task Description: {ticket.description}\n"
|
||||||
|
|||||||
@@ -63,3 +63,60 @@ def hot_func():
|
|||||||
assert 'print("keep me too")' in curated
|
assert 'print("keep me too")' in curated
|
||||||
assert '@core_logic' in curated
|
assert '@core_logic' in curated
|
||||||
assert '# [HOT]' in curated
|
assert '# [HOT]' in curated
|
||||||
|
|
||||||
|
def test_ast_parser_get_targeted_view() -> None:
|
||||||
|
"""Verify get_targeted_view includes targeted functions and dependencies."""
|
||||||
|
parser = ASTParser(language="python")
|
||||||
|
code = '''
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def dep2():
|
||||||
|
"""Dep 2"""
|
||||||
|
print("dep2")
|
||||||
|
|
||||||
|
def dep1():
|
||||||
|
"""Dep 1"""
|
||||||
|
dep2()
|
||||||
|
|
||||||
|
def targeted():
|
||||||
|
"""Targeted"""
|
||||||
|
dep1()
|
||||||
|
|
||||||
|
def unrelated():
|
||||||
|
"""Unrelated"""
|
||||||
|
print("unrelated")
|
||||||
|
|
||||||
|
class MyClass:
|
||||||
|
def method1(self):
|
||||||
|
"""Method 1"""
|
||||||
|
targeted()
|
||||||
|
def method2(self):
|
||||||
|
"""Method 2"""
|
||||||
|
pass
|
||||||
|
'''
|
||||||
|
# Depth 0: targeted
|
||||||
|
# Depth 1: dep1 (called by targeted)
|
||||||
|
# Depth 2: dep2 (called by dep1)
|
||||||
|
view = parser.get_targeted_view(code, ["targeted"])
|
||||||
|
assert 'def targeted():' in view
|
||||||
|
assert '"""Targeted"""' in view
|
||||||
|
assert 'def dep1():' in view
|
||||||
|
assert '"""Dep 1"""' in view
|
||||||
|
assert 'def dep2():' in view
|
||||||
|
assert '"""Dep 2"""' in view
|
||||||
|
assert 'def unrelated():' not in view
|
||||||
|
assert 'class MyClass:' not in view
|
||||||
|
assert 'import sys' in view
|
||||||
|
|
||||||
|
# Test depth limit
|
||||||
|
# Depth 0: MyClass.method1
|
||||||
|
# Depth 1: targeted (called by method1)
|
||||||
|
# Depth 2: dep1 (called by targeted)
|
||||||
|
# Depth 3: dep2 (called by dep1) -> should be elided
|
||||||
|
view2 = parser.get_targeted_view(code, ["MyClass.method1"])
|
||||||
|
assert 'class MyClass:' in view2
|
||||||
|
assert 'def method1(self):' in view2
|
||||||
|
assert 'def targeted():' in view2
|
||||||
|
assert 'def dep1():' in view2
|
||||||
|
assert 'def dep2():' not in view2
|
||||||
|
assert 'def method2(self):' not in view2
|
||||||
|
|||||||
113
tests/test_context_pruner.py
Normal file
113
tests/test_context_pruner.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from src.file_cache import ASTParser
|
||||||
|
from src.models import Ticket, Track, WorkerContext
|
||||||
|
from src.multi_agent_conductor import run_worker_lifecycle
|
||||||
|
|
||||||
|
def test_targeted_extraction():
|
||||||
|
parser = ASTParser("python")
|
||||||
|
code = """
|
||||||
|
def func_a():
|
||||||
|
print("A")
|
||||||
|
func_b()
|
||||||
|
|
||||||
|
def func_b():
|
||||||
|
print("B")
|
||||||
|
func_c()
|
||||||
|
|
||||||
|
def func_c():
|
||||||
|
print("C")
|
||||||
|
|
||||||
|
def func_unrelated():
|
||||||
|
print("Unrelated")
|
||||||
|
"""
|
||||||
|
# Target func_a, should include func_b and func_c (depth 2)
|
||||||
|
result = parser.get_targeted_view(code, ["func_a"])
|
||||||
|
assert "def func_a():" in result
|
||||||
|
assert "def func_b():" in result
|
||||||
|
assert "def func_c():" in result
|
||||||
|
assert "def func_unrelated():" not in result
|
||||||
|
assert "print(" not in result # Bodies should be stripped
|
||||||
|
|
||||||
|
def test_class_targeted_extraction():
|
||||||
|
parser = ASTParser("python")
|
||||||
|
code = """
|
||||||
|
class MyClass:
|
||||||
|
def method_a(self):
|
||||||
|
self.method_b()
|
||||||
|
|
||||||
|
def method_b(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def method_unrelated(self):
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
result = parser.get_targeted_view(code, ["MyClass.method_a"])
|
||||||
|
assert "class MyClass:" in result
|
||||||
|
assert "def method_a(self):" in result
|
||||||
|
assert "def method_b(self):" in result
|
||||||
|
assert "def method_unrelated(self):" not in result
|
||||||
|
|
||||||
|
def test_ast_caching(tmp_path):
|
||||||
|
parser = ASTParser("python")
|
||||||
|
file_path = tmp_path / "test_cache.py"
|
||||||
|
code = "def test(): pass"
|
||||||
|
file_path.write_text(code)
|
||||||
|
|
||||||
|
# First call: parses and caches
|
||||||
|
start = time.time()
|
||||||
|
view1 = parser.get_skeleton(code, path=str(file_path))
|
||||||
|
duration1 = time.time() - start
|
||||||
|
|
||||||
|
# Second call: should use cache
|
||||||
|
start = time.time()
|
||||||
|
view2 = parser.get_skeleton(code, path=str(file_path))
|
||||||
|
duration2 = time.time() - start
|
||||||
|
|
||||||
|
assert view1 == view2
|
||||||
|
# duration2 should be significantly faster, but let's just check it works
|
||||||
|
|
||||||
|
# Update file: should invalidate cache
|
||||||
|
time.sleep(0.1)
|
||||||
|
new_code = "def test_new(): pass"
|
||||||
|
file_path.write_text(new_code)
|
||||||
|
view3 = parser.get_skeleton(new_code, path=str(file_path))
|
||||||
|
assert "def test_new():" in view3
|
||||||
|
assert "def test():" not in view3
|
||||||
|
|
||||||
|
def test_performance_large_file():
|
||||||
|
parser = ASTParser("python")
|
||||||
|
# Generate a large file (approx 1000 lines)
|
||||||
|
code = "\n".join([f"def func_{i}():\n pass" for i in range(500)])
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
parser.get_skeleton(code)
|
||||||
|
duration = time.time() - start
|
||||||
|
|
||||||
|
print(f"Large file parse duration: {duration*1000:.2f}ms")
|
||||||
|
assert duration < 0.5 # Should be well under 500ms even for first parse
|
||||||
|
|
||||||
|
def test_token_reduction_logging(capsys):
|
||||||
|
ticket = Ticket(
|
||||||
|
id="T1",
|
||||||
|
description="Test ticket",
|
||||||
|
target_file="test.py",
|
||||||
|
target_symbols=["func_a"],
|
||||||
|
context_requirements=["test.py"]
|
||||||
|
)
|
||||||
|
context = WorkerContext(ticket_id="T1", model_name="gemini-2.5-flash", messages=[])
|
||||||
|
|
||||||
|
code = "def func_a(): pass\n" + "\n".join([f"def func_{i}(): pass" for i in range(100)])
|
||||||
|
|
||||||
|
# Mock open to return our large code
|
||||||
|
with pytest.MonkeyPatch().context() as m:
|
||||||
|
m.setattr("builtins.open", lambda f, *args, **kwargs: type('obj', (object,), {'read': lambda s: code, '__enter__': lambda s: s, '__exit__': lambda s, *a: None})())
|
||||||
|
m.setattr("pathlib.Path.exists", lambda s: True)
|
||||||
|
m.setattr("src.ai_client.send", lambda **kwargs: "DONE")
|
||||||
|
|
||||||
|
run_worker_lifecycle(ticket, context, context_files=["test.py"])
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "[MMA] Context pruning for T1" in captured.out
|
||||||
|
assert "reduction" in captured.out
|
||||||
Reference in New Issue
Block a user