feat(mma): Implement spawn interception in multi_agent_conductor.py

This commit is contained in:
2026-02-27 22:27:05 -05:00
parent c2c8732100
commit e293c5e302
2 changed files with 156 additions and 8 deletions

View File

@@ -1,11 +1,12 @@
import ai_client import ai_client
import json import json
import asyncio import asyncio
from typing import List, Optional from typing import List, Optional, Tuple
from dataclasses import asdict from dataclasses import asdict
import events import events
from models import Ticket, Track, WorkerContext from models import Ticket, Track, WorkerContext
from file_cache import ASTParser from file_cache import ASTParser
from pathlib import Path
from dag_engine import TrackDAG, ExecutionEngine from dag_engine import TrackDAG, ExecutionEngine
@@ -181,6 +182,57 @@ def confirm_execution(payload: str, event_queue: events.AsyncEventQueue, ticket_
return False return False
def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.AsyncEventQueue, ticket_id: str) -> Tuple[bool, str, str]:
"""
Pushes a spawn approval request to the GUI and waits for response.
Returns (approved, modified_prompt, modified_context)
"""
import threading
import time
import asyncio
dialog_container = [None]
task = {
"action": "mma_spawn_approval",
"ticket_id": ticket_id,
"role": role,
"prompt": prompt,
"context_md": context_md,
"dialog_container": dialog_container
}
# Push to queue
try:
loop = asyncio.get_event_loop()
if loop.is_running():
asyncio.run_coroutine_threadsafe(event_queue.put("mma_spawn_approval", task), loop)
else:
event_queue._queue.put_nowait(("mma_spawn_approval", task))
except Exception:
# Fallback if no loop
event_queue._queue.put_nowait(("mma_spawn_approval", task))
# Wait for the GUI to create the dialog and for the user to respond
start = time.time()
while dialog_container[0] is None and time.time() - start < 60:
time.sleep(0.1)
if dialog_container[0]:
approved, final_payload = dialog_container[0].wait()
# Extract modifications from final_payload if it's a dict
modified_prompt = prompt
modified_context = context_md
if isinstance(final_payload, dict):
modified_prompt = final_payload.get("prompt", prompt)
modified_context = final_payload.get("context_md", context_md)
return approved, modified_prompt, modified_context
return False, prompt, context_md
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] = None, event_queue: events.AsyncEventQueue = None, engine: Optional['ConductorEngine'] = None, md_content: str = ""): def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] = None, event_queue: events.AsyncEventQueue = None, engine: Optional['ConductorEngine'] = None, md_content: str = ""):
""" """
Simulates the lifecycle of a single agent working on a ticket. Simulates the lifecycle of a single agent working on a ticket.
@@ -202,10 +254,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
for i, file_path in enumerate(context_files): for i, file_path in enumerate(context_files):
try: try:
abs_path = Path(file_path) abs_path = Path(file_path)
if not abs_path.is_absolute() and engine: # (This is a bit simplified, but helps)
# Resolve relative to project base if possible
# (This is a bit simplified, but helps)
pass
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, 'r', encoding='utf-8') as f:
content = f.read() content = f.read()
if i == 0: if i == 0:
@@ -229,6 +278,22 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
"start your response with 'BLOCKED' and explain why." "start your response with 'BLOCKED' and explain why."
) )
# HITL Clutch: call confirm_spawn if event_queue is provided
if event_queue:
approved, modified_prompt, modified_context = confirm_spawn(
role="Tier 3 Worker",
prompt=user_message,
context_md=md_content,
event_queue=event_queue,
ticket_id=ticket.id
)
if not approved:
ticket.mark_blocked("Spawn rejected by user.")
return "BLOCKED: Spawn rejected by user."
user_message = modified_prompt
md_content = modified_context
# HITL Clutch: pass the queue and ticket_id to confirm_execution # HITL Clutch: pass the queue and ticket_id to confirm_execution
def clutch_callback(payload: str) -> bool: def clutch_callback(payload: str) -> bool:
if not event_queue: if not event_queue:
@@ -246,9 +311,6 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
# Update usage in engine if provided # Update usage in engine if provided
if engine: if engine:
stats = {} # ai_client.get_token_stats() is not available stats = {} # ai_client.get_token_stats() is not available
# ai_client provides aggregate stats, for granular tier tracking
# we'd need to diff before/after or have ai_client return usage per call.
# For Phase 4, we'll use a simplified diff approach.
engine.tier_usage["Tier 3"]["input"] += stats.get("prompt_tokens", 0) engine.tier_usage["Tier 3"]["input"] += stats.get("prompt_tokens", 0)
engine.tier_usage["Tier 3"]["output"] += stats.get("candidates_tokens", 0) engine.tier_usage["Tier 3"]["output"] += stats.get("candidates_tokens", 0)

