Files
manual_slop/src/multi_agent_conductor.py
Ed_ 90670b9671 feat(tier4): Integrate patch generation into GUI workflow
- Add patch_callback parameter throughout the tool execution chain
- Add _render_patch_modal() to gui_2.py with colored diff display
- Add patch modal state variables to App.__init__
- Add request_patch_from_tier4() to trigger patch generation
- Add run_tier4_patch_callback() to ai_client.py
- Update shell_runner to accept and execute patch_callback
- Diff colors: green for additions, red for deletions, cyan for headers
- 36 tests passing
2026-03-07 00:26:34 -05:00

445 lines
16 KiB
Python

from src import ai_client
import json
import threading
import time
import traceback
from typing import List, Optional, Tuple, Callable
from dataclasses import asdict
from src import events
from src import models
from src.models import Ticket, Track, WorkerContext
from src.file_cache import ASTParser
from pathlib import Path
from src.dag_engine import TrackDAG, ExecutionEngine
class WorkerPool:
"""
Manages a pool of worker threads with a concurrency limit.
"""
def __init__(self, max_workers: int = 4):
self.max_workers = max_workers
self._active: dict[str, threading.Thread] = {}
self._lock = threading.Lock()
self._semaphore = threading.Semaphore(max_workers)
def spawn(self, ticket_id: str, target: Callable, args: tuple) -> Optional[threading.Thread]:
"""
Spawns a new worker thread if the pool is not full.
Returns the thread object or None if full.
"""
with self._lock:
if len(self._active) >= self.max_workers:
return None
def wrapper(*a, **kw):
try:
with self._semaphore:
target(*a, **kw)
finally:
with self._lock:
self._active.pop(ticket_id, None)
t = threading.Thread(target=wrapper, args=args, daemon=True)
with self._lock:
self._active[ticket_id] = t
t.start()
return t
def join_all(self, timeout: float = None) -> None:
with self._lock:
threads = list(self._active.values())
for t in threads:
t.join(timeout=timeout)
with self._lock:
self._active.clear()
def get_active_count(self) -> int:
with self._lock:
return len(self._active)
def is_full(self) -> bool:
return self.get_active_count() >= self.max_workers
class ConductorEngine:
"""
Orchestrates the execution of tickets within a track.
"""
def __init__(self, track: Track, event_queue: Optional[events.SyncEventQueue] = None, auto_queue: bool = False) -> None:
self.track = track
self.event_queue = event_queue
self.tier_usage = {
"Tier 1": {"input": 0, "output": 0, "model": "gemini-3.1-pro-preview"},
"Tier 2": {"input": 0, "output": 0, "model": "gemini-3-flash-preview"},
"Tier 3": {"input": 0, "output": 0, "model": "gemini-2.5-flash-lite"},
"Tier 4": {"input": 0, "output": 0, "model": "gemini-2.5-flash-lite"},
}
self.dag = TrackDAG(self.track.tickets)
self.engine = ExecutionEngine(self.dag, auto_queue=auto_queue)
# Load MMA config
try:
config = models.load_config()
mma_cfg = config.get("mma", {})
max_workers = mma_cfg.get("max_workers", 4)
except Exception:
max_workers = 4
self.pool = WorkerPool(max_workers=max_workers)
self._workers_lock = threading.Lock()
self._active_workers: dict[str, threading.Thread] = {}
self._tier_usage_lock = threading.Lock()
def update_usage(self, tier: str, input_tokens: int, output_tokens: int) -> None:
with self._tier_usage_lock:
if tier in self.tier_usage:
self.tier_usage[tier]["input"] += input_tokens
self.tier_usage[tier]["output"] += output_tokens
def _push_state(self, status: str = "running", active_tier: str = None) -> None:
if not self.event_queue:
return
payload = {
"status": status,
"active_tier": active_tier,
"tier_usage": self.tier_usage,
"track": {
"id": self.track.id,
"title": self.track.description,
},
"tickets": [asdict(t) for t in self.track.tickets]
}
self.event_queue.put("mma_state_update", payload)
def parse_json_tickets(self, json_str: str) -> None:
"""
Parses a JSON string of ticket definitions (Godot ECS Flat List format)
and populates the Track's ticket list.
"""
try:
data = json.loads(json_str)
if not isinstance(data, list):
print("Error: JSON input must be a list of ticket definitions.")
return
for ticket_data in data:
# Construct Ticket object, using defaults for optional fields
ticket = Ticket(
id=ticket_data["id"],
description=ticket_data["description"],
status=ticket_data.get("status", "todo"),
assigned_to=ticket_data.get("assigned_to", "unassigned"),
depends_on=ticket_data.get("depends_on", []),
step_mode=ticket_data.get("step_mode", False)
)
self.track.tickets.append(ticket)
# Rebuild DAG and Engine after parsing new tickets
self.dag = TrackDAG(self.track.tickets)
self.engine = ExecutionEngine(self.dag, auto_queue=self.engine.auto_queue)
except json.JSONDecodeError as e:
print(f"Error parsing JSON tickets: {e}")
except KeyError as e:
print(f"Missing required field in ticket definition: {e}")
def run(self, md_content: str = "", max_ticks: Optional[int] = None) -> None:
"""
Main execution loop using the DAG engine.
Args:
md_content: The full markdown context (history + files) for AI workers.
max_ticks: Optional limit on number of iterations (for testing).
"""
self._push_state(status="running", active_tier="Tier 2 (Tech Lead)")
import sys
tick_count = 0
while True:
if max_ticks is not None and tick_count >= max_ticks:
break
tick_count += 1
# 1. Identify ready tasks
ready_tasks = self.engine.tick()
# 2. Check for completion or blockage
if not ready_tasks:
all_done = all(t.status == "completed" for t in self.track.tickets)
if all_done:
# Wait for any active pool threads to finish before declaring done
self.pool.join_all(timeout=5)
if all(t.status == "completed" for t in self.track.tickets):
print("Track completed successfully.")
self._push_state(status="done", active_tier=None)
break
# Check if any tasks are in-progress
if any(t.status == "in_progress" for t in self.track.tickets) or self.pool.get_active_count() > 0:
# Wait for tasks to complete
time.sleep(1)
continue
print("No more executable tickets. Track is blocked or finished.")
self._push_state(status="blocked", active_tier=None)
break
# 3. Process ready tasks
# Only include those that should be running: either already in_progress or todo + auto_queue
to_run = [t for t in ready_tasks if t.status == "in_progress" or (t.status == "todo" and not t.step_mode and self.engine.auto_queue)]
# Handle those awaiting approval
for ticket in ready_tasks:
if ticket not in to_run and ticket.status == "todo":
print(f"Ticket {ticket.id} is ready and awaiting approval.")
self._push_state(active_tier=f"Awaiting Approval: {ticket.id}")
time.sleep(1)
if to_run:
for ticket in to_run:
if ticket.status == "todo":
# Only spawn if pool has capacity
if self.pool.is_full():
continue
# Escalation logic based on retry_count
models_list = ["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-3.1-pro-preview"]
model_idx = min(ticket.retry_count, len(models_list) - 1)
model_name = models_list[model_idx]
context = WorkerContext(
ticket_id=ticket.id,
model_name=model_name,
messages=[]
)
context_files = ticket.context_requirements if ticket.context_requirements else None
spawned = self.pool.spawn(
ticket.id,
run_worker_lifecycle,
(ticket, context, context_files, self.event_queue, self, md_content)
)
if spawned:
ticket.status = "in_progress"
print(f"Executing ticket {ticket.id}: {ticket.description}")
self._push_state(active_tier=f"Tier 3 (Worker): {ticket.id}")
# 4. Retry and escalation logic for blocked tickets
# (Check tickets that recently became blocked)
for ticket in self.track.tickets:
if ticket.status == 'blocked' and ticket.retry_count < 2:
# Simple check to see if we should retry
# Escalation is currently handled inside run_worker_lifecycle or via manual retry
pass
self._push_state(active_tier="Tier 2 (Tech Lead)")
time.sleep(1)
def _queue_put(event_queue: events.SyncEventQueue, event_name: str, payload) -> None:
"""Thread-safe helper to push an event to the SyncEventQueue from a worker thread."""
event_queue.put(event_name, payload)
def confirm_execution(payload: str, event_queue: events.SyncEventQueue, ticket_id: str) -> bool:
"""
Pushes an approval request to the GUI and waits for response.
"""
dialog_container = [None]
task = {
"action": "mma_step_approval",
"ticket_id": ticket_id,
"payload": payload,
"dialog_container": dialog_container
}
_queue_put(event_queue, "mma_step_approval", task)
# Wait for the GUI to create the dialog and for the user to respond
start = time.time()
while dialog_container[0] is None and time.time() - start < 60:
time.sleep(0.1)
if dialog_container[0]:
approved, final_payload = dialog_container[0].wait()
return approved
return False
def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.SyncEventQueue, ticket_id: str) -> Tuple[bool, str, str]:
"""
Pushes a spawn approval request to the GUI and waits for response.
Returns (approved, modified_prompt, modified_context)
"""
dialog_container = [None]
task = {
"action": "mma_spawn_approval",
"ticket_id": ticket_id,
"role": role,
"prompt": prompt,
"context_md": context_md,
"dialog_container": dialog_container
}
_queue_put(event_queue, "mma_spawn_approval", task)
# Wait for the GUI to create the dialog and for the user to respond
start = time.time()
while dialog_container[0] is None and time.time() - start < 60:
time.sleep(0.1)
if dialog_container[0]:
res = dialog_container[0].wait()
if isinstance(res, dict):
approved = res.get("approved", False)
abort = res.get("abort", False)
modified_prompt = res.get("prompt", prompt)
modified_context = res.get("context_md", context_md)
return approved and not abort, modified_prompt, modified_context
else:
# Fallback for old tuple style if any
approved, final_payload = res
modified_prompt = prompt
modified_context = context_md
if isinstance(final_payload, dict):
modified_prompt = final_payload.get("prompt", prompt)
modified_context = final_payload.get("context_md", context_md)
return approved, modified_prompt, modified_context
return False, prompt, context_md
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] | None = None, event_queue: events.SyncEventQueue | None = None, engine: Optional['ConductorEngine'] = None, md_content: str = "") -> None:
"""
Simulates the lifecycle of a single agent working on a ticket.
Calls the AI client and updates the ticket status based on the response.
Args:
ticket: The ticket to process.
context: The worker context.
context_files: List of files to include in the context.
event_queue: Queue for pushing state updates and receiving approvals.
engine: The conductor engine.
md_content: The markdown context (history + files) for AI workers.
"""
# Enforce Context Amnesia: each ticket starts with a clean slate.
ai_client.reset_session()
ai_client.set_provider(ai_client.get_provider(), context.model_name)
context_injection = ""
tokens_before = 0
tokens_after = 0
def _count_tokens(text: str) -> int:
return len(text) // 4 # Rough estimate
if context_files:
parser = ASTParser(language="python")
for i, file_path in enumerate(context_files):
try:
Path(file_path)
# (This is a bit simplified, but helps)
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
tokens_before += _count_tokens(content)
if i == 0:
view = parser.get_curated_view(content, path=file_path)
elif ticket.target_file and Path(file_path).resolve() == Path(ticket.target_file).resolve() and ticket.target_symbols:
view = parser.get_targeted_view(content, ticket.target_symbols, path=file_path)
else:
view = parser.get_skeleton(content, path=file_path)
tokens_after += _count_tokens(view)
context_injection += f"\nFile: {file_path}\n{view}\n"
except Exception as e:
context_injection += f"\nError reading {file_path}: {e}\n"
if tokens_before > 0:
reduction = ((tokens_before - tokens_after) / tokens_before) * 100
print(f"[MMA] Context pruning for {ticket.id}: {tokens_before} -> {tokens_after} tokens ({reduction:.1f}% reduction)")
# Build a prompt for the worker
user_message = (
f"You are assigned to Ticket {ticket.id}.\n"
f"Task Description: {ticket.description}\n"
)
if context_injection:
user_message += f"\nContext Files:\n{context_injection}\n"
user_message += (
"Please complete this task. If you are blocked and cannot proceed, "
"start your response with 'BLOCKED' and explain why."
)
# HITL Clutch: call confirm_spawn if event_queue is provided
if event_queue:
approved, modified_prompt, modified_context = confirm_spawn(
role="Tier 3 Worker",
prompt=user_message,
context_md=md_content,
event_queue=event_queue,
ticket_id=ticket.id
)
if not approved:
ticket.mark_blocked("Spawn rejected by user.")
return "BLOCKED: Spawn rejected by user."
user_message = modified_prompt
md_content = modified_context
# HITL Clutch: pass the queue and ticket_id to confirm_execution
def clutch_callback(payload: str) -> bool:
if not event_queue:
return True
return confirm_execution(payload, event_queue, ticket.id)
def stream_callback(chunk: str) -> None:
if event_queue:
_queue_put(event_queue, 'mma_stream', {'stream_id': f'Tier 3 (Worker): {ticket.id}', 'text': chunk})
old_comms_cb = ai_client.comms_log_callback
def worker_comms_callback(entry: dict) -> None:
if event_queue:
kind = entry.get("kind")
payload = entry.get("payload", {})
chunk = ""
if kind == "tool_call":
chunk = f"\n\n[TOOL CALL] {payload.get('name')}\n{json.dumps(payload.get('script') or payload.get('args'))}\n"
elif kind == "tool_result":
res = str(payload.get("output", ""))
if len(res) > 500: res = res[:500] + "... (truncated)"
chunk = f"\n[TOOL RESULT]\n{res}\n"
if chunk:
_queue_put(event_queue, "response", {"text": chunk, "stream_id": f"Tier 3 (Worker): {ticket.id}", "status": "streaming..."})
if old_comms_cb:
old_comms_cb(entry)
ai_client.comms_log_callback = worker_comms_callback
ai_client.set_current_tier("Tier 3")
try:
comms_baseline = len(ai_client.get_comms_log())
response = ai_client.send(
md_content=md_content,
user_message=user_message,
base_dir=".",
pre_tool_callback=clutch_callback if ticket.step_mode else None,
qa_callback=ai_client.run_tier4_analysis,
patch_callback=ai_client.run_tier4_patch_callback,
stream_callback=stream_callback
)
finally:
ai_client.comms_log_callback = old_comms_cb
ai_client.set_current_tier(None)
if event_queue:
# Push via "response" event type — _process_event_queue wraps this
# as {"action": "handle_ai_response", "payload": ...} for the GUI.
try:
response_payload = {
"text": response,
"stream_id": f"Tier 3 (Worker): {ticket.id}",
"status": "done"
}
print(f"[MMA] Pushing Tier 3 response for {ticket.id}, stream_id={response_payload['stream_id']}")
_queue_put(event_queue, "response", response_payload)
except Exception as e:
print(f"[MMA] ERROR pushing response to UI: {e}\n{traceback.format_exc()}")
# Update usage in engine if provided
if engine:
_new_comms = ai_client.get_comms_log()[comms_baseline:]
_resp_entries = [e for e in _new_comms if e.get("direction") == "IN" and e.get("kind") == "response"]
_in_tokens = sum(e.get("payload", {}).get("usage", {}).get("input_tokens", 0) for e in _resp_entries)
_out_tokens = sum(e.get("payload", {}).get("usage", {}).get("output_tokens", 0) for e in _resp_entries)
engine.update_usage("Tier 3", _in_tokens, _out_tokens)
if "BLOCKED" in response.upper():
ticket.mark_blocked(response)
else:
ticket.mark_complete()
return response