diff --git a/scripts/audit_main_thread_imports.py b/scripts/audit_main_thread_imports.py index 5775648d..4db81db6 100644 --- a/scripts/audit_main_thread_imports.py +++ b/scripts/audit_main_thread_imports.py @@ -42,6 +42,7 @@ LEAN_ALLOWLIST: set[str] = { "src.models", "src.events", "src.config", + "src.module_loader", } diff --git a/src/api_hooks.py b/src/api_hooks.py index fee232d3..68eac12e 100644 --- a/src/api_hooks.py +++ b/src/api_hooks.py @@ -6,15 +6,12 @@ import logging import sys import threading import uuid -import websockets # TODO(Ed): Eliminate these? from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler from typing import Any -from websockets.asyncio.server import serve -from src import cost_tracker -from src import session_logger +from src.module_loader import _require_warmed """ @@ -104,7 +101,7 @@ class HookHandler(BaseHTTPRequestHandler): try: app = self.server.app print(f'[HOOKS] GET {self.path}') - session_logger.log_api_hook("GET", self.path, "") + _require_warmed("src.session_logger").log_api_hook("GET", self.path, "") if self.path == "/status": self.send_response(200) self.send_header("Content-Type", "application/json") @@ -308,7 +305,7 @@ class HookHandler(BaseHTTPRequestHandler): model = data.get("model", "") in_t = data.get("input", 0) out_t = data.get("output", 0) - cost = cost_tracker.estimate_cost(model, in_t, out_t) + cost = _require_warmed("src.cost_tracker").estimate_cost(model, in_t, out_t) metrics[tier] = {**data, "estimated_cost": cost} self.wfile.write(json.dumps({"financial": metrics}).encode("utf-8")) elif self.path == "/api/system/telemetry": @@ -852,7 +849,7 @@ class WebSocketServer: await websocket.send(json.dumps({"type": "subscription_confirmed", "channel": channel})) except json.JSONDecodeError: pass - except websockets.exceptions.ConnectionClosed: + except _require_warmed("websockets").exceptions.ConnectionClosed: pass finally: for channel in self.clients: @@ -868,7 +865,7 @@ class WebSocketServer: current_port = self.port for attempt in range(max_retries): try: - async with serve(self._handler, "127.0.0.1", current_port) as server: + async with _require_warmed("websockets.asyncio.server").serve(self._handler, "127.0.0.1", current_port) as server: self.port = current_port self.server = server logging.info(f"WebSocketServer successfully bound to port {self.port}") diff --git a/tests/test_api_hooks_no_top_level_heavy.py b/tests/test_api_hooks_no_top_level_heavy.py new file mode 100644 index 00000000..00d2c102 --- /dev/null +++ b/tests/test_api_hooks_no_top_level_heavy.py @@ -0,0 +1,78 @@ +"""Tests that src/api_hooks.py has NO top-level heavy imports (websockets, cost_tracker, session_logger). + +Per the Main Thread Purity Invariant, the 3 heavy modules are loaded lazily +via _require_warmed() at use sites. The file already has +'from __future__ import annotations' so type hints are strings. + +These tests run in a fresh subprocess to ensure no warmup state leaks. +""" + +import subprocess +import sys +import textwrap +from pathlib import Path + +ROOT = Path(__file__).resolve().parent.parent + + +def _run_in_subprocess(snippet: str) -> subprocess.CompletedProcess: + script = textwrap.dedent(snippet) + return subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + text=True, + cwd=str(ROOT), + timeout=30, + ) + + +def test_api_hooks_does_not_import_heavy_at_module_level() -> None: + res = _run_in_subprocess(""" + import sys + import src.api_hooks + for mod in ('websockets', 'src.cost_tracker', 'src.session_logger'): + print(mod, mod in sys.modules) + """) + assert res.returncode == 0, f"stderr: {res.stderr}" + for line in res.stdout.strip().splitlines(): + name, present = line.split() + assert present == "False", f"src.api_hooks triggered {name} import: {res.stdout}" + + +def test_api_hooks_loads_heavy_module_only_on_require_warmed() -> None: + res = _run_in_subprocess(""" + import sys + import src.api_hooks + pre = ('websockets' in sys.modules, 'src.cost_tracker' in sys.modules, 'src.session_logger' in sys.modules) + print('PRE', pre) + _require_warmed = src.api_hooks._require_warmed + ct = _require_warmed('src.cost_tracker') + post1 = ('src.cost_tracker' in sys.modules, ct is _require_warmed('src.cost_tracker')) + print('POST1', post1) + """) + assert res.returncode == 0, f"stderr: {res.stderr}" + lines = res.stdout.strip().splitlines() + assert "PRE (False, False, False)" in lines[0], f"heavy modules leaked at import: {res.stdout}" + assert "POST1 (True, True)" in lines[1], f"_require_warmed did not load/cache src.cost_tracker: {res.stdout}" + + +def test_audit_sees_no_violation_in_api_hooks() -> None: + res = _run_in_subprocess(""" + import ast + from pathlib import Path + tree = ast.parse(Path('src/api_hooks.py').read_text(encoding='utf-8')) + heavy = {'websockets', 'src.cost_tracker', 'src.session_logger'} + for node in tree.body: + if isinstance(node, ast.Import): + for alias in node.names: + top = alias.name.split('.')[0] + if top in heavy or any(alias.name.startswith(h + '.') for h in heavy): + print('VIOLATION:', alias.name) + elif isinstance(node, ast.ImportFrom): + if node.module and (node.module in heavy or any(node.module.startswith(h + '.') for h in heavy)): + print('VIOLATION:', node.module) + print('OK') + """) + assert res.returncode == 0, f"stderr: {res.stderr}" + assert "VIOLATION" not in res.stdout + assert "OK" in res.stdout