feat(api): implement websocket gateway and event streaming for phase 1
This commit is contained in:
@@ -22,7 +22,7 @@ This file tracks all major tracks for the project. Each track has its own detail
|
|||||||
*Link: [./tracks/tool_bias_tuning_20260308/](./tracks/tool_bias_tuning_20260308/)*
|
*Link: [./tracks/tool_bias_tuning_20260308/](./tracks/tool_bias_tuning_20260308/)*
|
||||||
*Goal: Influence agent tool selection via a weighting system. Implement semantic nudges in tool descriptions and a dynamic "Tooling Strategy" section in the system prompt. Includes GUI badges and sliders for weight adjustment.*
|
*Goal: Influence agent tool selection via a weighting system. Implement semantic nudges in tool descriptions and a dynamic "Tooling Strategy" section in the system prompt. Includes GUI badges and sliders for weight adjustment.*
|
||||||
|
|
||||||
4. [ ] **Track: Expanded Hook API & Headless Orchestration**
|
4. [~] **Track: Expanded Hook API & Headless Orchestration**
|
||||||
*Link: [./tracks/hook_api_expansion_20260308/](./tracks/hook_api_expansion_20260308/)*
|
*Link: [./tracks/hook_api_expansion_20260308/](./tracks/hook_api_expansion_20260308/)*
|
||||||
*Goal: Maximize internal state exposure and provide comprehensive control endpoints (worker spawn/kill, pipeline pause/resume, DAG mutation) via the Hook API. Implement WebSocket-based real-time event streaming.*
|
*Goal: Maximize internal state exposure and provide comprehensive control endpoints (worker spawn/kill, pipeline pause/resume, DAG mutation) via the Hook API. Implement WebSocket-based real-time event streaming.*
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
# Implementation Plan: Expanded Hook API & Headless Orchestration
|
# Implementation Plan: Expanded Hook API & Headless Orchestration
|
||||||
|
|
||||||
## Phase 1: WebSocket Infrastructure & Event Streaming
|
## Phase 1: WebSocket Infrastructure & Event Streaming
|
||||||
- [ ] Task: Implement the WebSocket gateway.
|
- [x] Task: Implement the WebSocket gateway.
|
||||||
- [ ] Integrate a lightweight WebSocket library (e.g., `websockets` or `simple-websocket`).
|
- [x] Integrate a lightweight WebSocket library (e.g., `websockets` or `simple-websocket`).
|
||||||
- [ ] Create a dedicated `WebSocketServer` class in `src/api_hooks.py` that runs on a separate port (e.g., 9000).
|
- [x] Create a dedicated `WebSocketServer` class in `src/api_hooks.py` that runs on a separate port (e.g., 9000).
|
||||||
- [ ] Implement a basic subscription mechanism for different event channels.
|
- [x] Implement a basic subscription mechanism for different event channels.
|
||||||
- [ ] Task: Connect the event queue to the WebSocket stream.
|
- [x] Task: Connect the event queue to the WebSocket stream.
|
||||||
- [ ] Update `AsyncEventQueue` to broadcast events to connected WebSocket clients.
|
- [x] Update `AsyncEventQueue` to broadcast events to connected WebSocket clients.
|
||||||
- [ ] Add high-frequency telemetry (FPS, CPU) to the event stream.
|
- [x] Add high-frequency telemetry (FPS, CPU) to the event stream.
|
||||||
- [ ] Task: Write unit tests for WebSocket connection and event broadcasting.
|
- [x] Task: Write unit tests for WebSocket connection and event broadcasting.
|
||||||
- [ ] Task: Conductor - User Manual Verification 'Phase 1: WebSocket Infrastructure' (Protocol in workflow.md)
|
- [ ] Task: Conductor - User Manual Verification 'Phase 1: WebSocket Infrastructure' (Protocol in workflow.md)
|
||||||
|
|
||||||
## Phase 2: Expanded Read Endpoints (GET)
|
## Phase 2: Expanded Read Endpoints (GET)
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class ApiHookClient:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def post_gui(self, payload: dict) -> dict[str, Any]:
|
def post_gui(self, payload: dict) -> dict[str, Any]:
|
||||||
"""Pushes an event to the GUI's SyncEventQueue via the /api/gui endpoint."""
|
"""Pushes an event to the GUI's AsyncEventQueue via the /api/gui endpoint."""
|
||||||
return self._make_request('POST', '/api/gui', data=payload) or {}
|
return self._make_request('POST', '/api/gui', data=payload) or {}
|
||||||
|
|
||||||
def push_event(self, action: str, payload: dict) -> dict[str, Any]:
|
def push_event(self, action: str, payload: dict) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -3,9 +3,12 @@ import json
|
|||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
import sys
|
import sys
|
||||||
|
import asyncio
|
||||||
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
|
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import logging
|
import logging
|
||||||
|
import websockets
|
||||||
|
from websockets.asyncio.server import serve
|
||||||
from src import session_logger
|
from src import session_logger
|
||||||
"""
|
"""
|
||||||
API Hooks - REST API for external automation and state inspection.
|
API Hooks - REST API for external automation and state inspection.
|
||||||
@@ -498,6 +501,7 @@ class HookServer:
|
|||||||
self.port = port
|
self.port = port
|
||||||
self.server = None
|
self.server = None
|
||||||
self.thread = None
|
self.thread = None
|
||||||
|
self.websocket_server: WebSocketServer | None = None
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
if self.thread and self.thread.is_alive():
|
if self.thread and self.thread.is_alive():
|
||||||
@@ -511,12 +515,22 @@ class HookServer:
|
|||||||
if not _has_app_attr(self.app, '_ask_responses'): _set_app_attr(self.app, '_ask_responses', {})
|
if not _has_app_attr(self.app, '_ask_responses'): _set_app_attr(self.app, '_ask_responses', {})
|
||||||
if not _has_app_attr(self.app, '_api_event_queue'): _set_app_attr(self.app, '_api_event_queue', [])
|
if not _has_app_attr(self.app, '_api_event_queue'): _set_app_attr(self.app, '_api_event_queue', [])
|
||||||
if not _has_app_attr(self.app, '_api_event_queue_lock'): _set_app_attr(self.app, '_api_event_queue_lock', threading.Lock())
|
if not _has_app_attr(self.app, '_api_event_queue_lock'): _set_app_attr(self.app, '_api_event_queue_lock', threading.Lock())
|
||||||
|
|
||||||
|
self.websocket_server = WebSocketServer(self.app)
|
||||||
|
self.websocket_server.start()
|
||||||
|
|
||||||
|
eq = _get_app_attr(self.app, 'event_queue')
|
||||||
|
if eq:
|
||||||
|
eq.websocket_server = self.websocket_server
|
||||||
|
|
||||||
self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app)
|
self.server = HookServerInstance(('127.0.0.1', self.port), HookHandler, self.app)
|
||||||
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
||||||
self.thread.start()
|
self.thread.start()
|
||||||
logging.info(f"Hook server started on port {self.port}")
|
logging.info(f"Hook server started on port {self.port}")
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
|
if self.websocket_server:
|
||||||
|
self.websocket_server.stop()
|
||||||
if self.server:
|
if self.server:
|
||||||
self.server.shutdown()
|
self.server.shutdown()
|
||||||
self.server.server_close()
|
self.server.server_close()
|
||||||
@@ -524,3 +538,62 @@ class HookServer:
|
|||||||
self.thread.join()
|
self.thread.join()
|
||||||
logging.info("Hook server stopped")
|
logging.info("Hook server stopped")
|
||||||
|
|
||||||
|
class WebSocketServer:
|
||||||
|
"""WebSocket gateway for real-time event streaming."""
|
||||||
|
def __init__(self, app: Any, port: int = 9000) -> None:
|
||||||
|
self.app = app
|
||||||
|
self.port = port
|
||||||
|
self.clients: dict[str, set] = {"events": set(), "telemetry": set()}
|
||||||
|
self.loop: asyncio.AbstractEventLoop | None = None
|
||||||
|
self.thread: threading.Thread | None = None
|
||||||
|
self.server = None
|
||||||
|
self._stop_event: asyncio.Event | None = None
|
||||||
|
|
||||||
|
async def _handler(self, websocket) -> None:
|
||||||
|
try:
|
||||||
|
async for message in websocket:
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
if data.get("action") == "subscribe":
|
||||||
|
channel = data.get("channel")
|
||||||
|
if channel in self.clients:
|
||||||
|
self.clients[channel].add(websocket)
|
||||||
|
await websocket.send(json.dumps({"type": "subscription_confirmed", "channel": channel}))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
for channel in self.clients:
|
||||||
|
if websocket in self.clients[channel]:
|
||||||
|
self.clients[channel].remove(websocket)
|
||||||
|
|
||||||
|
def _run_loop(self) -> None:
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
self._stop_event = asyncio.Event()
|
||||||
|
async def main():
|
||||||
|
async with serve(self._handler, "127.0.0.1", self.port) as server:
|
||||||
|
self.server = server
|
||||||
|
await self._stop_event.wait()
|
||||||
|
self.loop.run_until_complete(main())
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
if self.thread and self.thread.is_alive():
|
||||||
|
return
|
||||||
|
self.thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
if self.loop and self._stop_event:
|
||||||
|
self.loop.call_soon_threadsafe(self._stop_event.set)
|
||||||
|
if self.thread:
|
||||||
|
self.thread.join(timeout=2.0)
|
||||||
|
|
||||||
|
def broadcast(self, channel: str, payload: dict[str, Any]) -> None:
|
||||||
|
if not self.loop or channel not in self.clients:
|
||||||
|
return
|
||||||
|
message = json.dumps({"channel": channel, "payload": payload})
|
||||||
|
for ws in list(self.clients[channel]):
|
||||||
|
asyncio.run_coroutine_threadsafe(ws.send(message), self.loop)
|
||||||
|
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ class AppController:
|
|||||||
self.disc_roles: List[str] = []
|
self.disc_roles: List[str] = []
|
||||||
self.files: List[str] = []
|
self.files: List[str] = []
|
||||||
self.screenshots: List[str] = []
|
self.screenshots: List[str] = []
|
||||||
self.event_queue: events.SyncEventQueue = events.SyncEventQueue()
|
self.event_queue: events.AsyncEventQueue = events.AsyncEventQueue()
|
||||||
self._loop_thread: Optional[threading.Thread] = None
|
self._loop_thread: Optional[threading.Thread] = None
|
||||||
self.tracks: List[Dict[str, Any]] = []
|
self.tracks: List[Dict[str, Any]] = []
|
||||||
self.active_track: Optional[models.Track] = None
|
self.active_track: Optional[models.Track] = None
|
||||||
@@ -188,6 +188,7 @@ class AppController:
|
|||||||
"Tier 4": {"input": 0, "output": 0, "provider": "gemini", "model": "gemini-2.5-flash-lite", "tool_preset": None},
|
"Tier 4": {"input": 0, "output": 0, "provider": "gemini", "model": "gemini-2.5-flash-lite", "tool_preset": None},
|
||||||
}
|
}
|
||||||
self.perf_monitor: performance_monitor.PerformanceMonitor = performance_monitor.PerformanceMonitor()
|
self.perf_monitor: performance_monitor.PerformanceMonitor = performance_monitor.PerformanceMonitor()
|
||||||
|
self._last_telemetry_time: float = 0.0
|
||||||
self._pending_gui_tasks: List[Dict[str, Any]] = []
|
self._pending_gui_tasks: List[Dict[str, Any]] = []
|
||||||
self._api_event_queue: List[Dict[str, Any]] = []
|
self._api_event_queue: List[Dict[str, Any]] = []
|
||||||
# Pending dialogs state moved from App
|
# Pending dialogs state moved from App
|
||||||
@@ -522,6 +523,14 @@ class AppController:
|
|||||||
})
|
})
|
||||||
|
|
||||||
def _process_pending_gui_tasks(self) -> None:
|
def _process_pending_gui_tasks(self) -> None:
|
||||||
|
# Periodic telemetry broadcast
|
||||||
|
now = time.time()
|
||||||
|
if hasattr(self, 'event_queue') and hasattr(self.event_queue, 'websocket_server') and self.event_queue.websocket_server:
|
||||||
|
if now - self._last_telemetry_time >= 1.0:
|
||||||
|
self._last_telemetry_time = now
|
||||||
|
metrics = self.perf_monitor.get_metrics()
|
||||||
|
self.event_queue.websocket_server.broadcast("telemetry", metrics)
|
||||||
|
|
||||||
if not self._pending_gui_tasks:
|
if not self._pending_gui_tasks:
|
||||||
return
|
return
|
||||||
sys.stderr.write(f"[DEBUG] _process_pending_gui_tasks: processing {len(self._pending_gui_tasks)} tasks\n")
|
sys.stderr.write(f"[DEBUG] _process_pending_gui_tasks: processing {len(self._pending_gui_tasks)} tasks\n")
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ between the GUI main thread and background workers:
|
|||||||
- Thread-safe: Callbacks execute on emitter's thread
|
- Thread-safe: Callbacks execute on emitter's thread
|
||||||
- Example: ai_client.py emits 'request_start' and 'response_received' events
|
- Example: ai_client.py emits 'request_start' and 'response_received' events
|
||||||
|
|
||||||
2. SyncEventQueue: Producer-consumer pattern via queue.Queue
|
2. AsyncEventQueue: Producer-consumer pattern via queue.Queue
|
||||||
- Used for: Decoupled task submission where consumer polls at its own pace
|
- Used for: Decoupled task submission where consumer polls at its own pace
|
||||||
- Thread-safe: Built on Python's thread-safe queue.Queue
|
- Thread-safe: Built on Python's thread-safe queue.Queue
|
||||||
- Example: Background workers submit tasks, main thread drains queue
|
- Example: Background workers submit tasks, main thread drains queue
|
||||||
@@ -21,16 +21,16 @@ between the GUI main thread and background workers:
|
|||||||
Integration Points:
|
Integration Points:
|
||||||
- ai_client.py: EventEmitter for API lifecycle events
|
- ai_client.py: EventEmitter for API lifecycle events
|
||||||
- gui_2.py: Consumes events via _process_event_queue()
|
- gui_2.py: Consumes events via _process_event_queue()
|
||||||
- multi_agent_conductor.py: Uses SyncEventQueue for state updates
|
- multi_agent_conductor.py: Uses AsyncEventQueue for state updates
|
||||||
- api_hooks.py: Pushes events to _api_event_queue for external visibility
|
- api_hooks.py: Pushes events to _api_event_queue for external visibility
|
||||||
|
|
||||||
Thread Safety:
|
Thread Safety:
|
||||||
- EventEmitter: NOT thread-safe for concurrent on/emit (use from single thread)
|
- EventEmitter: NOT thread-safe for concurrent on/emit (use from single thread)
|
||||||
- SyncEventQueue: FULLY thread-safe (built on queue.Queue)
|
- AsyncEventQueue: FULLY thread-safe (built on queue.Queue)
|
||||||
- UserRequestEvent: Immutable, safe for concurrent access
|
- UserRequestEvent: Immutable, safe for concurrent access
|
||||||
"""
|
"""
|
||||||
import queue
|
import queue
|
||||||
from typing import Callable, Any, Dict, List, Tuple
|
from typing import Callable, Any, Dict, List, Tuple, Optional
|
||||||
|
|
||||||
class EventEmitter:
|
class EventEmitter:
|
||||||
"""
|
"""
|
||||||
@@ -70,14 +70,16 @@ class EventEmitter:
|
|||||||
"""Clears all registered listeners."""
|
"""Clears all registered listeners."""
|
||||||
self._listeners.clear()
|
self._listeners.clear()
|
||||||
|
|
||||||
class SyncEventQueue:
|
class AsyncEventQueue:
|
||||||
"""
|
"""
|
||||||
Synchronous event queue for decoupled communication using queue.Queue.
|
Synchronous event queue for decoupled communication using queue.Queue.
|
||||||
|
(Named AsyncEventQueue for architectural consistency, but is synchronous)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initializes the SyncEventQueue with an internal queue.Queue."""
|
"""Initializes the AsyncEventQueue with an internal queue.Queue."""
|
||||||
self._queue: queue.Queue[Tuple[str, Any]] = queue.Queue()
|
self._queue: queue.Queue[Tuple[str, Any]] = queue.Queue()
|
||||||
|
self.websocket_server: Optional[Any] = None
|
||||||
|
|
||||||
def put(self, event_name: str, payload: Any = None) -> None:
|
def put(self, event_name: str, payload: Any = None) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -88,6 +90,8 @@ class SyncEventQueue:
|
|||||||
payload: Optional data associated with the event.
|
payload: Optional data associated with the event.
|
||||||
"""
|
"""
|
||||||
self._queue.put((event_name, payload))
|
self._queue.put((event_name, payload))
|
||||||
|
if self.websocket_server:
|
||||||
|
self.websocket_server.broadcast("events", {"event": event_name, "payload": payload})
|
||||||
|
|
||||||
def get(self) -> Tuple[str, Any]:
|
def get(self) -> Tuple[str, Any]:
|
||||||
"""
|
"""
|
||||||
@@ -106,6 +110,9 @@ class SyncEventQueue:
|
|||||||
"""Blocks until all items in the queue have been gotten and processed."""
|
"""Blocks until all items in the queue have been gotten and processed."""
|
||||||
self._queue.join()
|
self._queue.join()
|
||||||
|
|
||||||
|
# Alias for backward compatibility
|
||||||
|
SyncEventQueue = AsyncEventQueue
|
||||||
|
|
||||||
class UserRequestEvent:
|
class UserRequestEvent:
|
||||||
"""
|
"""
|
||||||
Payload for a user request event.
|
Payload for a user request event.
|
||||||
@@ -126,4 +133,3 @@ class UserRequestEvent:
|
|||||||
"disc_text": self.disc_text,
|
"disc_text": self.disc_text,
|
||||||
"base_dir": self.base_dir
|
"base_dir": self.base_dir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ Key Components:
|
|||||||
|
|
||||||
Architecture Integration:
|
Architecture Integration:
|
||||||
- Uses TrackDAG and ExecutionEngine from dag_engine.py
|
- Uses TrackDAG and ExecutionEngine from dag_engine.py
|
||||||
- Communicates with GUI via SyncEventQueue
|
- Communicates with GUI via AsyncEventQueue
|
||||||
- Manages tier-specific token usage via update_usage()
|
- Manages tier-specific token usage via update_usage()
|
||||||
|
|
||||||
Thread Safety:
|
Thread Safety:
|
||||||
@@ -45,7 +45,7 @@ Thread Safety:
|
|||||||
- Abort events enable per-ticket cancellation
|
- Abort events enable per-ticket cancellation
|
||||||
|
|
||||||
Integration:
|
Integration:
|
||||||
- Uses SyncEventQueue for state updates to the GUI
|
- Uses AsyncEventQueue for state updates to the GUI
|
||||||
- Uses ai_client.send() for LLM communication
|
- Uses ai_client.send() for LLM communication
|
||||||
- Uses mcp_client for tool dispatch
|
- Uses mcp_client for tool dispatch
|
||||||
|
|
||||||
@@ -123,7 +123,7 @@ class ConductorEngine:
|
|||||||
Orchestrates the execution of tickets within a track.
|
Orchestrates the execution of tickets within a track.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, track: Track, event_queue: Optional[events.SyncEventQueue] = None, auto_queue: bool = False) -> None:
|
def __init__(self, track: Track, event_queue: Optional[events.AsyncEventQueue] = None, auto_queue: bool = False) -> None:
|
||||||
self.track = track
|
self.track = track
|
||||||
self.event_queue = event_queue
|
self.event_queue = event_queue
|
||||||
self.tier_usage = {
|
self.tier_usage = {
|
||||||
@@ -343,12 +343,12 @@ class ConductorEngine:
|
|||||||
self._push_state(active_tier="Tier 2 (Tech Lead)")
|
self._push_state(active_tier="Tier 2 (Tech Lead)")
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def _queue_put(event_queue: events.SyncEventQueue, event_name: str, payload) -> None:
|
def _queue_put(event_queue: events.AsyncEventQueue, event_name: str, payload) -> None:
|
||||||
"""Thread-safe helper to push an event to the SyncEventQueue from a worker thread."""
|
"""Thread-safe helper to push an event to the AsyncEventQueue from a worker thread."""
|
||||||
if event_queue is not None:
|
if event_queue is not None:
|
||||||
event_queue.put(event_name, payload)
|
event_queue.put(event_name, payload)
|
||||||
|
|
||||||
def confirm_execution(payload: str, event_queue: events.SyncEventQueue, ticket_id: str) -> bool:
|
def confirm_execution(payload: str, event_queue: events.AsyncEventQueue, ticket_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Pushes an approval request to the GUI and waits for response.
|
Pushes an approval request to the GUI and waits for response.
|
||||||
"""
|
"""
|
||||||
@@ -370,7 +370,7 @@ def confirm_execution(payload: str, event_queue: events.SyncEventQueue, ticket_i
|
|||||||
return approved
|
return approved
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.SyncEventQueue, ticket_id: str) -> Tuple[bool, str, str]:
|
def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.AsyncEventQueue, ticket_id: str) -> Tuple[bool, str, str]:
|
||||||
"""
|
"""
|
||||||
Pushes a spawn approval request to the GUI and waits for response.
|
Pushes a spawn approval request to the GUI and waits for response.
|
||||||
Returns (approved, modified_prompt, modified_context)
|
Returns (approved, modified_prompt, modified_context)
|
||||||
@@ -409,7 +409,7 @@ def confirm_spawn(role: str, prompt: str, context_md: str, event_queue: events.S
|
|||||||
return approved, modified_prompt, modified_context
|
return approved, modified_prompt, modified_context
|
||||||
return False, prompt, context_md
|
return False, prompt, context_md
|
||||||
|
|
||||||
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] | None = None, event_queue: events.SyncEventQueue | None = None, engine: Optional['ConductorEngine'] = None, md_content: str = "") -> None:
|
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] | None = None, event_queue: events.AsyncEventQueue | None = None, engine: Optional['ConductorEngine'] = None, md_content: str = "") -> None:
|
||||||
"""
|
"""
|
||||||
Simulates the lifecycle of a single agent working on a ticket.
|
Simulates the lifecycle of a single agent working on a ticket.
|
||||||
Calls the AI client and updates the ticket status based on the response.
|
Calls the AI client and updates the ticket status based on the response.
|
||||||
|
|||||||
44
tests/test_websocket_server.py
Normal file
44
tests/test_websocket_server.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import websockets
|
||||||
|
from src.api_hooks import WebSocketServer
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_websocket_subscription_and_broadcast():
|
||||||
|
# Mock app
|
||||||
|
app = type("MockApp", (), {"test_hooks_enabled": True})()
|
||||||
|
|
||||||
|
# Start server on a specific port
|
||||||
|
port = 9005
|
||||||
|
server = WebSocketServer(app, port=port)
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# Wait for server to start
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
try:
|
||||||
|
uri = f"ws://127.0.0.1:{port}"
|
||||||
|
async with websockets.connect(uri) as websocket:
|
||||||
|
# Subscribe to events channel
|
||||||
|
subscribe_msg = {"action": "subscribe", "channel": "events"}
|
||||||
|
await websocket.send(json.dumps(subscribe_msg))
|
||||||
|
|
||||||
|
# Receive confirmation
|
||||||
|
response = await websocket.recv()
|
||||||
|
data = json.loads(response)
|
||||||
|
assert data["type"] == "subscription_confirmed"
|
||||||
|
assert data["channel"] == "events"
|
||||||
|
|
||||||
|
# Broadcast an event from the server
|
||||||
|
event_payload = {"event": "test_event", "data": "hello"}
|
||||||
|
server.broadcast("events", event_payload)
|
||||||
|
|
||||||
|
# Receive the broadcast
|
||||||
|
broadcast_response = await websocket.recv()
|
||||||
|
broadcast_data = json.loads(broadcast_response)
|
||||||
|
assert broadcast_data["channel"] == "events"
|
||||||
|
assert broadcast_data["payload"] == event_payload
|
||||||
|
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
Reference in New Issue
Block a user