feat(api): implement websocket gateway and event streaming for phase 1

This commit is contained in:
2026-03-11 23:01:09 -04:00
parent 00a390ffab
commit 02e0fce548
8 changed files with 158 additions and 26 deletions

View File

@@ -22,7 +22,7 @@ This file tracks all major tracks for the project. Each track has its own detail
*Link: [./tracks/tool_bias_tuning_20260308/](./tracks/tool_bias_tuning_20260308/)* *Link: [./tracks/tool_bias_tuning_20260308/](./tracks/tool_bias_tuning_20260308/)*
*Goal: Influence agent tool selection via a weighting system. Implement semantic nudges in tool descriptions and a dynamic "Tooling Strategy" section in the system prompt. Includes GUI badges and sliders for weight adjustment.* *Goal: Influence agent tool selection via a weighting system. Implement semantic nudges in tool descriptions and a dynamic "Tooling Strategy" section in the system prompt. Includes GUI badges and sliders for weight adjustment.*
4. [ ] **Track: Expanded Hook API & Headless Orchestration** 4. [~] **Track: Expanded Hook API & Headless Orchestration**
*Link: [./tracks/hook_api_expansion_20260308/](./tracks/hook_api_expansion_20260308/)* *Link: [./tracks/hook_api_expansion_20260308/](./tracks/hook_api_expansion_20260308/)*
*Goal: Maximize internal state exposure and provide comprehensive control endpoints (worker spawn/kill, pipeline pause/resume, DAG mutation) via the Hook API. Implement WebSocket-based real-time event streaming.* *Goal: Maximize internal state exposure and provide comprehensive control endpoints (worker spawn/kill, pipeline pause/resume, DAG mutation) via the Hook API. Implement WebSocket-based real-time event streaming.*

View File

@@ -1,14 +1,14 @@
# Implementation Plan: Expanded Hook API & Headless Orchestration # Implementation Plan: Expanded Hook API & Headless Orchestration
## Phase 1: WebSocket Infrastructure & Event Streaming ## Phase 1: WebSocket Infrastructure & Event Streaming
- [ ] Task: Implement the WebSocket gateway. - [x] Task: Implement the WebSocket gateway.
- [ ] Integrate a lightweight WebSocket library (e.g., `websockets` or `simple-websocket`). - [x] Integrate a lightweight WebSocket library (e.g., `websockets` or `simple-websocket`).
- [ ] Create a dedicated `WebSocketServer` class in `src/api_hooks.py` that runs on a separate port (e.g., 9000). - [x] Create a dedicated `WebSocketServer` class in `src/api_hooks.py` that runs on a separate port (e.g., 9000).
- [ ] Implement a basic subscription mechanism for different event channels. - [x] Implement a basic subscription mechanism for different event channels.
- [ ] Task: Connect the event queue to the WebSocket stream. - [x] Task: Connect the event queue to the WebSocket stream.
- [ ] Update `AsyncEventQueue` to broadcast events to connected WebSocket clients. - [x] Update `AsyncEventQueue` to broadcast events to connected WebSocket clients.
- [ ] Add high-frequency telemetry (FPS, CPU) to the event stream. - [x] Add high-frequency telemetry (FPS, CPU) to the event stream.
- [ ] Task: Write unit tests for WebSocket connection and event broadcasting. - [x] Task: Write unit tests for WebSocket connection and event broadcasting.
- [ ] Task: Conductor - User Manual Verification 'Phase 1: WebSocket Infrastructure' (Protocol in workflow.md) - [ ] Task: Conductor - User Manual Verification 'Phase 1: WebSocket Infrastructure' (Protocol in workflow.md)
## Phase 2: Expanded Read Endpoints (GET) ## Phase 2: Expanded Read Endpoints (GET)

View File

