feat(conductor): Implement abort checks in worker lifecycle and kill_worker method

This commit is contained in:
2026-03-07 16:06:56 -05:00
parent da011fbc57
commit 597e6b51e2
3 changed files with 119 additions and 0 deletions

View File

@@ -98,6 +98,21 @@ class ConductorEngine:
self.tier_usage[tier]["input"] += input_tokens self.tier_usage[tier]["input"] += input_tokens
self.tier_usage[tier]["output"] += output_tokens self.tier_usage[tier]["output"] += output_tokens
def kill_worker(self, ticket_id: str) -> None:
"""Sets the abort event for a worker and attempts to join its thread."""
if ticket_id in self._abort_events:
print(f"[MMA] Setting abort event for {ticket_id}")
self._abort_events[ticket_id].set()
with self._workers_lock:
thread = self._active_workers.get(ticket_id)
if thread:
print(f"[MMA] Joining thread for {ticket_id}")
thread.join(timeout=1.0)
with self._workers_lock:
self._active_workers.pop(ticket_id, None)
def _push_state(self, status: str = "running", active_tier: str = None) -> None: def _push_state(self, status: str = "running", active_tier: str = None) -> None:
if not self.event_queue: if not self.event_queue:
return return
@@ -221,6 +236,8 @@ class ConductorEngine:
) )
if spawned: if spawned:
with self._workers_lock:
self._active_workers[ticket.id] = spawned
ticket.status = "in_progress" ticket.status = "in_progress"
_queue_put(self.event_queue, "ticket_started", {"ticket_id": ticket.id, "timestamp": time.time()}) _queue_put(self.event_queue, "ticket_started", {"ticket_id": ticket.id, "timestamp": time.time()})
print(f"Executing ticket {ticket.id}: {ticket.description}") print(f"Executing ticket {ticket.id}: {ticket.description}")
@@ -317,6 +334,17 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
# Enforce Context Amnesia: each ticket starts with a clean slate. # Enforce Context Amnesia: each ticket starts with a clean slate.
ai_client.reset_session() ai_client.reset_session()
ai_client.set_provider(ai_client.get_provider(), context.model_name) ai_client.set_provider(ai_client.get_provider(), context.model_name)
# Check for abort BEFORE any major work
if engine and hasattr(engine, "_abort_events"):
abort_event = engine._abort_events.get(ticket.id)
if abort_event and abort_event.is_set():
print(f"[MMA] Ticket {ticket.id} aborted early.")
ticket.status = "killed"
if event_queue:
_queue_put(event_queue, "ticket_completed", {"ticket_id": ticket.id, "timestamp": time.time()})
return "ABORTED"
context_injection = "" context_injection = ""
tokens_before = 0 tokens_before = 0
tokens_after = 0 tokens_after = 0
@@ -383,6 +411,12 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
def clutch_callback(payload: str) -> bool: def clutch_callback(payload: str) -> bool:
if not event_queue: if not event_queue:
return True return True
# SECONDARY CHECK: Before executing any tool, check abort
if engine and hasattr(engine, "_abort_events"):
abort_event = engine._abort_events.get(ticket.id)
if abort_event and abort_event.is_set():
print(f"[MMA] Ticket {ticket.id} aborted during clutch_callback.")
return False # Reject tool execution
return confirm_execution(payload, event_queue, ticket.id) return confirm_execution(payload, event_queue, ticket.id)
def stream_callback(chunk: str) -> None: def stream_callback(chunk: str) -> None:
@@ -423,6 +457,17 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
finally: finally:
ai_client.comms_log_callback = old_comms_cb ai_client.comms_log_callback = old_comms_cb
ai_client.set_current_tier(None) ai_client.set_current_tier(None)
# THIRD CHECK: After blocking send() returns
if engine and hasattr(engine, "_abort_events"):
abort_event = engine._abort_events.get(ticket.id)
if abort_event and abort_event.is_set():
print(f"[MMA] Ticket {ticket.id} aborted after AI call.")
ticket.status = "killed"
if event_queue:
_queue_put(event_queue, "ticket_completed", {"ticket_id": ticket.id, "timestamp": time.time()})
return "ABORTED"
if event_queue: if event_queue:
# Push via "response" event type — _process_event_queue wraps this # Push via "response" event type — _process_event_queue wraps this
# as {"action": "handle_ai_response", "payload": ...} for the GUI. # as {"action": "handle_ai_response", "payload": ...} for the GUI.

View File

@@ -1,5 +1,7 @@
import pytest import pytest
from unittest.mock import MagicMock from unittest.mock import MagicMock
import threading
import time
from src.multi_agent_conductor import ConductorEngine from src.multi_agent_conductor import ConductorEngine
from src.models import Track from src.models import Track
@@ -17,3 +19,35 @@ def test_conductor_engine_initializes_empty_worker_and_abort_dicts() -> None:
# Verify _active_workers and _abort_events are empty dictionaries # Verify _active_workers and _abort_events are empty dictionaries
assert engine._active_workers == {} assert engine._active_workers == {}
assert engine._abort_events == {} assert engine._abort_events == {}
def test_kill_worker_sets_abort_and_joins_thread() -> None:
"""
Test kill_worker: mock a running thread in _active_workers, call kill_worker,
assert abort_event is set and thread is joined.
"""
mock_track = MagicMock(spec=Track)
mock_track.tickets = []
engine = ConductorEngine(track=mock_track)
ticket_id = "test-ticket"
abort_event = threading.Event()
engine._abort_events[ticket_id] = abort_event
# Create a thread that waits for the abort event
def worker():
abort_event.wait(timeout=2.0)
thread = threading.Thread(target=worker)
thread.start()
with engine._workers_lock:
engine._active_workers[ticket_id] = thread
# Call kill_worker
engine.kill_worker(ticket_id)
# Assertions
assert abort_event.is_set()
assert not thread.is_alive()
with engine._workers_lock:
assert ticket_id not in engine._active_workers

View File

@@ -0,0 +1,40 @@
import unittest
from unittest.mock import MagicMock, patch
import threading
import json
import time
from src.multi_agent_conductor import run_worker_lifecycle
from src.models import Ticket, WorkerContext
class TestRunWorkerLifecycleAbort(unittest.TestCase):
def test_run_worker_lifecycle_returns_early_on_abort(self):
"""
Test that run_worker_lifecycle returns early and marks ticket as 'killed'
if the abort event is set for the ticket.
"""
# Mock ai_client.send
with patch('src.ai_client.send') as mock_send:
# Mock ticket and context
ticket = Ticket(id="T-001", description="Test task")
ticket = Ticket(id="T-001", description="Test task")
context = WorkerContext(ticket_id="T-001", model_name="test-model")
# Mock engine with _abort_events dict
mock_engine = MagicMock()
abort_event = threading.Event()
mock_engine._abort_events = {"T-001": abort_event}
# Set abort event
abort_event.set()
# Call run_worker_lifecycle
# md_content is expected to be passed if called like in ConductorEngine
run_worker_lifecycle(ticket, context, engine=mock_engine, md_content="test context")
# Assert ticket status is 'killed'
self.assertEqual(ticket.status, "killed")
# Also assert ai_client.send was NOT called
mock_send.assert_not_called()
if __name__ == "__main__":
unittest.main()