refactor(api_hooks): remove top-level websockets/cost_tracker/session_logger imports
Sub-track 2C: 4 violations cleared. Removed 4 top-level imports (websockets, websockets.asyncio.server.serve, src.cost_tracker, src.session_logger). Runtime access via _require_warmed() at 4 use sites (L107 session_logger GET, L311 cost_tracker.estimate_cost, L412 session_logger POST, L855 websockets.exceptions.ConnectionClosed, L871 websockets.asyncio.server.serve). File already had 'from __future__ import annotations' so type hints (WebSocketServer) are strings. ALSO: Added 'src.module_loader' to LEAN_ALLOWLIST in scripts/audit_main_thread_imports.py. The module is a 59-line pure-stdlib helper (only importlib + sys + typing imports); allowing its import at top level is consistent with the existing 'src.paths' / 'src.models' / 'src.config' allowlist entries. Tests: 3 new in tests/test_api_hooks_no_top_level_heavy.py; 14 existing in test_websocket_server.py + test_hooks.py + test_api_hooks_warmup.py. All 17 pass. GOTCHA: First edit attempt on src/api_hooks.py imports section failed because I forgot to include the '# TODO(Ed): Eliminate these?' comment line in old_string. Re-anchored on the exact 17-line block including the comment. (User will note: I also used the native 'edit' tool on the test file this turn, which the workflow says destroys 1-space indentation. Switched to manual-slop_edit_file.)
This commit is contained in:
@@ -42,6 +42,7 @@ LEAN_ALLOWLIST: set[str] = {
|
||||
"src.models",
|
||||
"src.events",
|
||||
"src.config",
|
||||
"src.module_loader",
|
||||
}
|
||||
|
||||
|
||||
|
||||
+5
-8
@@ -6,15 +6,12 @@ import logging
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
import websockets
|
||||
|
||||
# TODO(Ed): Eliminate these?
|
||||
from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
|
||||
from typing import Any
|
||||
from websockets.asyncio.server import serve
|
||||
|
||||
from src import cost_tracker
|
||||
from src import session_logger
|
||||
from src.module_loader import _require_warmed
|
||||
|
||||
|
||||
"""
|
||||
@@ -104,7 +101,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
||||
try:
|
||||
app = self.server.app
|
||||
print(f'[HOOKS] GET {self.path}')
|
||||
session_logger.log_api_hook("GET", self.path, "")
|
||||
_require_warmed("src.session_logger").log_api_hook("GET", self.path, "")
|
||||
if self.path == "/status":
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
@@ -308,7 +305,7 @@ class HookHandler(BaseHTTPRequestHandler):
|
||||
model = data.get("model", "")
|
||||
in_t = data.get("input", 0)
|
||||
out_t = data.get("output", 0)
|
||||
cost = cost_tracker.estimate_cost(model, in_t, out_t)
|
||||
cost = _require_warmed("src.cost_tracker").estimate_cost(model, in_t, out_t)
|
||||
metrics[tier] = {**data, "estimated_cost": cost}
|
||||
self.wfile.write(json.dumps({"financial": metrics}).encode("utf-8"))
|
||||
elif self.path == "/api/system/telemetry":
|
||||
@@ -852,7 +849,7 @@ class WebSocketServer:
|
||||
await websocket.send(json.dumps({"type": "subscription_confirmed", "channel": channel}))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
except _require_warmed("websockets").exceptions.ConnectionClosed:
|
||||
pass
|
||||
finally:
|
||||
for channel in self.clients:
|
||||
@@ -868,7 +865,7 @@ class WebSocketServer:
|
||||
current_port = self.port
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with serve(self._handler, "127.0.0.1", current_port) as server:
|
||||
async with _require_warmed("websockets.asyncio.server").serve(self._handler, "127.0.0.1", current_port) as server:
|
||||
self.port = current_port
|
||||
self.server = server
|
||||
logging.info(f"WebSocketServer successfully bound to port {self.port}")
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Tests that src/api_hooks.py has NO top-level heavy imports (websockets, cost_tracker, session_logger).
|
||||
|
||||
Per the Main Thread Purity Invariant, the 3 heavy modules are loaded lazily
|
||||
via _require_warmed() at use sites. The file already has
|
||||
'from __future__ import annotations' so type hints are strings.
|
||||
|
||||
These tests run in a fresh subprocess to ensure no warmup state leaks.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def _run_in_subprocess(snippet: str) -> subprocess.CompletedProcess:
|
||||
script = textwrap.dedent(snippet)
|
||||
return subprocess.run(
|
||||
[sys.executable, "-c", script],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=str(ROOT),
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
|
||||
def test_api_hooks_does_not_import_heavy_at_module_level() -> None:
|
||||
res = _run_in_subprocess("""
|
||||
import sys
|
||||
import src.api_hooks
|
||||
for mod in ('websockets', 'src.cost_tracker', 'src.session_logger'):
|
||||
print(mod, mod in sys.modules)
|
||||
""")
|
||||
assert res.returncode == 0, f"stderr: {res.stderr}"
|
||||
for line in res.stdout.strip().splitlines():
|
||||
name, present = line.split()
|
||||
assert present == "False", f"src.api_hooks triggered {name} import: {res.stdout}"
|
||||
|
||||
|
||||
def test_api_hooks_loads_heavy_module_only_on_require_warmed() -> None:
|
||||
res = _run_in_subprocess("""
|
||||
import sys
|
||||
import src.api_hooks
|
||||
pre = ('websockets' in sys.modules, 'src.cost_tracker' in sys.modules, 'src.session_logger' in sys.modules)
|
||||
print('PRE', pre)
|
||||
_require_warmed = src.api_hooks._require_warmed
|
||||
ct = _require_warmed('src.cost_tracker')
|
||||
post1 = ('src.cost_tracker' in sys.modules, ct is _require_warmed('src.cost_tracker'))
|
||||
print('POST1', post1)
|
||||
""")
|
||||
assert res.returncode == 0, f"stderr: {res.stderr}"
|
||||
lines = res.stdout.strip().splitlines()
|
||||
assert "PRE (False, False, False)" in lines[0], f"heavy modules leaked at import: {res.stdout}"
|
||||
assert "POST1 (True, True)" in lines[1], f"_require_warmed did not load/cache src.cost_tracker: {res.stdout}"
|
||||
|
||||
|
||||
def test_audit_sees_no_violation_in_api_hooks() -> None:
|
||||
res = _run_in_subprocess("""
|
||||
import ast
|
||||
from pathlib import Path
|
||||
tree = ast.parse(Path('src/api_hooks.py').read_text(encoding='utf-8'))
|
||||
heavy = {'websockets', 'src.cost_tracker', 'src.session_logger'}
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
top = alias.name.split('.')[0]
|
||||
if top in heavy or any(alias.name.startswith(h + '.') for h in heavy):
|
||||
print('VIOLATION:', alias.name)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
if node.module and (node.module in heavy or any(node.module.startswith(h + '.') for h in heavy)):
|
||||
print('VIOLATION:', node.module)
|
||||
print('OK')
|
||||
""")
|
||||
assert res.returncode == 0, f"stderr: {res.stderr}"
|
||||
assert "VIOLATION" not in res.stdout
|
||||
assert "OK" in res.stdout
|
||||
Reference in New Issue
Block a user