View File

@@ -0,0 +1,86 @@
import pytest
from unittest.mock import MagicMock, patch
import multi_agent_conductor
from models import Ticket, WorkerContext
import events
import asyncio
import concurrent.futures
class MockDialog:
def __init__(self, approved, final_payload=None):
self.approved = approved
self.final_payload = final_payload
def wait(self):
return self.approved, self.final_payload
@pytest.fixture
def mock_ai_client():
with patch("ai_client.send") as mock_send:
mock_send.return_value = "Task completed"
yield mock_send
@pytest.mark.asyncio
async def test_confirm_spawn_pushed_to_queue():
event_queue = events.AsyncEventQueue()
ticket_id = "T1"
role = "Tier 3 Worker"
prompt = "Original Prompt"
context_md = "Original Context"
# Start confirm_spawn in a thread since it blocks with time.sleep
def run_confirm():
return multi_agent_conductor.confirm_spawn(role, prompt, context_md, event_queue, ticket_id)
loop = asyncio.get_running_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
future = loop.run_in_executor(executor, run_confirm)
# Wait for the event to appear in the queue
event_name, payload = await event_queue.get()
assert event_name == "mma_spawn_approval"
assert payload["ticket_id"] == ticket_id
assert payload["role"] == role
assert payload["prompt"] == prompt
assert payload["context_md"] == context_md
assert "dialog_container" in payload
# Simulate GUI injecting a dialog
payload["dialog_container"][0] = MockDialog(True, {"prompt": "Modified Prompt", "context_md": "Modified Context"})
approved, final_prompt, final_context = await future
assert approved is True
assert final_prompt == "Modified Prompt"
assert final_context == "Modified Context"
@patch("multi_agent_conductor.confirm_spawn")
def test_run_worker_lifecycle_approved(mock_confirm, mock_ai_client):
ticket = Ticket(id="T1", description="desc", status="todo", assigned_to="user")
context = WorkerContext(ticket_id="T1", model_name="model", messages=[])
event_queue = events.AsyncEventQueue()
mock_confirm.return_value = (True, "Modified Prompt", "Modified Context")
multi_agent_conductor.run_worker_lifecycle(ticket, context, event_queue=event_queue)
mock_confirm.assert_called_once()
# Check that ai_client.send was called with modified values
args, kwargs = mock_ai_client.call_args
assert kwargs["user_message"] == "Modified Prompt"
assert kwargs["md_content"] == "Modified Context"
assert ticket.status == "completed"
@patch("multi_agent_conductor.confirm_spawn")
def test_run_worker_lifecycle_rejected(mock_confirm, mock_ai_client):
ticket = Ticket(id="T1", description="desc", status="todo", assigned_to="user")
context = WorkerContext(ticket_id="T1", model_name="model", messages=[])
event_queue = events.AsyncEventQueue()
mock_confirm.return_value = (False, "Original Prompt", "Original Context")
result = multi_agent_conductor.run_worker_lifecycle(ticket, context, event_queue=event_queue)
mock_confirm.assert_called_once()
mock_ai_client.assert_not_called()
assert ticket.status == "blocked"
assert "Spawn rejected by user" in ticket.blocked_reason
assert "BLOCKED" in result