@@ -116,7 +116,7 @@ class ApiHookClient:
return None return None
def post_gui(self, payload: dict) -> dict[str, Any]: def post_gui(self, payload: dict) -> dict[str, Any]:
"""Pushes an event to the GUI's SyncEventQueue via the /api/gui endpoint.""" """Pushes an event to the GUI's AsyncEventQueue via the /api/gui endpoint."""
return self._make_request('POST', '/api/gui', data=payload) or {} return self._make_request('POST', '/api/gui', data=payload) or {}
def push_event(self, action: str, payload: dict) -> dict[str, Any]: def push_event(self, action: str, payload: dict) -> dict[str, Any]:

View File

@@ -3,9 +3,12 @@ import json
import threading import threading
import uuid import uuid
import sys import sys
import asyncio
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
from typing import Any from typing import Any
import logging import logging
import websockets
from websockets.asyncio.server import serve
from src import session_logger from src import session_logger
""" """
API Hooks - REST API for external automation and state inspection. API Hooks - REST API for external automation and state inspection.
@@ -498,6 +501,7 @@ class HookServer:
self.port = port self.port = port
self.server = None self.server = None
self.thread = None self.thread = None
self.websocket_server: WebSocketServer | None = None
def start(self) -> None: def start(self) -> None:
if self.thread and self.thread.is_alive(): if self.thread and self.thread.is_alive():
@@ -511,12 +515,22 @@ class HookServer:
if not _has_app_attr(self.app, '_ask_responses'): _set_app_attr(self.app, '_ask_responses', {}) if not _has_app_attr(self.app, '_ask_responses'): _set_app_attr(self.app, '_ask_responses', {})
if not _has_app_attr(self.app, '_api_event_queue'): _set_app_attr(self.app, '_api_event_queue', []) if not _has_app_attr(self.app, '_api_event_queue'): _set_app_attr(self.app, '_api_event_queue', [])
if not _has_app_attr(self.app, '_api_event_queue_lock'): _set_app_attr(self.app, '_api_event_queue_lock', threading.Lock()) if not _has_app_attr(self.app, '_api_event_queue_lock'): _set_app_attr(self.app, '_api_event_queue_lock', threading.Lock())
self.websocket_server = WebSocketServer(self.app)
self.websocket_server.start()
eq = _get_app_attr(self.app, 'event_queue')
if eq:
eq.websocket_server = self.websocket_server
self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app) self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app)
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.thread.start() self.thread.start()
logging.info(f"Hook server started on port {self.port}") logging.info(f"Hook server started on port {self.port}")
def stop(self) -> None: def stop(self) -> None:
if self.websocket_server:
self.websocket_server.stop()
if self.server: if self.server:
self.server.shutdown() self.server.shutdown()
self.server.server_close() self.server.server_close()
@@ -524,3 +538,62 @@ class HookServer:
self.thread.join() self.thread.join()
logging.info("Hook server stopped") logging.info("Hook server stopped")
class WebSocketServer:
"""WebSocket gateway for real-time event streaming."""
def __init__(self, app: Any, port: int = 9000) -> None:
self.app = app
self.port = port
self.clients: dict[str, set] = {"events": set(), "telemetry": set()}
self.loop: asyncio.AbstractEventLoop | None = None
self.thread: threading.Thread | None = None
self.server = None
self._stop_event: asyncio.Event | None = None
async def _handler(self, websocket) -> None:
try:
async for message in websocket:
try:
data = json.loads(message)
if data.get("action") == "subscribe":
channel = data.get("channel")
if channel in self.clients:
self.clients[channel].add(websocket)
await websocket.send(json.dumps({"type": "subscription_confirmed", "channel": channel}))
except json.JSONDecodeError:
pass
except websockets.exceptions.ConnectionClosed:
pass
finally:
for channel in self.clients:
if websocket in self.clients[channel]:
self.clients[channel].remove(websocket)
def _run_loop(self) -> None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self._stop_event = asyncio.Event()
async def main():
async with serve(self._handler, "127.0.0.1", self.port) as server:
self.server = server
await self._stop_event.wait()
self.loop.run_until_complete(main())
def start(self) -> None:
if self.thread and self.thread.is_alive():
return
self.thread = threading.Thread(target=self._run_loop, daemon=True)
self.thread.start()
def stop(self) -> None:
if self.loop and self._stop_event:
self.loop.call_soon_threadsafe(self._stop_event.set)
if self.thread:
self.thread.join(timeout=2.0)
def broadcast(self, channel: str, payload: dict[str, Any]) -> None:
if not self.loop or channel not in self.clients:
return
message = json.dumps({"channel": channel, "payload": payload})
for ws in list(self.clients[channel]):
asyncio.run_coroutine_threadsafe(ws.send(message), self.loop)

