From 02e0fce54815c5128ed1fce758a73460df8e6a5c Mon Sep 17 00:00:00 2001 From: Ed_ Date: Wed, 11 Mar 2026 23:01:09 -0400 Subject: [PATCH] feat(api): implement websocket gateway and event streaming for phase 1 --- conductor/tracks.md | 2 +- .../hook_api_expansion_20260308/plan.md | 16 ++-- src/api_hook_client.py | 2 +- src/api_hooks.py | 73 +++++++++++++++++++ src/app_controller.py | 11 ++- src/events.py | 20 +++-- src/multi_agent_conductor.py | 16 ++-- tests/test_websocket_server.py | 44 +++++++++++ 8 files changed, 158 insertions(+), 26 deletions(-) create mode 100644 tests/test_websocket_server.py diff --git a/conductor/tracks.md b/conductor/tracks.md index e6a49d4..8b4615a 100644 --- a/conductor/tracks.md +++ b/conductor/tracks.md @@ -22,7 +22,7 @@ This file tracks all major tracks for the project. Each track has its own detail *Link: [./tracks/tool_bias_tuning_20260308/](./tracks/tool_bias_tuning_20260308/)* *Goal: Influence agent tool selection via a weighting system. Implement semantic nudges in tool descriptions and a dynamic "Tooling Strategy" section in the system prompt. Includes GUI badges and sliders for weight adjustment.* -4. [ ] **Track: Expanded Hook API & Headless Orchestration** +4. [~] **Track: Expanded Hook API & Headless Orchestration** *Link: [./tracks/hook_api_expansion_20260308/](./tracks/hook_api_expansion_20260308/)* *Goal: Maximize internal state exposure and provide comprehensive control endpoints (worker spawn/kill, pipeline pause/resume, DAG mutation) via the Hook API. Implement WebSocket-based real-time event streaming.* diff --git a/conductor/tracks/hook_api_expansion_20260308/plan.md b/conductor/tracks/hook_api_expansion_20260308/plan.md index 2ca5124..d7a3a15 100644 --- a/conductor/tracks/hook_api_expansion_20260308/plan.md +++ b/conductor/tracks/hook_api_expansion_20260308/plan.md @@ -1,14 +1,14 @@ # Implementation Plan: Expanded Hook API & Headless Orchestration ## Phase 1: WebSocket Infrastructure & Event Streaming -- [ ] Task: Implement the WebSocket gateway. - - [ ] Integrate a lightweight WebSocket library (e.g., `websockets` or `simple-websocket`). - - [ ] Create a dedicated `WebSocketServer` class in `src/api_hooks.py` that runs on a separate port (e.g., 9000). - - [ ] Implement a basic subscription mechanism for different event channels. -- [ ] Task: Connect the event queue to the WebSocket stream. - - [ ] Update `AsyncEventQueue` to broadcast events to connected WebSocket clients. - - [ ] Add high-frequency telemetry (FPS, CPU) to the event stream. -- [ ] Task: Write unit tests for WebSocket connection and event broadcasting. +- [x] Task: Implement the WebSocket gateway. + - [x] Integrate a lightweight WebSocket library (e.g., `websockets` or `simple-websocket`). + - [x] Create a dedicated `WebSocketServer` class in `src/api_hooks.py` that runs on a separate port (e.g., 9000). + - [x] Implement a basic subscription mechanism for different event channels. +- [x] Task: Connect the event queue to the WebSocket stream. + - [x] Update `AsyncEventQueue` to broadcast events to connected WebSocket clients. + - [x] Add high-frequency telemetry (FPS, CPU) to the event stream. +- [x] Task: Write unit tests for WebSocket connection and event broadcasting. - [ ] Task: Conductor - User Manual Verification 'Phase 1: WebSocket Infrastructure' (Protocol in workflow.md) ## Phase 2: Expanded Read Endpoints (GET) diff --git a/src/api_hook_client.py b/src/api_hook_client.py index 46e3995..f155597 100644 --- a/src/api_hook_client.py +++ b/src/api_hook_client.py @@ -116,7 +116,7 @@ class ApiHookClient: return None def post_gui(self, payload: dict) -> dict[str, Any]: - """Pushes an event to the GUI's SyncEventQueue via the /api/gui endpoint.""" + """Pushes an event to the GUI's AsyncEventQueue via the /api/gui endpoint.""" return self._make_request('POST', '/api/gui', data=payload) or {} def push_event(self, action: str, payload: dict) -> dict[str, Any]: diff --git a/src/api_hooks.py b/src/api_hooks.py index ab94c23..379f503 100644 --- a/src/api_hooks.py +++ b/src/api_hooks.py @@ -3,9 +3,12 @@ import json import threading import uuid import sys +import asyncio from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler from typing import Any import logging +import websockets +from websockets.asyncio.server import serve from src import session_logger """ API Hooks - REST API for external automation and state inspection. @@ -498,6 +501,7 @@ class HookServer: self.port = port self.server = None self.thread = None + self.websocket_server: WebSocketServer | None = None def start(self) -> None: if self.thread and self.thread.is_alive(): @@ -511,12 +515,22 @@ class HookServer: if not _has_app_attr(self.app, '_ask_responses'): _set_app_attr(self.app, '_ask_responses', {}) if not _has_app_attr(self.app, '_api_event_queue'): _set_app_attr(self.app, '_api_event_queue', []) if not _has_app_attr(self.app, '_api_event_queue_lock'): _set_app_attr(self.app, '_api_event_queue_lock', threading.Lock()) + + self.websocket_server = WebSocketServer(self.app) + self.websocket_server.start() + + eq = _get_app_attr(self.app, 'event_queue') + if eq: + eq.websocket_server = self.websocket_server + self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app) self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) self.thread.start() logging.info(f"Hook server started on port {self.port}") def stop(self) -> None: + if self.websocket_server: + self.websocket_server.stop() if self.server: self.server.shutdown() self.server.server_close() @@ -524,3 +538,62 @@ class HookServer: self.thread.join() logging.info("Hook server stopped") +class WebSocketServer: + """WebSocket gateway for real-time event streaming.""" + def __init__(self, app: Any, port: int = 9000) -> None: + self.app = app + self.port = port + self.clients: dict[str, set] = {"events": set(), "telemetry": set()} + self.loop: asyncio.AbstractEventLoop | None = None + self.thread: threading.Thread | None = None + self.server = None + self._stop_event: asyncio.Event | None = None + + async def _handler(self, websocket) -> None: + try: + async for message in websocket: + try: + data = json.loads(message) + if data.get("action") == "subscribe": + channel = data.get("channel") + if channel in self.clients: + self.clients[channel].add(websocket) + await websocket.send(json.dumps({"type": "subscription_confirmed", "channel": channel})) + except json.JSONDecodeError: + pass + except websockets.exceptions.ConnectionClosed: + pass + finally: + for channel in self.clients: + if websocket in self.clients[channel]: + self.clients[channel].remove(websocket) + + def _run_loop(self) -> None: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self._stop_event = asyncio.Event() + async def main(): + async with serve(self._handler, "127.0.0.1", self.port) as server: + self.server = server + await self._stop_event.wait() + self.loop.run_until_complete(main()) + + def start(self) -> None: + if self.thread and self.thread.is_alive(): + return + self.thread = threading.Thread(target=self._run_loop, daemon=True) + self.thread.start() + + def stop(self) -> None: + if self.loop and self._stop_event: + self.loop.call_soon_threadsafe(self._stop_event.set) + if self.thread: + self.thread.join(timeout=2.0) + + def broadcast(self, channel: str, payload: dict[str, Any]) -> None: + if not self.loop or channel not in self.clients: + return + message = json.dumps({"channel": channel, "payload": payload}) + for ws in list(self.clients[channel]): + asyncio.run_coroutine_threadsafe(ws.send(message), self.loop) + diff --git a/src/app_controller.py b/src/app_controller.py index 8a78f24..da0acee 100644 --- a/src/app_controller.py +++ b/src/app_controller.py @@ -150,7 +150,7 @@ class AppController: self.disc_roles: List[str] = [] self.files: List[str] = [] self.screenshots: List[str] = [] - self.event_queue: events.SyncEventQueue = events.SyncEventQueue() + self.event_queue: events.AsyncEventQueue = events.AsyncEventQueue() self._loop_thread: Optional[threading.Thread] = None self.tracks: List[Dict[str, Any]] = [] self.active_track: Optional[models.Track] = None @@ -188,6 +188,7 @@ class AppController: "Tier 4": {"input": 0, "output": 0, "provider": "gemini", "model": "gemini-2.5-flash-lite", "tool_preset": None}, } self.perf_monitor: performance_monitor.PerformanceMonitor = performance_monitor.PerformanceMonitor() + self._last_telemetry_time: float = 0.0 self._pending_gui_tasks: List[Dict[str, Any]] = [] self._api_event_queue: List[Dict[str, Any]] = [] # Pending dialogs state moved from App @@ -522,6 +523,14 @@ class AppController: }) def _process_pending_gui_tasks(self) -> None: + # Periodic telemetry broadcast + now = time.time() + if hasattr(self, 'event_queue') and hasattr(self.event_queue, 'websocket_server') and self.event_queue.websocket_server: + if now - self._last_telemetry_time >= 1.0: + self._last_telemetry_time = now + metrics = self.perf_monitor.get_metrics() + self.event_queue.websocket_server.broadcast("telemetry", metrics) + if not self._pending_gui_tasks: return sys.stderr.write(f"[DEBUG] _process_pending_gui_tasks: processing {len(self._pending_gui_tasks)} tasks\n") diff --git a/src/events.py b/src/events.py index 5b42206..4a53a4f 100644 --- a/src/events.py +++ b/src/events.py @@ -9,7 +9,7 @@ between the GUI main thread and background workers: - Thread-safe: Callbacks execute on emitter's thread - Example: ai_client.py emits 'request_start' and 'response_received' events -2. SyncEventQueue: Producer-consumer pattern via queue.Queue +2. AsyncEventQueue: Producer-consumer pattern via queue.Queue - Used for: Decoupled task submission where consumer polls at its own pace - Thread-safe: Built on Python's thread-safe queue.Queue - Example: Background workers submit tasks, main thread drains queue @@ -21,16 +21,16 @@ between the GUI main thread and background workers: Integration Points: - ai_client.py: EventEmitter for API lifecycle events - gui_2.py: Consumes events via _process_event_queue() - - multi_agent_conductor.py: Uses SyncEventQueue for state updates + - multi_agent_conductor.py: Uses AsyncEventQueue for state updates - api_hooks.py: Pushes events to _api_event_queue for external visibility Thread Safety: - EventEmitter: NOT thread-safe for concurrent on/emit (use from single thread) - - SyncEventQueue: FULLY thread-safe (built on queue.Queue) + - AsyncEventQueue: FULLY thread-safe (built on queue.Queue) - UserRequestEvent: Immutable, safe for concurrent access """ import queue -from typing import Callable, Any, Dict, List, Tuple +from typing import Callable, Any, Dict, List, Tuple, Optional class EventEmitter: """ @@ -70,14 +70,16 @@ class EventEmitter: """Clears all registered listeners.""" self._listeners.clear() -class SyncEventQueue: +class AsyncEventQueue: """ Synchronous event queue for decoupled communication using queue.Queue. + (Named AsyncEventQueue for architectural consistency, but is synchronous) """ def __init__(self) -> None: - """Initializes the SyncEventQueue with an internal queue.Queue.""" + """Initializes the AsyncEventQueue with an internal queue.Queue.""" self._queue: queue.Queue[Tuple[str, Any]] = queue.Queue() + self.websocket_server: Optional[Any] = None def put(self, event_name: str, payload: Any = None) -> None: """ @@ -88,6 +90,8 @@ class SyncEventQueue: payload: Optional data associated with the event. """ self._queue.put((event_name, payload)) + if self.websocket_server: + self.websocket_server.broadcast("events", {"event": event_name, "payload": payload}) def get(self) -> Tuple[str, Any]: """ @@ -106,6 +110,9 @@ class SyncEventQueue: """Blocks until all items in the queue have been gotten and processed.""" self._queue.join() +# Alias for backward compatibility +SyncEventQueue = AsyncEventQueue + class UserRequestEvent: """ Payload for a user request event. @@ -126,4 +133,3 @@ class UserRequestEvent: "disc_text": self.disc_text, "base_dir": self.base_dir } - diff --git a/src/multi_agent_conductor.py b/src/multi_agent_conductor.py index 650c05e..10545af 100644 --- a/src/multi_agent_conductor.py +++ b/src/multi_agent_conductor.py @@ -11,7 +11,7 @@ Key Components: Architecture Integration: - Uses TrackDAG and ExecutionEngine from dag_engine.py - - Communicates with GUI via SyncEventQueue + - Communicates with GUI via AsyncEventQueue - Manages tier-specific token usage via update_usage() Thread Safety: @@ -45,7 +45,7 @@ Thread Safety: - Abort events enable per-ticket cancellation Integration: - - Uses SyncEventQueue for state updates to the GUI + - Uses AsyncEventQueue for state updates to the GUI - Uses ai_client.send() for LLM communication - Uses mcp_client for tool dispatch @@ -123,7 +123,7 @@ class ConductorEngine: Orchestrates the execution of tickets within a track. """ - def __init__(self, track: Track, event_queue: Optional[events.SyncEventQueue] = None, auto_queue: bool = False) -> None: + def __init__(self, track: Track, event_queue: Optional[events.AsyncEventQueue] = None, auto_queue: bool = False) -> None: self.track = track self.event_queue = event_queue self.tier_usage = { @@ -343,12 +343,12 @@ class ConductorEngine: self._push_state(active_tier="Tier 2 (Tech Lead)") time.sleep(1) -def _queue_put(event_queue: events.SyncEventQueue, event_name: str, payload) -> None: - """Thread-safe helper to push an event to the SyncEventQueue from a worker thread.""" +def _queue_put(event_queue: events.AsyncEventQueue, event_name: str, payload) -> None: + """Thread-safe helper to push an event to the AsyncEventQueue from a worker thread.""" if event_queue is not None: event_queue.put(event_name, payload) -def confirm_execution(payload: str, event_queue: events.SyncEventQueue, ticket_id: str) -> bool: +def confirm_execution(payload: str, event_queue: events.AsyncEventQueue, ticket_id: str) -> bool: """ Pushes an approval request to the GUI and waits for response. """ @@ -370,7 +370,7 @@ def confirm_execution(payload: str, event_queue: events.SyncEventQueue, ticket_i return approved return False -def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.SyncEventQueue, ticket_id: str) -> Tuple[bool, str, str]: +def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.AsyncEventQueue, ticket_id: str) -> Tuple[bool, str, str]: """ Pushes a spawn approval request to the GUI and waits for response. Returns (approved, modified_prompt, modified_context) @@ -409,7 +409,7 @@ def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.S return approved, modified_prompt, modified_context return False, prompt, context_md -def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] | None = None, event_queue: events.SyncEventQueue | None = None, engine: Optional['ConductorEngine'] = None, md_content: str = "") -> None: +def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] | None = None, event_queue: events.AsyncEventQueue | None = None, engine: Optional['ConductorEngine'] = None, md_content: str = "") -> None: """ Simulates the lifecycle of a single agent working on a ticket. Calls the AI client and updates the ticket status based on the response. diff --git a/tests/test_websocket_server.py b/tests/test_websocket_server.py new file mode 100644 index 0000000..2ba3ef7 --- /dev/null +++ b/tests/test_websocket_server.py @@ -0,0 +1,44 @@ +import pytest +import asyncio +import json +import websockets +from src.api_hooks import WebSocketServer + +@pytest.mark.asyncio +async def test_websocket_subscription_and_broadcast(): + # Mock app + app = type("MockApp", (), {"test_hooks_enabled": True})() + + # Start server on a specific port + port = 9005 + server = WebSocketServer(app, port=port) + server.start() + + # Wait for server to start + await asyncio.sleep(0.5) + + try: + uri = f"ws://127.0.0.1:{port}" + async with websockets.connect(uri) as websocket: + # Subscribe to events channel + subscribe_msg = {"action": "subscribe", "channel": "events"} + await websocket.send(json.dumps(subscribe_msg)) + + # Receive confirmation + response = await websocket.recv() + data = json.loads(response) + assert data["type"] == "subscription_confirmed" + assert data["channel"] == "events" + + # Broadcast an event from the server + event_payload = {"event": "test_event", "data": "hello"} + server.broadcast("events", event_payload) + + # Receive the broadcast + broadcast_response = await websocket.recv() + broadcast_data = json.loads(broadcast_response) + assert broadcast_data["channel"] == "events" + assert broadcast_data["payload"] == event_payload + + finally: + server.stop()