fix(mma): Replace token stats stub with real comms log extraction in run_worker_lifecycle
Task 2.1 of mma_pipeline_fix_20260301: capture comms baseline before send(), then sum input_tokens/output_tokens from IN/response entries to populate engine.tier_usage['Tier 3'].
This commit is contained in:
@@ -268,6 +268,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
||||
if not event_queue:
|
||||
return True
|
||||
return confirm_execution(payload, event_queue, ticket.id, loop=loop)
|
||||
comms_baseline = len(ai_client.get_comms_log())
|
||||
response = ai_client.send(
|
||||
md_content=md_content,
|
||||
user_message=user_message,
|
||||
@@ -295,9 +296,12 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
||||
|
||||
# Update usage in engine if provided
|
||||
if engine:
|
||||
stats = {} # ai_client.get_token_stats() is not available
|
||||
engine.tier_usage["Tier 3"]["input"] += stats.get("prompt_tokens", 0)
|
||||
engine.tier_usage["Tier 3"]["output"] += stats.get("candidates_tokens", 0)
|
||||
_new_comms = ai_client.get_comms_log()[comms_baseline:]
|
||||
_resp_entries = [e for e in _new_comms if e.get("direction") == "IN" and e.get("kind") == "response"]
|
||||
_in_tokens = sum(e.get("payload", {}).get("usage", {}).get("input_tokens", 0) for e in _resp_entries)
|
||||
_out_tokens = sum(e.get("payload", {}).get("usage", {}).get("output_tokens", 0) for e in _resp_entries)
|
||||
engine.tier_usage["Tier 3"]["input"] += _in_tokens
|
||||
engine.tier_usage["Tier 3"]["output"] += _out_tokens
|
||||
if "BLOCKED" in response.upper():
|
||||
ticket.mark_blocked(response)
|
||||
else:
|
||||
|
||||
@@ -293,3 +293,31 @@ def test_run_worker_lifecycle_pushes_response_via_queue(monkeypatch: pytest.Monk
|
||||
assert call_args[3]["status"] == "done"
|
||||
assert ticket.status == "completed"
|
||||
|
||||
def test_run_worker_lifecycle_token_usage_from_comms_log(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""
|
||||
Test that run_worker_lifecycle reads token usage from the comms log and
|
||||
updates engine.tier_usage['Tier 3'] with real input/output token counts.
|
||||
"""
|
||||
ticket = Ticket(id="T1", description="Task 1", status="todo", assigned_to="worker1")
|
||||
context = WorkerContext(ticket_id="T1", model_name="test-model", messages=[])
|
||||
fake_comms = [
|
||||
{"direction": "OUT", "kind": "request", "payload": {"message": "hello"}},
|
||||
{"direction": "IN", "kind": "response", "payload": {"usage": {"input_tokens": 120, "output_tokens": 45}}},
|
||||
]
|
||||
monkeypatch.setattr(ai_client, 'send', MagicMock(return_value="Done."))
|
||||
monkeypatch.setattr(ai_client, 'reset_session', MagicMock())
|
||||
monkeypatch.setattr(ai_client, 'get_comms_log', MagicMock(side_effect=[
|
||||
[], # baseline call (before send)
|
||||
fake_comms, # after-send call
|
||||
]))
|
||||
from multi_agent_conductor import run_worker_lifecycle, ConductorEngine
|
||||
from models import Track
|
||||
track = Track(id="test_track", description="Test")
|
||||
engine = ConductorEngine(track=track, auto_queue=True)
|
||||
with patch("multi_agent_conductor.confirm_spawn") as mock_spawn, \
|
||||
patch("multi_agent_conductor._queue_put"):
|
||||
mock_spawn.return_value = (True, "prompt", "ctx")
|
||||
run_worker_lifecycle(ticket, context, event_queue=MagicMock(), loop=MagicMock(), engine=engine)
|
||||
assert engine.tier_usage["Tier 3"]["input"] == 120
|
||||
assert engine.tier_usage["Tier 3"]["output"] == 45
|
||||
|
||||
|
||||
Reference in New Issue
Block a user