View File

@@ -150,7 +150,7 @@ class AppController:
self.disc_roles: List[str] = [] self.disc_roles: List[str] = []
self.files: List[str] = [] self.files: List[str] = []
self.screenshots: List[str] = [] self.screenshots: List[str] = []
self.event_queue: events.SyncEventQueue = events.SyncEventQueue() self.event_queue: events.AsyncEventQueue = events.AsyncEventQueue()
self._loop_thread: Optional[threading.Thread] = None self._loop_thread: Optional[threading.Thread] = None
self.tracks: List[Dict[str, Any]] = [] self.tracks: List[Dict[str, Any]] = []
self.active_track: Optional[models.Track] = None self.active_track: Optional[models.Track] = None
@@ -188,6 +188,7 @@ class AppController:
"Tier 4": {"input": 0, "output": 0, "provider": "gemini", "model": "gemini-2.5-flash-lite", "tool_preset": None}, "Tier 4": {"input": 0, "output": 0, "provider": "gemini", "model": "gemini-2.5-flash-lite", "tool_preset": None},
} }
self.perf_monitor: performance_monitor.PerformanceMonitor = performance_monitor.PerformanceMonitor() self.perf_monitor: performance_monitor.PerformanceMonitor = performance_monitor.PerformanceMonitor()
self._last_telemetry_time: float = 0.0
self._pending_gui_tasks: List[Dict[str, Any]] = [] self._pending_gui_tasks: List[Dict[str, Any]] = []
self._api_event_queue: List[Dict[str, Any]] = [] self._api_event_queue: List[Dict[str, Any]] = []
# Pending dialogs state moved from App # Pending dialogs state moved from App
@@ -522,6 +523,14 @@ class AppController:
}) })
def _process_pending_gui_tasks(self) -> None: def _process_pending_gui_tasks(self) -> None:
# Periodic telemetry broadcast
now = time.time()
if hasattr(self, 'event_queue') and hasattr(self.event_queue, 'websocket_server') and self.event_queue.websocket_server:
if now - self._last_telemetry_time >= 1.0:
self._last_telemetry_time = now
metrics = self.perf_monitor.get_metrics()
self.event_queue.websocket_server.broadcast("telemetry", metrics)
if not self._pending_gui_tasks: if not self._pending_gui_tasks:
return return
sys.stderr.write(f"[DEBUG] _process_pending_gui_tasks: processing {len(self._pending_gui_tasks)} tasks\n") sys.stderr.write(f"[DEBUG] _process_pending_gui_tasks: processing {len(self._pending_gui_tasks)} tasks\n")

View File

