Compare commits
3 Commits
beae82860a
...
597e6b51e2
| Author | SHA1 | Date | |
|---|---|---|---|
| 597e6b51e2 | |||
| da011fbc57 | |||
| 5f7909121d |
@@ -89,6 +89,7 @@ class ConductorEngine:
|
|||||||
self.pool = WorkerPool(max_workers=max_workers)
|
self.pool = WorkerPool(max_workers=max_workers)
|
||||||
self._workers_lock = threading.Lock()
|
self._workers_lock = threading.Lock()
|
||||||
self._active_workers: dict[str, threading.Thread] = {}
|
self._active_workers: dict[str, threading.Thread] = {}
|
||||||
|
self._abort_events: dict[str, threading.Event] = {}
|
||||||
self._tier_usage_lock = threading.Lock()
|
self._tier_usage_lock = threading.Lock()
|
||||||
|
|
||||||
def update_usage(self, tier: str, input_tokens: int, output_tokens: int) -> None:
|
def update_usage(self, tier: str, input_tokens: int, output_tokens: int) -> None:
|
||||||
@@ -97,6 +98,21 @@ class ConductorEngine:
|
|||||||
self.tier_usage[tier]["input"] += input_tokens
|
self.tier_usage[tier]["input"] += input_tokens
|
||||||
self.tier_usage[tier]["output"] += output_tokens
|
self.tier_usage[tier]["output"] += output_tokens
|
||||||
|
|
||||||
|
def kill_worker(self, ticket_id: str) -> None:
|
||||||
|
"""Sets the abort event for a worker and attempts to join its thread."""
|
||||||
|
if ticket_id in self._abort_events:
|
||||||
|
print(f"[MMA] Setting abort event for {ticket_id}")
|
||||||
|
self._abort_events[ticket_id].set()
|
||||||
|
|
||||||
|
with self._workers_lock:
|
||||||
|
thread = self._active_workers.get(ticket_id)
|
||||||
|
|
||||||
|
if thread:
|
||||||
|
print(f"[MMA] Joining thread for {ticket_id}")
|
||||||
|
thread.join(timeout=1.0)
|
||||||
|
with self._workers_lock:
|
||||||
|
self._active_workers.pop(ticket_id, None)
|
||||||
|
|
||||||
def _push_state(self, status: str = "running", active_tier: str = None) -> None:
|
def _push_state(self, status: str = "running", active_tier: str = None) -> None:
|
||||||
if not self.event_queue:
|
if not self.event_queue:
|
||||||
return
|
return
|
||||||
@@ -210,6 +226,9 @@ class ConductorEngine:
|
|||||||
)
|
)
|
||||||
context_files = ticket.context_requirements if ticket.context_requirements else None
|
context_files = ticket.context_requirements if ticket.context_requirements else None
|
||||||
|
|
||||||
|
# Initialize abort event before spawning
|
||||||
|
self._abort_events[ticket.id] = threading.Event()
|
||||||
|
|
||||||
spawned = self.pool.spawn(
|
spawned = self.pool.spawn(
|
||||||
ticket.id,
|
ticket.id,
|
||||||
run_worker_lifecycle,
|
run_worker_lifecycle,
|
||||||
@@ -217,6 +236,8 @@ class ConductorEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if spawned:
|
if spawned:
|
||||||
|
with self._workers_lock:
|
||||||
|
self._active_workers[ticket.id] = spawned
|
||||||
ticket.status = "in_progress"
|
ticket.status = "in_progress"
|
||||||
_queue_put(self.event_queue, "ticket_started", {"ticket_id": ticket.id, "timestamp": time.time()})
|
_queue_put(self.event_queue, "ticket_started", {"ticket_id": ticket.id, "timestamp": time.time()})
|
||||||
print(f"Executing ticket {ticket.id}: {ticket.description}")
|
print(f"Executing ticket {ticket.id}: {ticket.description}")
|
||||||
@@ -313,6 +334,17 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
|||||||
# Enforce Context Amnesia: each ticket starts with a clean slate.
|
# Enforce Context Amnesia: each ticket starts with a clean slate.
|
||||||
ai_client.reset_session()
|
ai_client.reset_session()
|
||||||
ai_client.set_provider(ai_client.get_provider(), context.model_name)
|
ai_client.set_provider(ai_client.get_provider(), context.model_name)
|
||||||
|
|
||||||
|
# Check for abort BEFORE any major work
|
||||||
|
if engine and hasattr(engine, "_abort_events"):
|
||||||
|
abort_event = engine._abort_events.get(ticket.id)
|
||||||
|
if abort_event and abort_event.is_set():
|
||||||
|
print(f"[MMA] Ticket {ticket.id} aborted early.")
|
||||||
|
ticket.status = "killed"
|
||||||
|
if event_queue:
|
||||||
|
_queue_put(event_queue, "ticket_completed", {"ticket_id": ticket.id, "timestamp": time.time()})
|
||||||
|
return "ABORTED"
|
||||||
|
|
||||||
context_injection = ""
|
context_injection = ""
|
||||||
tokens_before = 0
|
tokens_before = 0
|
||||||
tokens_after = 0
|
tokens_after = 0
|
||||||
@@ -379,6 +411,12 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
|||||||
def clutch_callback(payload: str) -> bool:
|
def clutch_callback(payload: str) -> bool:
|
||||||
if not event_queue:
|
if not event_queue:
|
||||||
return True
|
return True
|
||||||
|
# SECONDARY CHECK: Before executing any tool, check abort
|
||||||
|
if engine and hasattr(engine, "_abort_events"):
|
||||||
|
abort_event = engine._abort_events.get(ticket.id)
|
||||||
|
if abort_event and abort_event.is_set():
|
||||||
|
print(f"[MMA] Ticket {ticket.id} aborted during clutch_callback.")
|
||||||
|
return False # Reject tool execution
|
||||||
return confirm_execution(payload, event_queue, ticket.id)
|
return confirm_execution(payload, event_queue, ticket.id)
|
||||||
|
|
||||||
def stream_callback(chunk: str) -> None:
|
def stream_callback(chunk: str) -> None:
|
||||||
@@ -419,6 +457,17 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
|||||||
finally:
|
finally:
|
||||||
ai_client.comms_log_callback = old_comms_cb
|
ai_client.comms_log_callback = old_comms_cb
|
||||||
ai_client.set_current_tier(None)
|
ai_client.set_current_tier(None)
|
||||||
|
|
||||||
|
# THIRD CHECK: After blocking send() returns
|
||||||
|
if engine and hasattr(engine, "_abort_events"):
|
||||||
|
abort_event = engine._abort_events.get(ticket.id)
|
||||||
|
if abort_event and abort_event.is_set():
|
||||||
|
print(f"[MMA] Ticket {ticket.id} aborted after AI call.")
|
||||||
|
ticket.status = "killed"
|
||||||
|
if event_queue:
|
||||||
|
_queue_put(event_queue, "ticket_completed", {"ticket_id": ticket.id, "timestamp": time.time()})
|
||||||
|
return "ABORTED"
|
||||||
|
|
||||||
if event_queue:
|
if event_queue:
|
||||||
# Push via "response" event type — _process_event_queue wraps this
|
# Push via "response" event type — _process_event_queue wraps this
|
||||||
# as {"action": "handle_ai_response", "payload": ...} for the GUI.
|
# as {"action": "handle_ai_response", "payload": ...} for the GUI.
|
||||||
|
|||||||
32
tests/test_conductor_abort_event.py
Normal file
32
tests/test_conductor_abort_event.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from src.multi_agent_conductor import ConductorEngine
|
||||||
|
from src.models import Ticket, Track
|
||||||
|
import threading
|
||||||
|
|
||||||
|
def test_conductor_abort_event_populated():
|
||||||
|
"""
|
||||||
|
Test that ConductorEngine populates _abort_events when spawning a worker.
|
||||||
|
"""
|
||||||
|
# 1. Mock WorkerPool.spawn to return a mock thread
|
||||||
|
with patch('src.multi_agent_conductor.WorkerPool.spawn') as mock_spawn:
|
||||||
|
mock_spawn.return_value = MagicMock(spec=threading.Thread)
|
||||||
|
|
||||||
|
# 2. Mock ExecutionEngine.tick
|
||||||
|
with patch('src.multi_agent_conductor.ExecutionEngine.tick') as mock_tick:
|
||||||
|
ticket_id = "test-ticket"
|
||||||
|
ticket = Ticket(id=ticket_id, description="Test description", status="todo")
|
||||||
|
mock_tick.return_value = [ticket]
|
||||||
|
|
||||||
|
mock_track = Track(id="test-track", description="Test Track", tickets=[ticket])
|
||||||
|
|
||||||
|
# 3. Set auto_queue=True
|
||||||
|
mock_queue = MagicMock()
|
||||||
|
engine = ConductorEngine(track=mock_track, event_queue=mock_queue, auto_queue=True)
|
||||||
|
|
||||||
|
# 4. Call ConductorEngine.run(max_ticks=1)
|
||||||
|
engine.run(max_ticks=1)
|
||||||
|
|
||||||
|
# 5. Assert that self._abort_events has an entry for the ticket ID
|
||||||
|
assert ticket_id in engine._abort_events
|
||||||
|
assert isinstance(engine._abort_events[ticket_id], threading.Event)
|
||||||
53
tests/test_conductor_engine_abort.py
Normal file
53
tests/test_conductor_engine_abort.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from src.multi_agent_conductor import ConductorEngine
|
||||||
|
from src.models import Track
|
||||||
|
|
||||||
|
def test_conductor_engine_initializes_empty_worker_and_abort_dicts() -> None:
|
||||||
|
"""
|
||||||
|
Test that ConductorEngine correctly initializes _active_workers and _abort_events as empty dictionaries.
|
||||||
|
"""
|
||||||
|
# Mock the track object
|
||||||
|
mock_track = MagicMock(spec=Track)
|
||||||
|
mock_track.tickets = []
|
||||||
|
|
||||||
|
# Initialize ConductorEngine
|
||||||
|
engine = ConductorEngine(track=mock_track)
|
||||||
|
|
||||||
|
# Verify _active_workers and _abort_events are empty dictionaries
|
||||||
|
assert engine._active_workers == {}
|
||||||
|
assert engine._abort_events == {}
|
||||||
|
|
||||||
|
def test_kill_worker_sets_abort_and_joins_thread() -> None:
|
||||||
|
"""
|
||||||
|
Test kill_worker: mock a running thread in _active_workers, call kill_worker,
|
||||||
|
assert abort_event is set and thread is joined.
|
||||||
|
"""
|
||||||
|
mock_track = MagicMock(spec=Track)
|
||||||
|
mock_track.tickets = []
|
||||||
|
engine = ConductorEngine(track=mock_track)
|
||||||
|
|
||||||
|
ticket_id = "test-ticket"
|
||||||
|
abort_event = threading.Event()
|
||||||
|
engine._abort_events[ticket_id] = abort_event
|
||||||
|
|
||||||
|
# Create a thread that waits for the abort event
|
||||||
|
def worker():
|
||||||
|
abort_event.wait(timeout=2.0)
|
||||||
|
|
||||||
|
thread = threading.Thread(target=worker)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
with engine._workers_lock:
|
||||||
|
engine._active_workers[ticket_id] = thread
|
||||||
|
|
||||||
|
# Call kill_worker
|
||||||
|
engine.kill_worker(ticket_id)
|
||||||
|
|
||||||
|
# Assertions
|
||||||
|
assert abort_event.is_set()
|
||||||
|
assert not thread.is_alive()
|
||||||
|
with engine._workers_lock:
|
||||||
|
assert ticket_id not in engine._active_workers
|
||||||
40
tests/test_run_worker_lifecycle_abort.py
Normal file
40
tests/test_run_worker_lifecycle_abort.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import threading
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from src.multi_agent_conductor import run_worker_lifecycle
|
||||||
|
from src.models import Ticket, WorkerContext
|
||||||
|
|
||||||
|
class TestRunWorkerLifecycleAbort(unittest.TestCase):
|
||||||
|
def test_run_worker_lifecycle_returns_early_on_abort(self):
|
||||||
|
"""
|
||||||
|
Test that run_worker_lifecycle returns early and marks ticket as 'killed'
|
||||||
|
if the abort event is set for the ticket.
|
||||||
|
"""
|
||||||
|
# Mock ai_client.send
|
||||||
|
with patch('src.ai_client.send') as mock_send:
|
||||||
|
# Mock ticket and context
|
||||||
|
ticket = Ticket(id="T-001", description="Test task")
|
||||||
|
ticket = Ticket(id="T-001", description="Test task")
|
||||||
|
context = WorkerContext(ticket_id="T-001", model_name="test-model")
|
||||||
|
# Mock engine with _abort_events dict
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
abort_event = threading.Event()
|
||||||
|
mock_engine._abort_events = {"T-001": abort_event}
|
||||||
|
|
||||||
|
# Set abort event
|
||||||
|
abort_event.set()
|
||||||
|
|
||||||
|
# Call run_worker_lifecycle
|
||||||
|
# md_content is expected to be passed if called like in ConductorEngine
|
||||||
|
run_worker_lifecycle(ticket, context, engine=mock_engine, md_content="test context")
|
||||||
|
|
||||||
|
# Assert ticket status is 'killed'
|
||||||
|
self.assertEqual(ticket.status, "killed")
|
||||||
|
|
||||||
|
# Also assert ai_client.send was NOT called
|
||||||
|
mock_send.assert_not_called()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user