feat(api): implement websocket gateway and event streaming for phase 1
This commit is contained in:
@@ -3,9 +3,12 @@ import json
|
||||
import threading
|
||||
import uuid
|
||||
import sys
|
||||
import asyncio
|
||||
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
|
||||
from typing import Any
|
||||
import logging
|
||||
import websockets
|
||||
from websockets.asyncio.server import serve
|
||||
from src import session_logger
|
||||
"""
|
||||
API Hooks - REST API for external automation and state inspection.
|
||||
@@ -498,6 +501,7 @@ class HookServer:
|
||||
self.port = port
|
||||
self.server = None
|
||||
self.thread = None
|
||||
self.websocket_server: WebSocketServer | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
if self.thread and self.thread.is_alive():
|
||||
@@ -511,12 +515,22 @@ class HookServer:
|
||||
if not _has_app_attr(self.app, '_ask_responses'): _set_app_attr(self.app, '_ask_responses', {})
|
||||
if not _has_app_attr(self.app, '_api_event_queue'): _set_app_attr(self.app, '_api_event_queue', [])
|
||||
if not _has_app_attr(self.app, '_api_event_queue_lock'): _set_app_attr(self.app, '_api_event_queue_lock', threading.Lock())
|
||||
|
||||
self.websocket_server = WebSocketServer(self.app)
|
||||
self.websocket_server.start()
|
||||
|
||||
eq = _get_app_attr(self.app, 'event_queue')
|
||||
if eq:
|
||||
eq.websocket_server = self.websocket_server
|
||||
|
||||
self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app)
|
||||
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
||||
self.thread.start()
|
||||
logging.info(f"Hook server started on port {self.port}")
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.websocket_server:
|
||||
self.websocket_server.stop()
|
||||
if self.server:
|
||||
self.server.shutdown()
|
||||
self.server.server_close()
|
||||
@@ -524,3 +538,62 @@ class HookServer:
|
||||
self.thread.join()
|
||||
logging.info("Hook server stopped")
|
||||
|
||||
class WebSocketServer:
|
||||
"""WebSocket gateway for real-time event streaming."""
|
||||
def __init__(self, app: Any, port: int = 9000) -> None:
|
||||
self.app = app
|
||||
self.port = port
|
||||
self.clients: dict[str, set] = {"events": set(), "telemetry": set()}
|
||||
self.loop: asyncio.AbstractEventLoop | None = None
|
||||
self.thread: threading.Thread | None = None
|
||||
self.server = None
|
||||
self._stop_event: asyncio.Event | None = None
|
||||
|
||||
async def _handler(self, websocket) -> None:
|
||||
try:
|
||||
async for message in websocket:
|
||||
try:
|
||||
data = json.loads(message)
|
||||
if data.get("action") == "subscribe":
|
||||
channel = data.get("channel")
|
||||
if channel in self.clients:
|
||||
self.clients[channel].add(websocket)
|
||||
await websocket.send(json.dumps({"type": "subscription_confirmed", "channel": channel}))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
for channel in self.clients:
|
||||
if websocket in self.clients[channel]:
|
||||
self.clients[channel].remove(websocket)
|
||||
|
||||
def _run_loop(self) -> None:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
self._stop_event = asyncio.Event()
|
||||
async def main():
|
||||
async with serve(self._handler, "127.0.0.1", self.port) as server:
|
||||
self.server = server
|
||||
await self._stop_event.wait()
|
||||
self.loop.run_until_complete(main())
|
||||
|
||||
def start(self) -> None:
|
||||
if self.thread and self.thread.is_alive():
|
||||
return
|
||||
self.thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
if self.loop and self._stop_event:
|
||||
self.loop.call_soon_threadsafe(self._stop_event.set)
|
||||
if self.thread:
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
def broadcast(self, channel: str, payload: dict[str, Any]) -> None:
|
||||
if not self.loop or channel not in self.clients:
|
||||
return
|
||||
message = json.dumps({"channel": channel, "payload": payload})
|
||||
for ws in list(self.clients[channel]):
|
||||
asyncio.run_coroutine_threadsafe(ws.send(message), self.loop)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user