feat(mma): Implement spawn interception in multi_agent_conductor.py
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import ai_client
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import asdict
|
||||
import events
|
||||
from models import Ticket, Track, WorkerContext
|
||||
from file_cache import ASTParser
|
||||
from pathlib import Path
|
||||
|
||||
from dag_engine import TrackDAG, ExecutionEngine
|
||||
|
||||
@@ -181,6 +182,57 @@ def confirm_execution(payload: str, event_queue: events.AsyncEventQueue, ticket_
|
||||
|
||||
return False
|
||||
|
||||
def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.AsyncEventQueue, ticket_id: str) -> Tuple[bool, str, str]:
|
||||
"""
|
||||
Pushes a spawn approval request to the GUI and waits for response.
|
||||
Returns (approved, modified_prompt, modified_context)
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
dialog_container = [None]
|
||||
|
||||
task = {
|
||||
"action": "mma_spawn_approval",
|
||||
"ticket_id": ticket_id,
|
||||
"role": role,
|
||||
"prompt": prompt,
|
||||
"context_md": context_md,
|
||||
"dialog_container": dialog_container
|
||||
}
|
||||
|
||||
# Push to queue
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(event_queue.put("mma_spawn_approval", task), loop)
|
||||
else:
|
||||
event_queue._queue.put_nowait(("mma_spawn_approval", task))
|
||||
except Exception:
|
||||
# Fallback if no loop
|
||||
event_queue._queue.put_nowait(("mma_spawn_approval", task))
|
||||
|
||||
# Wait for the GUI to create the dialog and for the user to respond
|
||||
start = time.time()
|
||||
while dialog_container[0] is None and time.time() - start < 60:
|
||||
time.sleep(0.1)
|
||||
|
||||
if dialog_container[0]:
|
||||
approved, final_payload = dialog_container[0].wait()
|
||||
|
||||
# Extract modifications from final_payload if it's a dict
|
||||
modified_prompt = prompt
|
||||
modified_context = context_md
|
||||
|
||||
if isinstance(final_payload, dict):
|
||||
modified_prompt = final_payload.get("prompt", prompt)
|
||||
modified_context = final_payload.get("context_md", context_md)
|
||||
|
||||
return approved, modified_prompt, modified_context
|
||||
|
||||
return False, prompt, context_md
|
||||
|
||||
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] = None, event_queue: events.AsyncEventQueue = None, engine: Optional['ConductorEngine'] = None, md_content: str = ""):
|
||||
"""
|
||||
Simulates the lifecycle of a single agent working on a ticket.
|
||||
@@ -202,10 +254,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
||||
for i, file_path in enumerate(context_files):
|
||||
try:
|
||||
abs_path = Path(file_path)
|
||||
if not abs_path.is_absolute() and engine:
|
||||
# Resolve relative to project base if possible
|
||||
# (This is a bit simplified, but helps)
|
||||
pass
|
||||
# (This is a bit simplified, but helps)
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
if i == 0:
|
||||
@@ -229,6 +278,22 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
||||
"start your response with 'BLOCKED' and explain why."
|
||||
)
|
||||
|
||||
# HITL Clutch: call confirm_spawn if event_queue is provided
|
||||
if event_queue:
|
||||
approved, modified_prompt, modified_context = confirm_spawn(
|
||||
role="Tier 3 Worker",
|
||||
prompt=user_message,
|
||||
context_md=md_content,
|
||||
event_queue=event_queue,
|
||||
ticket_id=ticket.id
|
||||
)
|
||||
if not approved:
|
||||
ticket.mark_blocked("Spawn rejected by user.")
|
||||
return "BLOCKED: Spawn rejected by user."
|
||||
|
||||
user_message = modified_prompt
|
||||
md_content = modified_context
|
||||
|
||||
# HITL Clutch: pass the queue and ticket_id to confirm_execution
|
||||
def clutch_callback(payload: str) -> bool:
|
||||
if not event_queue:
|
||||
@@ -246,9 +311,6 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
||||
# Update usage in engine if provided
|
||||
if engine:
|
||||
stats = {} # ai_client.get_token_stats() is not available
|
||||
# ai_client provides aggregate stats, for granular tier tracking
|
||||
# we'd need to diff before/after or have ai_client return usage per call.
|
||||
# For Phase 4, we'll use a simplified diff approach.
|
||||
engine.tier_usage["Tier 3"]["input"] += stats.get("prompt_tokens", 0)
|
||||
engine.tier_usage["Tier 3"]["output"] += stats.get("candidates_tokens", 0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user