fix(mma): Replace token stats stub with real comms log extraction in run_worker_lifecycle

Task 2.1 of mma_pipeline_fix_20260301: capture comms baseline before send(), then sum input_tokens/output_tokens from IN/response entries to populate engine.tier_usage['Tier 3'].
This commit is contained in:
2026-03-01 13:22:15 -05:00
parent d5eb3f472e
commit 3eefdfd29d
2 changed files with 35 additions and 3 deletions

View File

@@ -268,6 +268,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
if not event_queue: if not event_queue:
return True return True
return confirm_execution(payload, event_queue, ticket.id, loop=loop) return confirm_execution(payload, event_queue, ticket.id, loop=loop)
comms_baseline = len(ai_client.get_comms_log())
response = ai_client.send( response = ai_client.send(
md_content=md_content, md_content=md_content,
user_message=user_message, user_message=user_message,
@@ -295,9 +296,12 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
# Update usage in engine if provided # Update usage in engine if provided
if engine: if engine:
stats = {} # ai_client.get_token_stats() is not available _new_comms = ai_client.get_comms_log()[comms_baseline:]
engine.tier_usage["Tier 3"]["input"] += stats.get("prompt_tokens", 0) _resp_entries = [e for e in _new_comms if e.get("direction") == "IN" and e.get("kind") == "response"]
engine.tier_usage["Tier 3"]["output"] += stats.get("candidates_tokens", 0) _in_tokens = sum(e.get("payload", {}).get("usage", {}).get("input_tokens", 0) for e in _resp_entries)
_out_tokens = sum(e.get("payload", {}).get("usage", {}).get("output_tokens", 0) for e in _resp_entries)
engine.tier_usage["Tier 3"]["input"] += _in_tokens
engine.tier_usage["Tier 3"]["output"] += _out_tokens
if "BLOCKED" in response.upper(): if "BLOCKED" in response.upper():
ticket.mark_blocked(response) ticket.mark_blocked(response)
else: else:

View File

@@ -293,3 +293,31 @@ def test_run_worker_lifecycle_pushes_response_via_queue(monkeypatch: pytest.Monk
assert call_args[3]["status"] == "done" assert call_args[3]["status"] == "done"
assert ticket.status == "completed" assert ticket.status == "completed"
def test_run_worker_lifecycle_token_usage_from_comms_log(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Test that run_worker_lifecycle reads token usage from the comms log and
updates engine.tier_usage['Tier 3'] with real input/output token counts.
"""
ticket = Ticket(id="T1", description="Task 1", status="todo", assigned_to="worker1")
context = WorkerContext(ticket_id="T1", model_name="test-model", messages=[])
fake_comms = [
{"direction": "OUT", "kind": "request", "payload": {"message": "hello"}},
{"direction": "IN", "kind": "response", "payload": {"usage": {"input_tokens": 120, "output_tokens": 45}}},
]
monkeypatch.setattr(ai_client, 'send', MagicMock(return_value="Done."))
monkeypatch.setattr(ai_client, 'reset_session', MagicMock())
monkeypatch.setattr(ai_client, 'get_comms_log', MagicMock(side_effect=[
[], # baseline call (before send)
fake_comms, # after-send call
]))
from multi_agent_conductor import run_worker_lifecycle, ConductorEngine
from models import Track
track = Track(id="test_track", description="Test")
engine = ConductorEngine(track=track, auto_queue=True)
with patch("multi_agent_conductor.confirm_spawn") as mock_spawn, \
patch("multi_agent_conductor._queue_put"):
mock_spawn.return_value = (True, "prompt", "ctx")
run_worker_lifecycle(ticket, context, event_queue=MagicMock(), loop=MagicMock(), engine=engine)
assert engine.tier_usage["Tier 3"]["input"] == 120
assert engine.tier_usage["Tier 3"]["output"] == 45