feat(api): implement websocket gateway and event streaming for phase 1

This commit is contained in:
2026-03-11 23:01:09 -04:00
parent 00a390ffab
commit 02e0fce548
8 changed files with 158 additions and 26 deletions

View File

@@ -3,9 +3,12 @@ import json
import threading
import uuid
import sys
import asyncio
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
from typing import Any
import logging
import websockets
from websockets.asyncio.server import serve
from src import session_logger
"""
API Hooks - REST API for external automation and state inspection.
@@ -498,6 +501,7 @@ class HookServer:
self.port = port
self.server = None
self.thread = None
self.websocket_server: WebSocketServer | None = None
def start(self) -> None:
if self.thread and self.thread.is_alive():
@@ -511,12 +515,22 @@ class HookServer:
if not _has_app_attr(self.app, '_ask_responses'): _set_app_attr(self.app, '_ask_responses', {})
if not _has_app_attr(self.app, '_api_event_queue'): _set_app_attr(self.app, '_api_event_queue', [])
if not _has_app_attr(self.app, '_api_event_queue_lock'): _set_app_attr(self.app, '_api_event_queue_lock', threading.Lock())
self.websocket_server = WebSocketServer(self.app)
self.websocket_server.start()
eq = _get_app_attr(self.app, 'event_queue')
if eq:
eq.websocket_server = self.websocket_server
self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app)
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.thread.start()
logging.info(f"Hook server started on port {self.port}")
def stop(self) -> None:
if self.websocket_server:
self.websocket_server.stop()
if self.server:
self.server.shutdown()
self.server.server_close()
@@ -524,3 +538,62 @@ class HookServer:
self.thread.join()
logging.info("Hook server stopped")
class WebSocketServer:
"""WebSocket gateway for real-time event streaming."""
def __init__(self, app: Any, port: int = 9000) -> None:
self.app = app
self.port = port
self.clients: dict[str, set] = {"events": set(), "telemetry": set()}
self.loop: asyncio.AbstractEventLoop | None = None
self.thread: threading.Thread | None = None
self.server = None
self._stop_event: asyncio.Event | None = None
async def _handler(self, websocket) -> None:
try:
async for message in websocket:
try:
data = json.loads(message)
if data.get("action") == "subscribe":
channel = data.get("channel")
if channel in self.clients:
self.clients[channel].add(websocket)
await websocket.send(json.dumps({"type": "subscription_confirmed", "channel": channel}))
except json.JSONDecodeError:
pass
except websockets.exceptions.ConnectionClosed:
pass
finally:
for channel in self.clients:
if websocket in self.clients[channel]:
self.clients[channel].remove(websocket)
def _run_loop(self) -> None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self._stop_event = asyncio.Event()
async def main():
async with serve(self._handler, "127.0.0.1", self.port) as server:
self.server = server
await self._stop_event.wait()
self.loop.run_until_complete(main())
def start(self) -> None:
if self.thread and self.thread.is_alive():
return
self.thread = threading.Thread(target=self._run_loop, daemon=True)
self.thread.start()
def stop(self) -> None:
if self.loop and self._stop_event:
self.loop.call_soon_threadsafe(self._stop_event.set)
if self.thread:
self.thread.join(timeout=2.0)
def broadcast(self, channel: str, payload: dict[str, Any]) -> None:
if not self.loop or channel not in self.clients:
return
message = json.dumps({"channel": channel, "payload": payload})
for ws in list(self.clients[channel]):
asyncio.run_coroutine_threadsafe(ws.send(message), self.loop)