diff --git a/src/warmup.py b/src/warmup.py index ea49819d..08922de4 100644 --- a/src/warmup.py +++ b/src/warmup.py @@ -198,9 +198,13 @@ class WarmupManager: break done = self._started and not self._pending if done: - self._done_event.set() callbacks = list(self._callbacks) all_done = True + # NOTE: do NOT set _done_event here. We set it AFTER callbacks + # fire (below) so that `wait()` does not return before user + # on_complete callbacks have run. This closes the race where + # a test thread calling `mgr.wait()` could observe `is_done()` + # and proceed before the on_complete side effects were visible. if canary_snapshot is not None: self._log_canary(canary_snapshot) if all_done: @@ -210,6 +214,8 @@ class WarmupManager: cb(self._snapshot()) except Exception: pass + if all_done: + self._done_event.set() def _record_failure(self, name: str, _err: BaseException, end_ts: Optional[float] = None) -> None: if end_ts is None: end_ts = time.time() @@ -231,7 +237,6 @@ class WarmupManager: break done = self._started and not self._pending if done: - self._done_event.set() callbacks = list(self._callbacks) all_done = True if canary_snapshot is not None: @@ -243,6 +248,8 @@ class WarmupManager: cb(self._snapshot()) except Exception: pass + if all_done: + self._done_event.set() def _log_canary(self, canary: dict) -> None: if not self._log_to_stderr: return diff --git a/tests/test_warmup.py b/tests/test_warmup.py index 39bf993f..1c9d6334 100644 --- a/tests/test_warmup.py +++ b/tests/test_warmup.py @@ -53,12 +53,10 @@ def test_warmup_status_reflects_failures() -> None: pool.shutdown(wait=False) -@pytest.mark.skip(reason="Pre-existing flaky test: warmup of stdlib modules 'os' and 'sys' completes synchronously on a fast machine before the test can assert is_done()==False. Test assumes async behavior that doesn't hold. Tracked as pre-existing in state.toml.") def test_warmup_done_event_set_after_all_complete() -> None: pool = _make_pool() mgr = WarmupManager(pool) mgr.submit(["os", "sys"]) - assert not mgr.is_done() mgr.wait(timeout=5) assert mgr.is_done() pool.shutdown(wait=False) @@ -73,7 +71,6 @@ def test_warmup_wait_blocks_until_done() -> None: pool.shutdown(wait=False) -@pytest.mark.skip(reason="Pre-existing flaky test: mgr.wait() returns when _done_event is set (under the lock in _record_success), but the on_complete callbacks fire AFTER the lock is released, in the worker thread. The test's main thread can be unblocked from wait() before the callback appends to 'received'. Race condition. Tracked as pre-existing.") def test_warmup_on_complete_callback_fires() -> None: pool = _make_pool() mgr = WarmupManager(pool)