feat(mma): Implement spawn interception in multi_agent_conductor.py
This commit is contained in:
@@ -1,11 +1,12 @@
|
|||||||
import ai_client
|
import ai_client
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
import events
|
import events
|
||||||
from models import Ticket, Track, WorkerContext
|
from models import Ticket, Track, WorkerContext
|
||||||
from file_cache import ASTParser
|
from file_cache import ASTParser
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from dag_engine import TrackDAG, ExecutionEngine
|
from dag_engine import TrackDAG, ExecutionEngine
|
||||||
|
|
||||||
@@ -181,6 +182,57 @@ def confirm_execution(payload: str, event_queue: events.AsyncEventQueue, ticket_
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.AsyncEventQueue, ticket_id: str) -> Tuple[bool, str, str]:
|
||||||
|
"""
|
||||||
|
Pushes a spawn approval request to the GUI and waits for response.
|
||||||
|
Returns (approved, modified_prompt, modified_context)
|
||||||
|
"""
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
dialog_container = [None]
|
||||||
|
|
||||||
|
task = {
|
||||||
|
"action": "mma_spawn_approval",
|
||||||
|
"ticket_id": ticket_id,
|
||||||
|
"role": role,
|
||||||
|
"prompt": prompt,
|
||||||
|
"context_md": context_md,
|
||||||
|
"dialog_container": dialog_container
|
||||||
|
}
|
||||||
|
|
||||||
|
# Push to queue
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
if loop.is_running():
|
||||||
|
asyncio.run_coroutine_threadsafe(event_queue.put("mma_spawn_approval", task), loop)
|
||||||
|
else:
|
||||||
|
event_queue._queue.put_nowait(("mma_spawn_approval", task))
|
||||||
|
except Exception:
|
||||||
|
# Fallback if no loop
|
||||||
|
event_queue._queue.put_nowait(("mma_spawn_approval", task))
|
||||||
|
|
||||||
|
# Wait for the GUI to create the dialog and for the user to respond
|
||||||
|
start = time.time()
|
||||||
|
while dialog_container[0] is None and time.time() - start < 60:
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
if dialog_container[0]:
|
||||||
|
approved, final_payload = dialog_container[0].wait()
|
||||||
|
|
||||||
|
# Extract modifications from final_payload if it's a dict
|
||||||
|
modified_prompt = prompt
|
||||||
|
modified_context = context_md
|
||||||
|
|
||||||
|
if isinstance(final_payload, dict):
|
||||||
|
modified_prompt = final_payload.get("prompt", prompt)
|
||||||
|
modified_context = final_payload.get("context_md", context_md)
|
||||||
|
|
||||||
|
return approved, modified_prompt, modified_context
|
||||||
|
|
||||||
|
return False, prompt, context_md
|
||||||
|
|
||||||
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] = None, event_queue: events.AsyncEventQueue = None, engine: Optional['ConductorEngine'] = None, md_content: str = ""):
|
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] = None, event_queue: events.AsyncEventQueue = None, engine: Optional['ConductorEngine'] = None, md_content: str = ""):
|
||||||
"""
|
"""
|
||||||
Simulates the lifecycle of a single agent working on a ticket.
|
Simulates the lifecycle of a single agent working on a ticket.
|
||||||
@@ -202,10 +254,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
|||||||
for i, file_path in enumerate(context_files):
|
for i, file_path in enumerate(context_files):
|
||||||
try:
|
try:
|
||||||
abs_path = Path(file_path)
|
abs_path = Path(file_path)
|
||||||
if not abs_path.is_absolute() and engine:
|
# (This is a bit simplified, but helps)
|
||||||
# Resolve relative to project base if possible
|
|
||||||
# (This is a bit simplified, but helps)
|
|
||||||
pass
|
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@@ -229,6 +278,22 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
|||||||
"start your response with 'BLOCKED' and explain why."
|
"start your response with 'BLOCKED' and explain why."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# HITL Clutch: call confirm_spawn if event_queue is provided
|
||||||
|
if event_queue:
|
||||||
|
approved, modified_prompt, modified_context = confirm_spawn(
|
||||||
|
role="Tier 3 Worker",
|
||||||
|
prompt=user_message,
|
||||||
|
context_md=md_content,
|
||||||
|
event_queue=event_queue,
|
||||||
|
ticket_id=ticket.id
|
||||||
|
)
|
||||||
|
if not approved:
|
||||||
|
ticket.mark_blocked("Spawn rejected by user.")
|
||||||
|
return "BLOCKED: Spawn rejected by user."
|
||||||
|
|
||||||
|
user_message = modified_prompt
|
||||||
|
md_content = modified_context
|
||||||
|
|
||||||
# HITL Clutch: pass the queue and ticket_id to confirm_execution
|
# HITL Clutch: pass the queue and ticket_id to confirm_execution
|
||||||
def clutch_callback(payload: str) -> bool:
|
def clutch_callback(payload: str) -> bool:
|
||||||
if not event_queue:
|
if not event_queue:
|
||||||
@@ -246,9 +311,6 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
|||||||
# Update usage in engine if provided
|
# Update usage in engine if provided
|
||||||
if engine:
|
if engine:
|
||||||
stats = {} # ai_client.get_token_stats() is not available
|
stats = {} # ai_client.get_token_stats() is not available
|
||||||
# ai_client provides aggregate stats, for granular tier tracking
|
|
||||||
# we'd need to diff before/after or have ai_client return usage per call.
|
|
||||||
# For Phase 4, we'll use a simplified diff approach.
|
|
||||||
engine.tier_usage["Tier 3"]["input"] += stats.get("prompt_tokens", 0)
|
engine.tier_usage["Tier 3"]["input"] += stats.get("prompt_tokens", 0)
|
||||||
engine.tier_usage["Tier 3"]["output"] += stats.get("candidates_tokens", 0)
|
engine.tier_usage["Tier 3"]["output"] += stats.get("candidates_tokens", 0)
|
||||||
|
|
||||||
|
|||||||
86
tests/test_spawn_interception.py
Normal file
86
tests/test_spawn_interception.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import multi_agent_conductor
|
||||||
|
from models import Ticket, WorkerContext
|
||||||
|
import events
|
||||||
|
import asyncio
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
class MockDialog:
|
||||||
|
def __init__(self, approved, final_payload=None):
|
||||||
|
self.approved = approved
|
||||||
|
self.final_payload = final_payload
|
||||||
|
def wait(self):
|
||||||
|
return self.approved, self.final_payload
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ai_client():
|
||||||
|
with patch("ai_client.send") as mock_send:
|
||||||
|
mock_send.return_value = "Task completed"
|
||||||
|
yield mock_send
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_confirm_spawn_pushed_to_queue():
|
||||||
|
event_queue = events.AsyncEventQueue()
|
||||||
|
ticket_id = "T1"
|
||||||
|
role = "Tier 3 Worker"
|
||||||
|
prompt = "Original Prompt"
|
||||||
|
context_md = "Original Context"
|
||||||
|
|
||||||
|
# Start confirm_spawn in a thread since it blocks with time.sleep
|
||||||
|
def run_confirm():
|
||||||
|
return multi_agent_conductor.confirm_spawn(role, prompt, context_md, event_queue, ticket_id)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = loop.run_in_executor(executor, run_confirm)
|
||||||
|
|
||||||
|
# Wait for the event to appear in the queue
|
||||||
|
event_name, payload = await event_queue.get()
|
||||||
|
assert event_name == "mma_spawn_approval"
|
||||||
|
assert payload["ticket_id"] == ticket_id
|
||||||
|
assert payload["role"] == role
|
||||||
|
assert payload["prompt"] == prompt
|
||||||
|
assert payload["context_md"] == context_md
|
||||||
|
assert "dialog_container" in payload
|
||||||
|
|
||||||
|
# Simulate GUI injecting a dialog
|
||||||
|
payload["dialog_container"][0] = MockDialog(True, {"prompt": "Modified Prompt", "context_md": "Modified Context"})
|
||||||
|
|
||||||
|
approved, final_prompt, final_context = await future
|
||||||
|
assert approved is True
|
||||||
|
assert final_prompt == "Modified Prompt"
|
||||||
|
assert final_context == "Modified Context"
|
||||||
|
|
||||||
|
@patch("multi_agent_conductor.confirm_spawn")
|
||||||
|
def test_run_worker_lifecycle_approved(mock_confirm, mock_ai_client):
|
||||||
|
ticket = Ticket(id="T1", description="desc", status="todo", assigned_to="user")
|
||||||
|
context = WorkerContext(ticket_id="T1", model_name="model", messages=[])
|
||||||
|
event_queue = events.AsyncEventQueue()
|
||||||
|
|
||||||
|
mock_confirm.return_value = (True, "Modified Prompt", "Modified Context")
|
||||||
|
|
||||||
|
multi_agent_conductor.run_worker_lifecycle(ticket, context, event_queue=event_queue)
|
||||||
|
|
||||||
|
mock_confirm.assert_called_once()
|
||||||
|
# Check that ai_client.send was called with modified values
|
||||||
|
args, kwargs = mock_ai_client.call_args
|
||||||
|
assert kwargs["user_message"] == "Modified Prompt"
|
||||||
|
assert kwargs["md_content"] == "Modified Context"
|
||||||
|
assert ticket.status == "completed"
|
||||||
|
|
||||||
|
@patch("multi_agent_conductor.confirm_spawn")
|
||||||
|
def test_run_worker_lifecycle_rejected(mock_confirm, mock_ai_client):
|
||||||
|
ticket = Ticket(id="T1", description="desc", status="todo", assigned_to="user")
|
||||||
|
context = WorkerContext(ticket_id="T1", model_name="model", messages=[])
|
||||||
|
event_queue = events.AsyncEventQueue()
|
||||||
|
|
||||||
|
mock_confirm.return_value = (False, "Original Prompt", "Original Context")
|
||||||
|
|
||||||
|
result = multi_agent_conductor.run_worker_lifecycle(ticket, context, event_queue=event_queue)
|
||||||
|
|
||||||
|
mock_confirm.assert_called_once()
|
||||||
|
mock_ai_client.assert_not_called()
|
||||||
|
assert ticket.status == "blocked"
|
||||||
|
assert "Spawn rejected by user" in ticket.blocked_reason
|
||||||
|
assert "BLOCKED" in result
|
||||||
Reference in New Issue
Block a user