@@ -9,7 +9,7 @@ between the GUI main thread and background workers:
- Thread-safe: Callbacks execute on emitter's thread - Thread-safe: Callbacks execute on emitter's thread
- Example: ai_client.py emits 'request_start' and 'response_received' events - Example: ai_client.py emits 'request_start' and 'response_received' events
2. SyncEventQueue: Producer-consumer pattern via queue.Queue 2. AsyncEventQueue: Producer-consumer pattern via queue.Queue
- Used for: Decoupled task submission where consumer polls at its own pace - Used for: Decoupled task submission where consumer polls at its own pace
- Thread-safe: Built on Python's thread-safe queue.Queue - Thread-safe: Built on Python's thread-safe queue.Queue
- Example: Background workers submit tasks, main thread drains queue - Example: Background workers submit tasks, main thread drains queue
@@ -21,16 +21,16 @@ between the GUI main thread and background workers:
Integration Points: Integration Points:
- ai_client.py: EventEmitter for API lifecycle events - ai_client.py: EventEmitter for API lifecycle events
- gui_2.py: Consumes events via _process_event_queue() - gui_2.py: Consumes events via _process_event_queue()
- multi_agent_conductor.py: Uses SyncEventQueue for state updates - multi_agent_conductor.py: Uses AsyncEventQueue for state updates
- api_hooks.py: Pushes events to _api_event_queue for external visibility - api_hooks.py: Pushes events to _api_event_queue for external visibility
Thread Safety: Thread Safety:
- EventEmitter: NOT thread-safe for concurrent on/emit (use from single thread) - EventEmitter: NOT thread-safe for concurrent on/emit (use from single thread)
- SyncEventQueue: FULLY thread-safe (built on queue.Queue) - AsyncEventQueue: FULLY thread-safe (built on queue.Queue)
- UserRequestEvent: Immutable, safe for concurrent access - UserRequestEvent: Immutable, safe for concurrent access
""" """
import queue import queue
from typing import Callable, Any, Dict, List, Tuple from typing import Callable, Any, Dict, List, Tuple, Optional
class EventEmitter: class EventEmitter:
""" """
@@ -70,14 +70,16 @@ class EventEmitter:
"""Clears all registered listeners.""" """Clears all registered listeners."""
self._listeners.clear() self._listeners.clear()
class SyncEventQueue: class AsyncEventQueue:
""" """
Synchronous event queue for decoupled communication using queue.Queue. Synchronous event queue for decoupled communication using queue.Queue.
(Named AsyncEventQueue for architectural consistency, but is synchronous)
""" """
def __init__(self) -> None: def __init__(self) -> None:
"""Initializes the SyncEventQueue with an internal queue.Queue.""" """Initializes the AsyncEventQueue with an internal queue.Queue."""
self._queue: queue.Queue[Tuple[str, Any]] = queue.Queue() self._queue: queue.Queue[Tuple[str, Any]] = queue.Queue()
self.websocket_server: Optional[Any] = None
def put(self, event_name: str, payload: Any = None) -> None: def put(self, event_name: str, payload: Any = None) -> None:
""" """
@@ -88,6 +90,8 @@ class SyncEventQueue:
payload: Optional data associated with the event. payload: Optional data associated with the event.
""" """
self._queue.put((event_name, payload)) self._queue.put((event_name, payload))
if self.websocket_server:
self.websocket_server.broadcast("events", {"event": event_name, "payload": payload})
def get(self) -> Tuple[str, Any]: def get(self) -> Tuple[str, Any]:
""" """
@@ -106,6 +110,9 @@ class SyncEventQueue:
"""Blocks until all items in the queue have been gotten and processed.""" """Blocks until all items in the queue have been gotten and processed."""
self._queue.join() self._queue.join()
# Alias for backward compatibility
SyncEventQueue = AsyncEventQueue
class UserRequestEvent: class UserRequestEvent:
""" """
Payload for a user request event. Payload for a user request event.
@@ -126,4 +133,3 @@ class UserRequestEvent:
"disc_text": self.disc_text, "disc_text": self.disc_text,
"base_dir": self.base_dir "base_dir": self.base_dir
} }

View File

