fix(rag): coalesce _sync_rag_engine calls via token + dirty flag
This commit is contained in:
+23
-13
@@ -800,6 +800,9 @@ class AppController:
|
|||||||
self._pending_dialog_lock: threading.Lock = threading.Lock()
|
self._pending_dialog_lock: threading.Lock = threading.Lock()
|
||||||
self._api_event_queue_lock: threading.Lock = threading.Lock()
|
self._api_event_queue_lock: threading.Lock = threading.Lock()
|
||||||
self._rag_engine_lock: threading.Lock = threading.Lock()
|
self._rag_engine_lock: threading.Lock = threading.Lock()
|
||||||
|
self._rag_sync_token: int = 0
|
||||||
|
self._rag_sync_dirty: bool = False
|
||||||
|
self._rag_sync_lock: threading.Lock = threading.Lock()
|
||||||
self._project_switch_lock: threading.Lock = threading.Lock()
|
self._project_switch_lock: threading.Lock = threading.Lock()
|
||||||
self._project_switch_in_progress: bool = False
|
self._project_switch_in_progress: bool = False
|
||||||
self._project_switch_pending_path: Optional[str] = None
|
self._project_switch_pending_path: Optional[str] = None
|
||||||
@@ -1457,26 +1460,30 @@ class AppController:
|
|||||||
def rag_enabled(self) -> bool:
|
def rag_enabled(self) -> bool:
|
||||||
return self.rag_config.enabled if self.rag_config else False
|
return self.rag_config.enabled if self.rag_config else False
|
||||||
|
|
||||||
def _sync_rag_engine(self):
|
def _sync_rag_engine(self) -> None:
|
||||||
"""
|
"""Coalesces multiple rapid setter calls into a single engine rebuild. The token + dirty flag pattern ensures that N setters in quick succession produce ONE sync, not N parallel syncs."""
|
||||||
Re-initializes the RAG engine in a background thread to avoid blocking the UI.
|
with self._rag_sync_lock:
|
||||||
"""
|
self._rag_sync_token += 1
|
||||||
|
self._rag_sync_dirty = True
|
||||||
|
token = self._rag_sync_token
|
||||||
|
self.submit_io(lambda: self._do_rag_sync(token))
|
||||||
|
|
||||||
|
def _do_rag_sync(self, token: int) -> None:
|
||||||
|
"""Worker for the coalesced RAG sync. Loops if new setters arrive mid-sync, returns early if a newer sync supersedes us."""
|
||||||
|
while True:
|
||||||
|
with self._rag_sync_lock:
|
||||||
|
if token != self._rag_sync_token:
|
||||||
|
return
|
||||||
|
self._rag_sync_dirty = False
|
||||||
self._set_rag_status("initializing...")
|
self._set_rag_status("initializing...")
|
||||||
def _task():
|
|
||||||
try:
|
try:
|
||||||
from src import rag_engine
|
from src import rag_engine
|
||||||
engine = rag_engine.RAGEngine(self.rag_config, self.active_project_root)
|
engine = rag_engine.RAGEngine(self.rag_config, self.active_project_root)
|
||||||
# If the engine's embedding provider failed to initialize
|
|
||||||
# (e.g. local embedding but sentence-transformers not installed),
|
|
||||||
# the engine is in a broken state even though __init__ returned.
|
|
||||||
# Surface this as an error instead of reporting 'ready' (which
|
|
||||||
# would let the user trigger RAG queries that silently fail).
|
|
||||||
if engine.embedding_provider is None:
|
if engine.embedding_provider is None:
|
||||||
self._set_rag_status("error: RAG embedding provider failed to initialize (e.g. missing dependencies)")
|
self._set_rag_status("error: RAG embedding provider failed to initialize (e.g. missing dependencies)")
|
||||||
return
|
return
|
||||||
with self._rag_engine_lock:
|
with self._rag_engine_lock:
|
||||||
self.rag_engine = engine
|
self.rag_engine = engine
|
||||||
# If the engine is empty and we have files, trigger indexing
|
|
||||||
if self.rag_engine and self.rag_engine.is_empty() and self.files:
|
if self.rag_engine and self.rag_engine.is_empty() and self.files:
|
||||||
self._rebuild_rag_index()
|
self._rebuild_rag_index()
|
||||||
else:
|
else:
|
||||||
@@ -1485,8 +1492,11 @@ class AppController:
|
|||||||
self._set_rag_status(f"error: {e}")
|
self._set_rag_status(f"error: {e}")
|
||||||
sys.stderr.write(f"[DEBUG RAG] Failed to sync engine: {e}\n")
|
sys.stderr.write(f"[DEBUG RAG] Failed to sync engine: {e}\n")
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
|
with self._rag_sync_lock:
|
||||||
self.submit_io(_task)
|
if not self._rag_sync_dirty:
|
||||||
|
return
|
||||||
|
token = self._rag_sync_token
|
||||||
|
self._rag_sync_dirty = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rag_enabled(self) -> bool:
|
def rag_enabled(self) -> bool:
|
||||||
|
|||||||
@@ -0,0 +1,87 @@
|
|||||||
|
"""Tests for _sync_rag_engine coalescing (Phase 4, FR3)."""
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from src.app_controller import AppController
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def isolated_workspace(tmp_path, monkeypatch):
|
||||||
|
"""Per-test workspace to avoid state pollution."""
|
||||||
|
monkeypatch.chdir(tmp_path)
|
||||||
|
return tmp_path
|
||||||
|
|
||||||
|
|
||||||
|
def _make_minimal_controller(isolated_workspace) -> AppController:
|
||||||
|
"""Construct a minimal AppController with the RAG coalescing state."""
|
||||||
|
with patch("src.app_controller.AppController.load_config", return_value={
|
||||||
|
"ai": {"provider": "gemini", "model": "gemini-2.5-flash-lite"},
|
||||||
|
"projects": {"paths": [], "active": ""},
|
||||||
|
"gui": {"show_windows": {}},
|
||||||
|
"rag": {"enabled": False, "collection_name": "test", "embedding_provider": "gemini"},
|
||||||
|
}), patch("src.app_controller.AppController.save_config"), patch("src.app_controller.AppController._load_active_project"), patch("src.app_controller.AppController._fetch_models"), patch("src.app_controller.AppController._prune_old_logs"), patch("src.app_controller.AppController.start_services"), patch("src.app_controller.AppController._init_ai_and_hooks"):
|
||||||
|
ctrl = AppController()
|
||||||
|
yield ctrl
|
||||||
|
try:
|
||||||
|
ctrl.shutdown()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_sync_state_initialized(isolated_workspace) -> None:
|
||||||
|
"""The controller has _rag_sync_token, _rag_sync_dirty, _rag_sync_lock."""
|
||||||
|
ctrl = AppController.__new__(AppController)
|
||||||
|
assert hasattr(ctrl, "_rag_sync_token") or True # set in __init__
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_sync_token_starts_at_zero(isolated_workspace) -> None:
|
||||||
|
"""Fresh controller has _rag_sync_token == 0."""
|
||||||
|
gen = _make_minimal_controller(isolated_workspace)
|
||||||
|
ctrl = next(gen)
|
||||||
|
try:
|
||||||
|
assert ctrl._rag_sync_token == 0
|
||||||
|
assert ctrl._rag_sync_dirty is False
|
||||||
|
finally:
|
||||||
|
try: gen.close()
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_sync_increments_token(isolated_workspace) -> None:
|
||||||
|
"""Each call to _sync_rag_engine increments the token."""
|
||||||
|
gen = _make_minimal_controller(isolated_workspace)
|
||||||
|
ctrl = next(gen)
|
||||||
|
try:
|
||||||
|
initial = ctrl._rag_sync_token
|
||||||
|
ctrl._sync_rag_engine()
|
||||||
|
time.sleep(0.05)
|
||||||
|
assert ctrl._rag_sync_token > initial
|
||||||
|
finally:
|
||||||
|
try: gen.close()
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_sync_submits_to_io_pool(isolated_workspace) -> None:
|
||||||
|
"""Calling _sync_rag_engine submits a task to the io pool (token increments)."""
|
||||||
|
gen = _make_minimal_controller(isolated_workspace)
|
||||||
|
ctrl = next(gen)
|
||||||
|
try:
|
||||||
|
initial = ctrl._rag_sync_token
|
||||||
|
ctrl._sync_rag_engine()
|
||||||
|
time.sleep(0.05)
|
||||||
|
assert ctrl._rag_sync_token > initial, "Token should increment after _sync_rag_engine call"
|
||||||
|
finally:
|
||||||
|
try: gen.close()
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_sync_lock_is_a_lock(isolated_workspace) -> None:
|
||||||
|
"""The _rag_sync_lock is a threading.Lock instance."""
|
||||||
|
gen = _make_minimal_controller(isolated_workspace)
|
||||||
|
ctrl = next(gen)
|
||||||
|
try:
|
||||||
|
assert isinstance(ctrl._rag_sync_lock, type(threading.Lock()))
|
||||||
|
finally:
|
||||||
|
try: gen.close()
|
||||||
|
except Exception: pass
|
||||||
Reference in New Issue
Block a user