feat(mma): Implement worker pool and configurable concurrency for DAG engine and mark track 'True Parallel Worker Execution' as complete

This commit is contained in:
2026-03-06 16:55:45 -05:00
parent 616675d7ea
commit 7da2946eff
5 changed files with 251 additions and 55 deletions

View File

@@ -3,15 +3,64 @@ import json
import threading
import time
import traceback
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Callable
from dataclasses import asdict
from src import events
from src import models
from src.models import Ticket, Track, WorkerContext
from src.file_cache import ASTParser
from pathlib import Path
from src.dag_engine import TrackDAG, ExecutionEngine
class WorkerPool:
"""
Manages a pool of worker threads with a concurrency limit.
"""
def __init__(self, max_workers: int = 4):
self.max_workers = max_workers
self._active: dict[str, threading.Thread] = {}
self._lock = threading.Lock()
self._semaphore = threading.Semaphore(max_workers)
def spawn(self, ticket_id: str, target: Callable, args: tuple) -> Optional[threading.Thread]:
"""
Spawns a new worker thread if the pool is not full.
Returns the thread object or None if full.
"""
with self._lock:
if len(self._active) >= self.max_workers:
return None
def wrapper(*a, **kw):
try:
with self._semaphore:
target(*a, **kw)
finally:
with self._lock:
self._active.pop(ticket_id, None)
t = threading.Thread(target=wrapper, args=args, daemon=True)
with self._lock:
self._active[ticket_id] = t
t.start()
return t
def join_all(self, timeout: float = None) -> None:
with self._lock:
threads = list(self._active.values())
for t in threads:
t.join(timeout=timeout)
with self._lock:
self._active.clear()
def get_active_count(self) -> int:
with self._lock:
return len(self._active)
def is_full(self) -> bool:
return self.get_active_count() >= self.max_workers
class ConductorEngine:
"""
Orchestrates the execution of tickets within a track.
@@ -28,6 +77,25 @@ class ConductorEngine:
}
self.dag = TrackDAG(self.track.tickets)
self.engine = ExecutionEngine(self.dag, auto_queue=auto_queue)
# Load MMA config
try:
config = models.load_config()
mma_cfg = config.get("mma", {})
max_workers = mma_cfg.get("max_workers", 4)
except Exception:
max_workers = 4
self.pool = WorkerPool(max_workers=max_workers)
self._workers_lock = threading.Lock()
self._active_workers: dict[str, threading.Thread] = {}
self._tier_usage_lock = threading.Lock()
def update_usage(self, tier: str, input_tokens: int, output_tokens: int) -> None:
with self._tier_usage_lock:
if tier in self.tier_usage:
self.tier_usage[tier]["input"] += input_tokens
self.tier_usage[tier]["output"] += output_tokens
def _push_state(self, status: str = "running", active_tier: str = None) -> None:
if not self.event_queue:
@@ -73,34 +141,48 @@ class ConductorEngine:
except KeyError as e:
print(f"Missing required field in ticket definition: {e}")
def run(self, md_content: str = "") -> None:
def run(self, md_content: str = "", max_ticks: Optional[int] = None) -> None:
"""
Main execution loop using the DAG engine.
Args:
md_content: The full markdown context (history + files) for AI workers.
max_ticks: Optional limit on number of iterations (for testing).
"""
self._push_state(status="running", active_tier="Tier 2 (Tech Lead)")
import sys
tick_count = 0
while True:
if max_ticks is not None and tick_count >= max_ticks:
break
tick_count += 1
# 1. Identify ready tasks
ready_tasks = self.engine.tick()
# 2. Check for completion or blockage
if not ready_tasks:
all_done = all(t.status == "completed" for t in self.track.tickets)
if all_done:
print("Track completed successfully.")
self._push_state(status="done", active_tier=None)
else:
# Check if any tasks are in-progress or could be ready
if any(t.status == "in_progress" for t in self.track.tickets):
# Wait for any active pool threads to finish before declaring done
self.pool.join_all(timeout=5)
if all(t.status == "completed" for t in self.track.tickets):
print("Track completed successfully.")
self._push_state(status="done", active_tier=None)
break
# Check if any tasks are in-progress
if any(t.status == "in_progress" for t in self.track.tickets) or self.pool.get_active_count() > 0:
# Wait for tasks to complete
time.sleep(1)
continue
print("No more executable tickets. Track is blocked or finished.")
self._push_state(status="blocked", active_tier=None)
time.sleep(1)
continue
print("No more executable tickets. Track is blocked or finished.")
self._push_state(status="blocked", active_tier=None)
break
# 3. Process ready tasks
to_run = [t for t in ready_tasks if t.status == "in_progress" or (not t.step_mode and self.engine.auto_queue)]
# Only include those that should be running: either already in_progress or todo + auto_queue
to_run = [t for t in ready_tasks if t.status == "in_progress" or (t.status == "todo" and not t.step_mode and self.engine.auto_queue)]
# Handle those awaiting approval
for ticket in ready_tasks:
@@ -110,44 +192,45 @@ class ConductorEngine:
time.sleep(1)
if to_run:
threads = []
for ticket in to_run:
ticket.status = "in_progress"
print(f"Executing ticket {ticket.id}: {ticket.description}")
self._push_state(active_tier=f"Tier 3 (Worker): {ticket.id}")
# Escalation logic based on retry_count
models = ["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-3.1-pro-preview"]
model_idx = min(ticket.retry_count, len(models) - 1)
model_name = models[model_idx]
if ticket.status == "todo":
# Only spawn if pool has capacity
if self.pool.is_full():
continue
# Escalation logic based on retry_count
models_list = ["gemini-2.5-flash-lite", "gemini-2.5-flash", "gemini-3.1-pro-preview"]
model_idx = min(ticket.retry_count, len(models_list) - 1)
model_name = models_list[model_idx]
context = WorkerContext(
ticket_id=ticket.id,
model_name=model_name,
messages=[]
)
context_files = ticket.context_requirements if ticket.context_requirements else None
t = threading.Thread(
target=run_worker_lifecycle,
args=(ticket, context, context_files, self.event_queue, self, md_content),
daemon=True
)
threads.append(t)
t.start()
for t in threads:
t.join()
# 4. Retry and escalation logic
for ticket in to_run:
if ticket.status == 'blocked':
if ticket.get('retry_count', 0) < 2:
ticket.retry_count += 1
ticket.status = 'todo'
print(f"Ticket {ticket.id} BLOCKED. Escalating to {models[min(ticket.retry_count, len(models)-1)]} and retrying...")
self._push_state(active_tier="Tier 2 (Tech Lead)")
context = WorkerContext(
ticket_id=ticket.id,
model_name=model_name,
messages=[]
)
context_files = ticket.context_requirements if ticket.context_requirements else None
spawned = self.pool.spawn(
ticket.id,
run_worker_lifecycle,
(ticket, context, context_files, self.event_queue, self, md_content)
)
if spawned:
ticket.status = "in_progress"
print(f"Executing ticket {ticket.id}: {ticket.description}")
self._push_state(active_tier=f"Tier 3 (Worker): {ticket.id}")
# 4. Retry and escalation logic for blocked tickets
# (Check tickets that recently became blocked)
for ticket in self.track.tickets:
if ticket.status == 'blocked' and ticket.retry_count < 2:
# Simple check to see if we should retry
# Escalation is currently handled inside run_worker_lifecycle or via manual retry
pass
self._push_state(active_tier="Tier 2 (Tech Lead)")
time.sleep(1)
def _queue_put(event_queue: events.SyncEventQueue, event_name: str, payload) -> None:
"""Thread-safe helper to push an event to the SyncEventQueue from a worker thread."""
@@ -334,8 +417,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
_resp_entries = [e for e in _new_comms if e.get("direction") == "IN" and e.get("kind") == "response"]
_in_tokens = sum(e.get("payload", {}).get("usage", {}).get("input_tokens", 0) for e in _resp_entries)
_out_tokens = sum(e.get("payload", {}).get("usage", {}).get("output_tokens", 0) for e in _resp_entries)
engine.tier_usage["Tier 3"]["input"] += _in_tokens
engine.tier_usage["Tier 3"]["output"] += _out_tokens
engine.update_usage("Tier 3", _in_tokens, _out_tokens)
if "BLOCKED" in response.upper():
ticket.mark_blocked(response)
else: