feat(mma): Implement Deep AST-Driven Context Pruning and mark track complete
This commit is contained in:
@@ -63,3 +63,60 @@ def hot_func():
|
||||
assert 'print("keep me too")' in curated
|
||||
assert '@core_logic' in curated
|
||||
assert '# [HOT]' in curated
|
||||
|
||||
def test_ast_parser_get_targeted_view() -> None:
|
||||
"""Verify get_targeted_view includes targeted functions and dependencies."""
|
||||
parser = ASTParser(language="python")
|
||||
code = '''
|
||||
import sys
|
||||
|
||||
def dep2():
|
||||
"""Dep 2"""
|
||||
print("dep2")
|
||||
|
||||
def dep1():
|
||||
"""Dep 1"""
|
||||
dep2()
|
||||
|
||||
def targeted():
|
||||
"""Targeted"""
|
||||
dep1()
|
||||
|
||||
def unrelated():
|
||||
"""Unrelated"""
|
||||
print("unrelated")
|
||||
|
||||
class MyClass:
|
||||
def method1(self):
|
||||
"""Method 1"""
|
||||
targeted()
|
||||
def method2(self):
|
||||
"""Method 2"""
|
||||
pass
|
||||
'''
|
||||
# Depth 0: targeted
|
||||
# Depth 1: dep1 (called by targeted)
|
||||
# Depth 2: dep2 (called by dep1)
|
||||
view = parser.get_targeted_view(code, ["targeted"])
|
||||
assert 'def targeted():' in view
|
||||
assert '"""Targeted"""' in view
|
||||
assert 'def dep1():' in view
|
||||
assert '"""Dep 1"""' in view
|
||||
assert 'def dep2():' in view
|
||||
assert '"""Dep 2"""' in view
|
||||
assert 'def unrelated():' not in view
|
||||
assert 'class MyClass:' not in view
|
||||
assert 'import sys' in view
|
||||
|
||||
# Test depth limit
|
||||
# Depth 0: MyClass.method1
|
||||
# Depth 1: targeted (called by method1)
|
||||
# Depth 2: dep1 (called by targeted)
|
||||
# Depth 3: dep2 (called by dep1) -> should be elided
|
||||
view2 = parser.get_targeted_view(code, ["MyClass.method1"])
|
||||
assert 'class MyClass:' in view2
|
||||
assert 'def method1(self):' in view2
|
||||
assert 'def targeted():' in view2
|
||||
assert 'def dep1():' in view2
|
||||
assert 'def dep2():' not in view2
|
||||
assert 'def method2(self):' not in view2
|
||||
|
||||
113
tests/test_context_pruner.py
Normal file
113
tests/test_context_pruner.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import pytest
|
||||
import time
|
||||
from pathlib import Path
|
||||
from src.file_cache import ASTParser
|
||||
from src.models import Ticket, Track, WorkerContext
|
||||
from src.multi_agent_conductor import run_worker_lifecycle
|
||||
|
||||
def test_targeted_extraction():
|
||||
parser = ASTParser("python")
|
||||
code = """
|
||||
def func_a():
|
||||
print("A")
|
||||
func_b()
|
||||
|
||||
def func_b():
|
||||
print("B")
|
||||
func_c()
|
||||
|
||||
def func_c():
|
||||
print("C")
|
||||
|
||||
def func_unrelated():
|
||||
print("Unrelated")
|
||||
"""
|
||||
# Target func_a, should include func_b and func_c (depth 2)
|
||||
result = parser.get_targeted_view(code, ["func_a"])
|
||||
assert "def func_a():" in result
|
||||
assert "def func_b():" in result
|
||||
assert "def func_c():" in result
|
||||
assert "def func_unrelated():" not in result
|
||||
assert "print(" not in result # Bodies should be stripped
|
||||
|
||||
def test_class_targeted_extraction():
|
||||
parser = ASTParser("python")
|
||||
code = """
|
||||
class MyClass:
|
||||
def method_a(self):
|
||||
self.method_b()
|
||||
|
||||
def method_b(self):
|
||||
pass
|
||||
|
||||
def method_unrelated(self):
|
||||
pass
|
||||
"""
|
||||
result = parser.get_targeted_view(code, ["MyClass.method_a"])
|
||||
assert "class MyClass:" in result
|
||||
assert "def method_a(self):" in result
|
||||
assert "def method_b(self):" in result
|
||||
assert "def method_unrelated(self):" not in result
|
||||
|
||||
def test_ast_caching(tmp_path):
|
||||
parser = ASTParser("python")
|
||||
file_path = tmp_path / "test_cache.py"
|
||||
code = "def test(): pass"
|
||||
file_path.write_text(code)
|
||||
|
||||
# First call: parses and caches
|
||||
start = time.time()
|
||||
view1 = parser.get_skeleton(code, path=str(file_path))
|
||||
duration1 = time.time() - start
|
||||
|
||||
# Second call: should use cache
|
||||
start = time.time()
|
||||
view2 = parser.get_skeleton(code, path=str(file_path))
|
||||
duration2 = time.time() - start
|
||||
|
||||
assert view1 == view2
|
||||
# duration2 should be significantly faster, but let's just check it works
|
||||
|
||||
# Update file: should invalidate cache
|
||||
time.sleep(0.1)
|
||||
new_code = "def test_new(): pass"
|
||||
file_path.write_text(new_code)
|
||||
view3 = parser.get_skeleton(new_code, path=str(file_path))
|
||||
assert "def test_new():" in view3
|
||||
assert "def test():" not in view3
|
||||
|
||||
def test_performance_large_file():
|
||||
parser = ASTParser("python")
|
||||
# Generate a large file (approx 1000 lines)
|
||||
code = "\n".join([f"def func_{i}():\n pass" for i in range(500)])
|
||||
|
||||
start = time.time()
|
||||
parser.get_skeleton(code)
|
||||
duration = time.time() - start
|
||||
|
||||
print(f"Large file parse duration: {duration*1000:.2f}ms")
|
||||
assert duration < 0.5 # Should be well under 500ms even for first parse
|
||||
|
||||
def test_token_reduction_logging(capsys):
|
||||
ticket = Ticket(
|
||||
id="T1",
|
||||
description="Test ticket",
|
||||
target_file="test.py",
|
||||
target_symbols=["func_a"],
|
||||
context_requirements=["test.py"]
|
||||
)
|
||||
context = WorkerContext(ticket_id="T1", model_name="gemini-2.5-flash", messages=[])
|
||||
|
||||
code = "def func_a(): pass\n" + "\n".join([f"def func_{i}(): pass" for i in range(100)])
|
||||
|
||||
# Mock open to return our large code
|
||||
with pytest.MonkeyPatch().context() as m:
|
||||
m.setattr("builtins.open", lambda f, *args, **kwargs: type('obj', (object,), {'read': lambda s: code, '__enter__': lambda s: s, '__exit__': lambda s, *a: None})())
|
||||
m.setattr("pathlib.Path.exists", lambda s: True)
|
||||
m.setattr("src.ai_client.send", lambda **kwargs: "DONE")
|
||||
|
||||
run_worker_lifecycle(ticket, context, context_files=["test.py"])
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "[MMA] Context pruning for T1" in captured.out
|
||||
assert "reduction" in captured.out
|
||||
Reference in New Issue
Block a user