diff --git a/multi_agent_conductor.py b/multi_agent_conductor.py index b5a1f93..e767cc9 100644 --- a/multi_agent_conductor.py +++ b/multi_agent_conductor.py @@ -1,11 +1,12 @@ import ai_client import json import asyncio -from typing import List, Optional +from typing import List, Optional, Tuple from dataclasses import asdict import events from models import Ticket, Track, WorkerContext from file_cache import ASTParser +from pathlib import Path from dag_engine import TrackDAG, ExecutionEngine @@ -181,6 +182,57 @@ def confirm_execution(payload: str, event_queue: events.AsyncEventQueue, ticket_ return False +def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.AsyncEventQueue, ticket_id: str) -> Tuple[bool, str, str]: + """ + Pushes a spawn approval request to the GUI and waits for response. + Returns (approved, modified_prompt, modified_context) + """ + import threading + import time + import asyncio + + dialog_container = [None] + + task = { + "action": "mma_spawn_approval", + "ticket_id": ticket_id, + "role": role, + "prompt": prompt, + "context_md": context_md, + "dialog_container": dialog_container + } + + # Push to queue + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + asyncio.run_coroutine_threadsafe(event_queue.put("mma_spawn_approval", task), loop) + else: + event_queue._queue.put_nowait(("mma_spawn_approval", task)) + except Exception: + # Fallback if no loop + event_queue._queue.put_nowait(("mma_spawn_approval", task)) + + # Wait for the GUI to create the dialog and for the user to respond + start = time.time() + while dialog_container[0] is None and time.time() - start < 60: + time.sleep(0.1) + + if dialog_container[0]: + approved, final_payload = dialog_container[0].wait() + + # Extract modifications from final_payload if it's a dict + modified_prompt = prompt + modified_context = context_md + + if isinstance(final_payload, dict): + modified_prompt = final_payload.get("prompt", prompt) + modified_context = final_payload.get("context_md", context_md) + + return approved, modified_prompt, modified_context + + return False, prompt, context_md + def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] = None, event_queue: events.AsyncEventQueue = None, engine: Optional['ConductorEngine'] = None, md_content: str = ""): """ Simulates the lifecycle of a single agent working on a ticket. @@ -202,10 +254,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: for i, file_path in enumerate(context_files): try: abs_path = Path(file_path) - if not abs_path.is_absolute() and engine: - # Resolve relative to project base if possible - # (This is a bit simplified, but helps) - pass + # (This is a bit simplified, but helps) with open(file_path, 'r', encoding='utf-8') as f: content = f.read() if i == 0: @@ -229,6 +278,22 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: "start your response with 'BLOCKED' and explain why." ) + # HITL Clutch: call confirm_spawn if event_queue is provided + if event_queue: + approved, modified_prompt, modified_context = confirm_spawn( + role="Tier 3 Worker", + prompt=user_message, + context_md=md_content, + event_queue=event_queue, + ticket_id=ticket.id + ) + if not approved: + ticket.mark_blocked("Spawn rejected by user.") + return "BLOCKED: Spawn rejected by user." + + user_message = modified_prompt + md_content = modified_context + # HITL Clutch: pass the queue and ticket_id to confirm_execution def clutch_callback(payload: str) -> bool: if not event_queue: @@ -246,9 +311,6 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: # Update usage in engine if provided if engine: stats = {} # ai_client.get_token_stats() is not available - # ai_client provides aggregate stats, for granular tier tracking - # we'd need to diff before/after or have ai_client return usage per call. - # For Phase 4, we'll use a simplified diff approach. engine.tier_usage["Tier 3"]["input"] += stats.get("prompt_tokens", 0) engine.tier_usage["Tier 3"]["output"] += stats.get("candidates_tokens", 0) diff --git a/tests/test_spawn_interception.py b/tests/test_spawn_interception.py new file mode 100644 index 0000000..d8ed3c5 --- /dev/null +++ b/tests/test_spawn_interception.py @@ -0,0 +1,86 @@ +import pytest +from unittest.mock import MagicMock, patch +import multi_agent_conductor +from models import Ticket, WorkerContext +import events +import asyncio +import concurrent.futures + +class MockDialog: + def __init__(self, approved, final_payload=None): + self.approved = approved + self.final_payload = final_payload + def wait(self): + return self.approved, self.final_payload + +@pytest.fixture +def mock_ai_client(): + with patch("ai_client.send") as mock_send: + mock_send.return_value = "Task completed" + yield mock_send + +@pytest.mark.asyncio +async def test_confirm_spawn_pushed_to_queue(): + event_queue = events.AsyncEventQueue() + ticket_id = "T1" + role = "Tier 3 Worker" + prompt = "Original Prompt" + context_md = "Original Context" + + # Start confirm_spawn in a thread since it blocks with time.sleep + def run_confirm(): + return multi_agent_conductor.confirm_spawn(role, prompt, context_md, event_queue, ticket_id) + + loop = asyncio.get_running_loop() + with concurrent.futures.ThreadPoolExecutor() as executor: + future = loop.run_in_executor(executor, run_confirm) + + # Wait for the event to appear in the queue + event_name, payload = await event_queue.get() + assert event_name == "mma_spawn_approval" + assert payload["ticket_id"] == ticket_id + assert payload["role"] == role + assert payload["prompt"] == prompt + assert payload["context_md"] == context_md + assert "dialog_container" in payload + + # Simulate GUI injecting a dialog + payload["dialog_container"][0] = MockDialog(True, {"prompt": "Modified Prompt", "context_md": "Modified Context"}) + + approved, final_prompt, final_context = await future + assert approved is True + assert final_prompt == "Modified Prompt" + assert final_context == "Modified Context" + +@patch("multi_agent_conductor.confirm_spawn") +def test_run_worker_lifecycle_approved(mock_confirm, mock_ai_client): + ticket = Ticket(id="T1", description="desc", status="todo", assigned_to="user") + context = WorkerContext(ticket_id="T1", model_name="model", messages=[]) + event_queue = events.AsyncEventQueue() + + mock_confirm.return_value = (True, "Modified Prompt", "Modified Context") + + multi_agent_conductor.run_worker_lifecycle(ticket, context, event_queue=event_queue) + + mock_confirm.assert_called_once() + # Check that ai_client.send was called with modified values + args, kwargs = mock_ai_client.call_args + assert kwargs["user_message"] == "Modified Prompt" + assert kwargs["md_content"] == "Modified Context" + assert ticket.status == "completed" + +@patch("multi_agent_conductor.confirm_spawn") +def test_run_worker_lifecycle_rejected(mock_confirm, mock_ai_client): + ticket = Ticket(id="T1", description="desc", status="todo", assigned_to="user") + context = WorkerContext(ticket_id="T1", model_name="model", messages=[]) + event_queue = events.AsyncEventQueue() + + mock_confirm.return_value = (False, "Original Prompt", "Original Context") + + result = multi_agent_conductor.run_worker_lifecycle(ticket, context, event_queue=event_queue) + + mock_confirm.assert_called_once() + mock_ai_client.assert_not_called() + assert ticket.status == "blocked" + assert "Spawn rejected by user" in ticket.blocked_reason + assert "BLOCKED" in result