"""WarmupManager: import heavy modules on a background thread pool. Per spec.md:2.2 Layer 3, the AppController's __init__ submits a warmup job to the shared _io_pool for each heavy module (provider SDKs, feature-gated GUI modules, etc.). After warmup completes, those modules are in sys.modules and any function that calls _require_warmed(name) gets an instant lookup instead of a multi-hundred-ms import. Public API on the manager (and exposed on AppController via delegation): mgr.submit(modules) - start warmup jobs (call once at init) mgr.status() - {pending, completed, failed} mgr.is_done() - bool mgr.wait(timeout) - block until done mgr.on_complete(callback) - register completion callback mgr.canaries() - list[dict] of per-module canary records (observability) mgr.reset() - clear state (for re-warmup, e.g. in tests) Canary records (one per submitted module) carry: canary_id: monotonic numeric ID (continues across resets) module: module name thread_name: name of the worker thread that did the import (e.g. "controller-io-0") thread_id: threading.get_ident() of that worker submit_ts: wall-clock when submit() was called start_ts: wall-clock when the worker started the import end_ts: wall-clock when the import finished elapsed_ms: (end_ts - start_ts) * 1000 status: "running" | "completed" | "failed" | "cancelled" error: error message string if status == "failed", else None """ import importlib import sys import threading import time from concurrent.futures import Future, ThreadPoolExecutor from typing import Callable, Optional CompletionCallback = Callable[[dict], None] class WarmupManager: def __init__(self, pool: ThreadPoolExecutor, log_to_stderr: bool = True) -> None: self._pool = pool self._lock = threading.Lock() self._log_lock = threading.Lock() self._done_event = threading.Event() self._pending: list[str] = [] self._completed: list[str] = [] self._failed: list[str] = [] self._callbacks: list[CompletionCallback] = [] self._started = False # Canary observability state (per-module import tracking). self._canaries: list[dict] = [] self._next_canary_id: int = 1 # Stderr logging: when True, the manager prints a one-line summary # of each completed/failed canary to stderr, plus a final aggregate # line when the entire warmup finishes. Default True so production # runs get observability for free. Tests can opt out. self._log_to_stderr: bool = log_to_stderr # Capture the main thread ident at construction time so we can flag # any canary that ran on the main thread (a main-thread-purity violation). self._main_thread_ident: int = threading.get_ident() self._pending: list[str] = [] self._completed: list[str] = [] self._failed: list[str] = [] self._callbacks: list[CompletionCallback] = [] self._started = False # Canary observability state (per-module import tracking). self._canaries: list[dict] = [] self._next_canary_id: int = 1 # Stderr logging: when True, the manager prints a one-line summary # of each completed/failed canary to stderr, plus a final aggregate # line when the entire warmup finishes. Default True so production # runs get observability for free. Tests can opt out. self._log_to_stderr: bool = log_to_stderr # Capture the main thread ident at construction time so we can flag # any canary that ran on the main thread (a main-thread-purity violation). self._main_thread_ident: int = threading.get_ident() def submit(self, modules: list[str]) -> None: submit_ts = time.time() with self._lock: if self._started: raise RuntimeError("WarmupManager.submit() called twice; call reset() first") self._pending = list(modules) self._completed = [] self._failed = [] self._done_event.clear() self._started = True for name in modules: canary = { "canary_id": self._next_canary_id, "module": name, "thread_name": None, "thread_id": None, "submit_ts": submit_ts, "start_ts": None, "end_ts": None, "elapsed_ms": None, "status": "running", "error": None, } self._next_canary_id += 1 self._canaries.append(canary) for name in modules: self._pool.submit(self._warmup_one, name) def status(self) -> dict[str, list[str]]: with self._lock: return { "pending": list(self._pending), "completed": list(self._completed), "failed": list(self._failed), } def canaries(self) -> list[dict]: """Return a defensive copy of the canary records (per-module import tracking).""" with self._lock: return [dict(c) for c in self._canaries] def is_done(self) -> bool: return self._done_event.is_set() def wait(self, timeout: Optional[float] = None) -> bool: return self._done_event.wait(timeout=timeout) def on_complete(self, callback: CompletionCallback) -> None: fire_now = False with self._lock: if self._done_event.is_set(): fire_now = True snap = self._snapshot() else: self._callbacks.append(callback) if fire_now: try: callback(snap) except Exception: pass def reset(self) -> None: with self._lock: self._pending = [] self._completed = [] self._failed = [] self._done_event.clear() self._callbacks = [] self._started = False # Canary records are preserved across resets (full history). # Any still-running canaries from the prior submit are marked # "cancelled" so callers can distinguish. for c in self._canaries: if c["status"] == "running": c["status"] = "cancelled" c["end_ts"] = c.get("end_ts") or time.time() if c.get("start_ts") and c["elapsed_ms"] is None: c["elapsed_ms"] = (c["end_ts"] - c["start_ts"]) * 1000 def _warmup_one(self, name: str) -> None: start_ts = time.time() thread = threading.current_thread() thread_name = thread.name thread_id = thread.ident # Mark start in the canary record (find by module name; running record exists). with self._lock: for c in self._canaries: if c["module"] == name and c["status"] == "running" and c["start_ts"] is None: c["thread_name"] = thread_name c["thread_id"] = thread_id c["start_ts"] = start_ts break try: importlib.import_module(name) except BaseException as e: end_ts = time.time() self._record_failure(name, e, end_ts) else: end_ts = time.time() self._record_success(name, end_ts) def _record_success(self, name: str, end_ts: Optional[float] = None) -> None: if end_ts is None: end_ts = time.time() callbacks: list[CompletionCallback] = [] canary_snapshot: Optional[dict] = None all_done = False with self._lock: if name in self._pending: self._pending.remove(name) self._completed.append(name) for c in self._canaries: if c["module"] == name and c["status"] == "running": c["status"] = "completed" c["end_ts"] = end_ts if c["start_ts"] is not None: c["elapsed_ms"] = (end_ts - c["start_ts"]) * 1000 canary_snapshot = dict(c) break done = self._started and not self._pending if done: self._done_event.set() callbacks = list(self._callbacks) all_done = True if canary_snapshot is not None: self._log_canary(canary_snapshot) if all_done: self._log_summary() for cb in callbacks: try: cb(self._snapshot()) except Exception: pass def _record_failure(self, name: str, _err: BaseException, end_ts: Optional[float] = None) -> None: if end_ts is None: end_ts = time.time() callbacks: list[CompletionCallback] = [] canary_snapshot: Optional[dict] = None all_done = False with self._lock: if name in self._pending: self._pending.remove(name) self._failed.append(name) for c in self._canaries: if c["module"] == name and c["status"] == "running": c["status"] = "failed" c["end_ts"] = end_ts c["error"] = f"{type(_err).__name__}: {_err}" if c["start_ts"] is not None: c["elapsed_ms"] = (end_ts - c["start_ts"]) * 1000 canary_snapshot = dict(c) break done = self._started and not self._pending if done: self._done_event.set() callbacks = list(self._callbacks) all_done = True if canary_snapshot is not None: self._log_canary(canary_snapshot) if all_done: self._log_summary() for cb in callbacks: try: cb(self._snapshot()) except Exception: pass def _log_canary(self, canary: dict) -> None: if not self._log_to_stderr: return cid = canary["canary_id"] module = canary["module"] thread_name = canary.get("thread_name") or "?" thread_id = canary.get("thread_id") elapsed = canary.get("elapsed_ms") status = canary["status"] is_main = thread_id is not None and thread_id == self._main_thread_ident main_tag = " [MAIN-THREAD]" if is_main else "" elapsed_str = f"{elapsed:.1f}ms" if elapsed is not None else "?ms" with self._log_lock: if status == "completed": line = f"[warmup {cid}] {module} on {thread_name} (id={thread_id}): {elapsed_str}{main_tag}\n" elif status == "failed": err = canary.get("error") or "?" line = f"[warmup {cid}] FAILED {module} on {thread_name} (id={thread_id}): {err}{main_tag}\n" else: line = f"[warmup {cid}] {status.upper()} {module} on {thread_name} (id={thread_id}){main_tag}\n" try: sys.stderr.write(line) sys.stderr.flush() except Exception: pass def _log_summary(self) -> None: if not self._log_to_stderr: return with self._lock: canaries = list(self._canaries) if not canaries: return total = len(canaries) completed = sum(1 for c in canaries if c["status"] == "completed") failed = sum(1 for c in canaries if c["status"] == "failed") cancelled = sum(1 for c in canaries if c["status"] == "cancelled") main_thread_violations = [c["module"] for c in canaries if c.get("thread_id") == self._main_thread_ident] total_ms = 0.0 for c in canaries: if c.get("elapsed_ms"): total_ms += c["elapsed_ms"] parts = [f"{completed} completed"] if failed: parts.append(f"{failed} failed") if cancelled: parts.append(f"{cancelled} cancelled") with self._log_lock: try: sys.stderr.write(f"[warmup done] {total} modules: {', '.join(parts)} (sum of per-module elapsed: {total_ms:.1f}ms)\n") if main_thread_violations: sys.stderr.write(f"[warmup WARNING] {len(main_thread_violations)} module(s) loaded on the MAIN THREAD (violates main thread purity invariant): {', '.join(main_thread_violations)}\n") sys.stderr.flush() except Exception: pass def _snapshot(self) -> dict[str, list[str]]: return { "pending": list(self._pending), "completed": list(self._completed), "failed": list(self._failed), }