@@ -11,7 +11,7 @@ Key Components:
Architecture Integration: Architecture Integration:
- Uses TrackDAG and ExecutionEngine from dag_engine.py - Uses TrackDAG and ExecutionEngine from dag_engine.py
- Communicates with GUI via SyncEventQueue - Communicates with GUI via AsyncEventQueue
- Manages tier-specific token usage via update_usage() - Manages tier-specific token usage via update_usage()
Thread Safety: Thread Safety:
@@ -45,7 +45,7 @@ Thread Safety:
- Abort events enable per-ticket cancellation - Abort events enable per-ticket cancellation
Integration: Integration:
- Uses SyncEventQueue for state updates to the GUI - Uses AsyncEventQueue for state updates to the GUI
- Uses ai_client.send() for LLM communication - Uses ai_client.send() for LLM communication
- Uses mcp_client for tool dispatch - Uses mcp_client for tool dispatch
@@ -123,7 +123,7 @@ class ConductorEngine:
Orchestrates the execution of tickets within a track. Orchestrates the execution of tickets within a track.
""" """
def __init__(self, track: Track, event_queue: Optional[events.SyncEventQueue] = None, auto_queue: bool = False) -> None: def __init__(self, track: Track, event_queue: Optional[events.AsyncEventQueue] = None, auto_queue: bool = False) -> None:
self.track = track self.track = track
self.event_queue = event_queue self.event_queue = event_queue
self.tier_usage = { self.tier_usage = {
@@ -343,12 +343,12 @@ class ConductorEngine:
self._push_state(active_tier="Tier 2 (Tech Lead)") self._push_state(active_tier="Tier 2 (Tech Lead)")
time.sleep(1) time.sleep(1)
def _queue_put(event_queue: events.SyncEventQueue, event_name: str, payload) -> None: def _queue_put(event_queue: events.AsyncEventQueue, event_name: str, payload) -> None:
"""Thread-safe helper to push an event to the SyncEventQueue from a worker thread.""" """Thread-safe helper to push an event to the AsyncEventQueue from a worker thread."""
if event_queue is not None: if event_queue is not None:
event_queue.put(event_name, payload) event_queue.put(event_name, payload)
def confirm_execution(payload: str, event_queue: events.SyncEventQueue, ticket_id: str) -> bool: def confirm_execution(payload: str, event_queue: events.AsyncEventQueue, ticket_id: str) -> bool:
""" """
Pushes an approval request to the GUI and waits for response. Pushes an approval request to the GUI and waits for response.
""" """
@@ -370,7 +370,7 @@ def confirm_execution(payload: str, event_queue: events.SyncEventQueue, ticket_i
return approved return approved
return False return False
def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.SyncEventQueue, ticket_id: str) -> Tuple[bool, str, str]: def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.AsyncEventQueue, ticket_id: str) -> Tuple[bool, str, str]:
""" """
Pushes a spawn approval request to the GUI and waits for response. Pushes a spawn approval request to the GUI and waits for response.
Returns (approved, modified_prompt, modified_context) Returns (approved, modified_prompt, modified_context)
@@ -409,7 +409,7 @@ def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.S
return approved, modified_prompt, modified_context return approved, modified_prompt, modified_context
return False, prompt, context_md return False, prompt, context_md
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] | None = None, event_queue: events.SyncEventQueue | None = None, engine: Optional['ConductorEngine'] = None, md_content: str = "") -> None: def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] | None = None, event_queue: events.AsyncEventQueue | None = None, engine: Optional['ConductorEngine'] = None, md_content: str = "") -> None:
""" """
Simulates the lifecycle of a single agent working on a ticket. Simulates the lifecycle of a single agent working on a ticket.
Calls the AI client and updates the ticket status based on the response. Calls the AI client and updates the ticket status based on the response.

View File

@@ -0,0 +1,44 @@
import pytest
import asyncio
import json
import websockets
from src.api_hooks import WebSocketServer
@pytest.mark.asyncio
async def test_websocket_subscription_and_broadcast():
# Mock app
app = type("MockApp", (), {"test_hooks_enabled": True})()
# Start server on a specific port
port = 9005
server = WebSocketServer(app, port=port)
server.start()
# Wait for server to start
await asyncio.sleep(0.5)
try:
uri = f"ws://127.0.0.1:{port}"
async with websockets.connect(uri) as websocket:
# Subscribe to events channel
subscribe_msg = {"action": "subscribe", "channel": "events"}
await websocket.send(json.dumps(subscribe_msg))
# Receive confirmation
response = await websocket.recv()
data = json.loads(response)
assert data["type"] == "subscription_confirmed"
assert data["channel"] == "events"
# Broadcast an event from the server
event_payload = {"event": "test_event", "data": "hello"}
server.broadcast("events", event_payload)
# Receive the broadcast
broadcast_response = await websocket.recv()
broadcast_data = json.loads(broadcast_response)
assert broadcast_data["channel"] == "events"
assert broadcast_data["payload"] == event_payload
finally:
server.stop()