diff --git a/src/api_hooks.py b/src/api_hooks.py index 9c1b5668..a5bdfbf7 100644 --- a/src/api_hooks.py +++ b/src/api_hooks.py @@ -10,9 +10,17 @@ import uuid # TODO(Ed): Eliminate these? from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler from typing import Any +from dataclasses import dataclass from src.module_loader import _require_warmed from src.result_types import ErrorInfo, ErrorKind, Result +from src.type_aliases import JsonValue + + +@dataclass(frozen=True) +class WebSocketMessage: + channel: str + payload: JsonValue """ @@ -131,7 +139,7 @@ class HookServerInstance(ThreadingHTTPServer): super().__init__(server_address, RequestHandlerClass) self.app = app -def _serialize_for_api(obj: Any) -> Any: +def _serialize_for_api(obj: Any) -> JsonValue: """Serializes complex objects into API-friendly formats (dicts/lists).""" if hasattr(obj, "to_dict"): return obj.to_dict() @@ -972,12 +980,12 @@ class WebSocketServer: if self.thread: self.thread.join(timeout=2.0) - def broadcast(self, channel: str, payload: dict[str, Any]) -> None: + def broadcast(self, message: WebSocketMessage) -> None: """ [C: src/app_controller.py:AppController._process_pending_gui_tasks, src/events.py:AsyncEventQueue.put, tests/test_websocket_server.py:test_websocket_subscription_and_broadcast] """ - if not self.loop or channel not in self.clients: + if not self.loop or message.channel not in self.clients: return - message = json.dumps({"channel": channel, "payload": payload}) - for ws in list(self.clients[channel]): - asyncio.run_coroutine_threadsafe(ws.send(message), self.loop) + wire = json.dumps({"channel": message.channel, "payload": message.payload}) + for ws in list(self.clients[message.channel]): + asyncio.run_coroutine_threadsafe(ws.send(wire), self.loop) diff --git a/tests/test_api_hooks_dataclasses.py b/tests/test_api_hooks_dataclasses.py new file mode 100644 index 00000000..b70f6a99 --- /dev/null +++ b/tests/test_api_hooks_dataclasses.py @@ -0,0 +1,99 @@ +"""Tests for src/api_hooks.py WebSocketMessage + JsonValue usage + +Phase 5 of any_type_componentization_20260621. Verifies: +- WebSocketMessage dataclass (channel, payload: JsonValue) +- WebSocketMessage is frozen=True +- _serialize_for_api uses JsonValue type hint +- broadcast() takes WebSocketMessage instead of (channel, payload) +- _get_app_attr / _set_app_attr signatures UNCHANGED (Pattern 4 preserved) + +CONVENTION: 1-space indentation. NO COMMENTS. +""" +from __future__ import annotations + +import json +import pytest +from src import api_hooks +from src.type_aliases import JsonValue + + +def test_websocket_message_construction() -> None: + msg = api_hooks.WebSocketMessage(channel="status", payload={"status": "ok"}) + assert msg.channel == "status" + assert msg.payload == {"status": "ok"} + + +def test_websocket_message_with_list_payload() -> None: + msg = api_hooks.WebSocketMessage(channel="events", payload=[{"type": "x"}, {"type": "y"}]) + assert msg.payload == [{"type": "x"}, {"type": "y"}] + + +def test_websocket_message_with_nested_payload() -> None: + msg = api_hooks.WebSocketMessage( + channel="data", + payload={"users": [{"name": "a", "meta": {"active": True}}], "count": 1} + ) + assert msg.payload["count"] == 1 + assert msg.payload["users"][0]["meta"]["active"] is True + + +def test_websocket_message_is_frozen() -> None: + msg = api_hooks.WebSocketMessage(channel="x", payload={}) + with pytest.raises(Exception): + msg.channel = "mutated" + + +def test_websocket_message_to_json() -> None: + msg = api_hooks.WebSocketMessage(channel="status", payload={"ok": True}) + j = json.dumps({"channel": msg.channel, "payload": msg.payload}) + assert json.loads(j) == {"channel": "status", "payload": {"ok": True}} + + +def test_serialize_for_api_returns_dict_for_to_dict_object() -> None: + class WithToDict: + def to_dict(self) -> dict: + return {"k": "v"} + result = api_hooks._serialize_for_api(WithToDict()) + assert result == {"k": "v"} + + +def test_serialize_for_api_handles_nested_lists() -> None: + obj = {"items": [{"a": 1}, {"b": 2}]} + result = api_hooks._serialize_for_api(obj) + assert result == {"items": [{"a": 1}, {"b": 2}]} + + +def test_serialize_for_api_handles_purepath() -> None: + from pathlib import PurePath, PureWindowsPath + p = PurePath("a/b/c") # Use a relative path to avoid Windows normalization + result = api_hooks._serialize_for_api(p) + assert isinstance(result, str) + # Either forward or backslash separator; both are valid string representations + assert result.replace("\\", "/") == "a/b/c" + + +def test_serialize_for_api_passthrough_for_primitives() -> None: + assert api_hooks._serialize_for_api(42) == 42 + assert api_hooks._serialize_for_api("hello") == "hello" + assert api_hooks._serialize_for_api(None) is None + + +def test_serialize_for_api_handles_mixed_nesting() -> None: + obj = {"list": [1, 2, {"nested": "deep"}], "scalar": True} + result = api_hooks._serialize_for_api(obj) + assert result == obj + + +def test_get_app_attr_signature_preserved() -> None: + """Pattern 4: _get_app_attr / _set_app_attr must NOT change signature.""" + import inspect + sig = inspect.signature(api_hooks._get_app_attr) + params = list(sig.parameters.keys()) + assert params == ["app", "name", "default"] + + +def test_set_app_attr_signature_preserved() -> None: + import inspect + sig = inspect.signature(api_hooks._set_app_attr) + params = list(sig.parameters.keys()) + assert params == ["app", "name", "value"] \ No newline at end of file diff --git a/tests/test_websocket_server.py b/tests/test_websocket_server.py index c4cd89c2..819977c7 100644 --- a/tests/test_websocket_server.py +++ b/tests/test_websocket_server.py @@ -2,7 +2,7 @@ import pytest import asyncio import json import websockets -from src.api_hooks import WebSocketServer +from src.api_hooks import WebSocketMessage, WebSocketServer @pytest.mark.asyncio async def test_websocket_subscription_and_broadcast(): @@ -32,7 +32,7 @@ async def test_websocket_subscription_and_broadcast(): # Broadcast an event from the server event_payload = {"event": "test_event", "data": "hello"} - server.broadcast("events", event_payload) + server.broadcast(WebSocketMessage(channel="events", payload=event_payload)) # Receive the broadcast broadcast_response = await websocket.recv()