diff --git a/src/warmup.py b/src/warmup.py index 9f45f450..ea49819d 100644 --- a/src/warmup.py +++ b/src/warmup.py @@ -29,6 +29,7 @@ Canary records (one per submitted module) carry: """ import importlib +import sys import threading import time from concurrent.futures import Future, ThreadPoolExecutor @@ -39,9 +40,10 @@ CompletionCallback = Callable[[dict], None] class WarmupManager: - def __init__(self, pool: ThreadPoolExecutor) -> None: + def __init__(self, pool: ThreadPoolExecutor, log_to_stderr: bool = True) -> None: self._pool = pool self._lock = threading.Lock() + self._log_lock = threading.Lock() self._done_event = threading.Event() self._pending: list[str] = [] self._completed: list[str] = [] @@ -51,6 +53,30 @@ class WarmupManager: # Canary observability state (per-module import tracking). self._canaries: list[dict] = [] self._next_canary_id: int = 1 + # Stderr logging: when True, the manager prints a one-line summary + # of each completed/failed canary to stderr, plus a final aggregate + # line when the entire warmup finishes. Default True so production + # runs get observability for free. Tests can opt out. + self._log_to_stderr: bool = log_to_stderr + # Capture the main thread ident at construction time so we can flag + # any canary that ran on the main thread (a main-thread-purity violation). + self._main_thread_ident: int = threading.get_ident() + self._pending: list[str] = [] + self._completed: list[str] = [] + self._failed: list[str] = [] + self._callbacks: list[CompletionCallback] = [] + self._started = False + # Canary observability state (per-module import tracking). + self._canaries: list[dict] = [] + self._next_canary_id: int = 1 + # Stderr logging: when True, the manager prints a one-line summary + # of each completed/failed canary to stderr, plus a final aggregate + # line when the entire warmup finishes. Default True so production + # runs get observability for free. Tests can opt out. + self._log_to_stderr: bool = log_to_stderr + # Capture the main thread ident at construction time so we can flag + # any canary that ran on the main thread (a main-thread-purity violation). + self._main_thread_ident: int = threading.get_ident() def submit(self, modules: list[str]) -> None: submit_ts = time.time() @@ -156,6 +182,8 @@ class WarmupManager: def _record_success(self, name: str, end_ts: Optional[float] = None) -> None: if end_ts is None: end_ts = time.time() callbacks: list[CompletionCallback] = [] + canary_snapshot: Optional[dict] = None + all_done = False with self._lock: if name in self._pending: self._pending.remove(name) @@ -166,11 +194,17 @@ class WarmupManager: c["end_ts"] = end_ts if c["start_ts"] is not None: c["elapsed_ms"] = (end_ts - c["start_ts"]) * 1000 + canary_snapshot = dict(c) break done = self._started and not self._pending if done: self._done_event.set() callbacks = list(self._callbacks) + all_done = True + if canary_snapshot is not None: + self._log_canary(canary_snapshot) + if all_done: + self._log_summary() for cb in callbacks: try: cb(self._snapshot()) @@ -180,6 +214,8 @@ class WarmupManager: def _record_failure(self, name: str, _err: BaseException, end_ts: Optional[float] = None) -> None: if end_ts is None: end_ts = time.time() callbacks: list[CompletionCallback] = [] + canary_snapshot: Optional[dict] = None + all_done = False with self._lock: if name in self._pending: self._pending.remove(name) @@ -191,17 +227,71 @@ class WarmupManager: c["error"] = f"{type(_err).__name__}: {_err}" if c["start_ts"] is not None: c["elapsed_ms"] = (end_ts - c["start_ts"]) * 1000 + canary_snapshot = dict(c) break done = self._started and not self._pending if done: self._done_event.set() callbacks = list(self._callbacks) + all_done = True + if canary_snapshot is not None: + self._log_canary(canary_snapshot) + if all_done: + self._log_summary() for cb in callbacks: try: cb(self._snapshot()) except Exception: pass + def _log_canary(self, canary: dict) -> None: + if not self._log_to_stderr: return + cid = canary["canary_id"] + module = canary["module"] + thread_name = canary.get("thread_name") or "?" + thread_id = canary.get("thread_id") + elapsed = canary.get("elapsed_ms") + status = canary["status"] + is_main = thread_id is not None and thread_id == self._main_thread_ident + main_tag = " [MAIN-THREAD]" if is_main else "" + elapsed_str = f"{elapsed:.1f}ms" if elapsed is not None else "?ms" + with self._log_lock: + if status == "completed": + line = f"[warmup {cid}] {module} on {thread_name} (id={thread_id}): {elapsed_str}{main_tag}\n" + elif status == "failed": + err = canary.get("error") or "?" + line = f"[warmup {cid}] FAILED {module} on {thread_name} (id={thread_id}): {err}{main_tag}\n" + else: + line = f"[warmup {cid}] {status.upper()} {module} on {thread_name} (id={thread_id}){main_tag}\n" + try: + sys.stderr.write(line) + sys.stderr.flush() + except Exception: pass + + def _log_summary(self) -> None: + if not self._log_to_stderr: return + with self._lock: + canaries = list(self._canaries) + if not canaries: return + total = len(canaries) + completed = sum(1 for c in canaries if c["status"] == "completed") + failed = sum(1 for c in canaries if c["status"] == "failed") + cancelled = sum(1 for c in canaries if c["status"] == "cancelled") + main_thread_violations = [c["module"] for c in canaries if c.get("thread_id") == self._main_thread_ident] + total_ms = 0.0 + for c in canaries: + if c.get("elapsed_ms"): total_ms += c["elapsed_ms"] + parts = [f"{completed} completed"] + if failed: parts.append(f"{failed} failed") + if cancelled: parts.append(f"{cancelled} cancelled") + with self._log_lock: + try: + sys.stderr.write(f"[warmup done] {total} modules: {', '.join(parts)} (sum of per-module elapsed: {total_ms:.1f}ms)\n") + if main_thread_violations: + sys.stderr.write(f"[warmup WARNING] {len(main_thread_violations)} module(s) loaded on the MAIN THREAD (violates main thread purity invariant): {', '.join(main_thread_violations)}\n") + sys.stderr.flush() + except Exception: pass + def _snapshot(self) -> dict[str, list[str]]: return { "pending": list(self._pending), diff --git a/tests/test_warmup_canaries.py b/tests/test_warmup_canaries.py index b27b03e5..8f3568c5 100644 --- a/tests/test_warmup_canaries.py +++ b/tests/test_warmup_canaries.py @@ -27,9 +27,9 @@ from src.io_pool import make_io_pool def _build_warmup() -> tuple[WarmupManager, object]: - """Build a fresh WarmupManager + pool for testing.""" + """Build a fresh WarmupManager + pool for testing (silent by default).""" pool = make_io_pool() - mgr = WarmupManager(pool) + mgr = WarmupManager(pool, log_to_stderr=False) return mgr, pool @@ -164,3 +164,78 @@ def test_canary_canary_id_increments_across_resets() -> None: f"canary_ids should be [first=1, second=2]; got {second_ids}" ) pool.shutdown(wait=True) + +def test_warmup_logs_to_stderr_on_completion(capsys: pytest.CaptureFixture) -> None: + """Successful canaries print a one-line summary to stderr.""" + pool = make_io_pool() + mgr = WarmupManager(pool, log_to_stderr=True) + mgr.submit(["os", "json"]) + assert mgr.wait(timeout=10.0) + captured = capsys.readouterr() + # Each completed module should have a log line + assert "[warmup" in captured.err + assert " os " in captured.err + assert " json " in captured.err + # Format: "[warmup N] module on thread (id=IDENT): ELAPSEDms" + assert "controller-io" in captured.err + assert "ms" in captured.err + pool.shutdown(wait=True) + + +def test_warmup_can_be_quiet(capsys: pytest.CaptureFixture) -> None: + """log_to_stderr=False suppresses the per-module log lines.""" + pool = make_io_pool() + mgr = WarmupManager(pool, log_to_stderr=False) + mgr.submit(["os", "json"]) + assert mgr.wait(timeout=10.0) + captured = capsys.readouterr() + # No per-module log lines + assert "[warmup]" not in captured.err + # But the structured canary records still exist + canaries = mgr.canaries() + assert len(canaries) == 2 + assert all(c["status"] == "completed" for c in canaries) + pool.shutdown(wait=True) + + +def test_warmup_logs_total_time_at_completion(capsys: pytest.CaptureFixture) -> None: + """A summary line is printed when the entire warmup completes.""" + pool = make_io_pool() + mgr = WarmupManager(pool, log_to_stderr=True) + mgr.submit(["os", "json"]) + assert mgr.wait(timeout=10.0) + captured = capsys.readouterr() + # Summary line contains "done" or "ready" or "complete" + err_lines = [l for l in captured.err.splitlines() if l.strip()] + assert len(err_lines) >= 3 # 2 per-module + 1 summary + # The summary should mention total/total_ms or something aggregate + summary_line = err_lines[-1] + assert "warmup" in summary_line.lower() + pool.shutdown(wait=True) + + +def test_warmup_logs_failure_to_stderr(capsys: pytest.CaptureFixture) -> None: + """A failed import prints a FAILED log line to stderr.""" + pool = make_io_pool() + mgr = WarmupManager(pool, log_to_stderr=True) + mgr.submit(["definitely_does_not_exist_xyz_12345"]) + assert mgr.wait(timeout=10.0) + captured = capsys.readouterr() + # Should contain a FAILED marker + assert "FAILED" in captured.err + assert "definitely_does_not_exist_xyz_12345" in captured.err + pool.shutdown(wait=True) + + +def test_warmup_log_line_includes_thread_id(capsys: pytest.CaptureFixture) -> None: + """The log line includes the thread_id (matching the canary record).""" + pool = make_io_pool() + mgr = WarmupManager(pool, log_to_stderr=True) + mgr.submit(["os"]) + assert mgr.wait(timeout=10.0) + canaries = mgr.canaries() + captured = capsys.readouterr() + # The thread_id from the canary should appear in the log + thread_id = str(canaries[0]["thread_id"]) + assert thread_id in captured.err, f"expected thread_id {thread_id} in stderr: {captured.err!r}" + pool.shutdown(wait=True)