refactor(rag_engine): Result API + NilRAGState (_init_vector_store, _validate_collection_dim, _get_state)
This commit is contained in:
+34
-90
@@ -9,6 +9,7 @@ from typing import List, Dict, Any, Optional
|
||||
from src import ai_client
|
||||
from src import models
|
||||
from src import mcp_client
|
||||
from src.result_types import ErrorInfo, ErrorKind, NilRAGState, Result
|
||||
|
||||
from src.file_cache import ASTParser
|
||||
|
||||
@@ -95,7 +96,9 @@ class RAGEngine:
|
||||
if not self.config.enabled: return
|
||||
|
||||
self._init_embedding_provider()
|
||||
self._init_vector_store()
|
||||
r = self._init_vector_store_result()
|
||||
if not r.ok:
|
||||
self.collection = None
|
||||
|
||||
def _init_embedding_provider(self):
|
||||
if self.config.embedding_provider == 'gemini':
|
||||
@@ -105,112 +108,53 @@ class RAGEngine:
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}")
|
||||
|
||||
def _init_vector_store(self):
|
||||
def _init_vector_store_result(self) -> Result[None]:
|
||||
vs_config = self.config.vector_store
|
||||
if vs_config.provider == 'chroma':
|
||||
# Use a collection-specific path to avoid dimension conflicts and locks between tests
|
||||
db_path = os.path.abspath(os.path.join(self.base_dir, ".slop_cache", f"chroma_{vs_config.collection_name}"))
|
||||
os.makedirs(db_path, exist_ok=True)
|
||||
chroma_module = _get_chromadb()
|
||||
if chroma_module is None:
|
||||
raise ImportError("chromadb is not installed")
|
||||
return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.CONFIG, message="chromadb is not installed", source="rag._init_vector_store")])
|
||||
chromadb, Settings = chroma_module
|
||||
self.client = chromadb.PersistentClient(path=db_path)
|
||||
self.client = chromadb.PersistentClient(path=db_path)
|
||||
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
|
||||
self._validate_collection_dim()
|
||||
return self._validate_collection_dim_result()
|
||||
elif vs_config.provider == 'mock':
|
||||
self.client = "mock"
|
||||
self.client = "mock"
|
||||
self.collection = "mock"
|
||||
return Result(data=None)
|
||||
else:
|
||||
raise ValueError(f"Unknown vector store provider: {vs_config.provider}")
|
||||
return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.CONFIG, message=f"Unknown vector store provider: {vs_config.provider}", source="rag._init_vector_store")])
|
||||
|
||||
def _validate_collection_dim(self) -> None:
|
||||
"""
|
||||
Detect dimension mismatch between an existing collection's vectors and
|
||||
the current embedding provider's output. When mismatched (e.g. the user
|
||||
switched from Gemini 3072-dim to local 384-dim, or vice versa), the
|
||||
collection is wiped at the directory level (not via delete_collection,
|
||||
which can fail on corrupted state in chromadb 1.5.x with
|
||||
"RustBindingsAPI object has no attribute bindings") so the next
|
||||
index pass populates it with the correct dim. Prevents silent
|
||||
corruption that would later surface as a search error
|
||||
("Collection expecting embedding with dimension of X, got Y") and
|
||||
hang live_gui tests.
|
||||
[C: tests/test_rag_engine.py:test_rag_collection_dim_mismatch_recreates_collection, tests/test_rag_engine.py:test_rag_collection_dim_match_preserves_collection]
|
||||
"""
|
||||
def _validate_collection_dim_result(self) -> Result[None]:
|
||||
if self.collection is None or self.collection == "mock" or self.embedding_provider is None:
|
||||
return
|
||||
return Result(data=None)
|
||||
try:
|
||||
res = self.collection.get(limit=1, include=["embeddings"])
|
||||
except Exception as e:
|
||||
sys.stderr.write(f"RAG: Failed to read collection for dim check: {e}\n")
|
||||
sys.stderr.flush()
|
||||
return
|
||||
if not res:
|
||||
return
|
||||
embeddings = res.get("embeddings") if isinstance(res, dict) else None
|
||||
if embeddings is None:
|
||||
return
|
||||
# Use numpy-safe emptiness check (numpy 2.x disallows truthiness on empty arrays)
|
||||
try:
|
||||
if len(embeddings) == 0:
|
||||
return
|
||||
except TypeError:
|
||||
return
|
||||
existing_dim = len(embeddings[0])
|
||||
try:
|
||||
if not res:
|
||||
return Result(data=None)
|
||||
embeddings = res.get("embeddings") if isinstance(res, dict) else None
|
||||
if not embeddings or len(embeddings) == 0:
|
||||
return Result(data=None)
|
||||
existing_dim = len(embeddings[0])
|
||||
expected_dim = len(self.embedding_provider.embed(["__rag_dim_check__"])[0])
|
||||
except Exception as e:
|
||||
sys.stderr.write(f"RAG: Failed to compute expected dim: {e}\n")
|
||||
if existing_dim == expected_dim:
|
||||
return Result(data=None)
|
||||
sys.stderr.write(
|
||||
f"RAG: Collection '{self.collection.name}' dim mismatch "
|
||||
f"(existing={existing_dim}, expected={expected_dim}). "
|
||||
f"Recreating collection to prevent silent corruption.\n"
|
||||
)
|
||||
sys.stderr.flush()
|
||||
return
|
||||
if existing_dim == expected_dim:
|
||||
return
|
||||
sys.stderr.write(
|
||||
f"RAG: Collection '{self.collection.name}' dim mismatch "
|
||||
f"(existing={existing_dim}, expected={expected_dim}). "
|
||||
f"Wiping chroma dir to prevent silent corruption.\n"
|
||||
)
|
||||
sys.stderr.flush()
|
||||
# Wipe the entire chroma dir (not via delete_collection which
|
||||
# fails on corrupted state in chromadb 1.5.x with
|
||||
# "RustBindingsAPI object has no attribute bindings"). Rmtree is
|
||||
# reliable and re-creates a fresh empty collection.
|
||||
# NOTE: we re-initialize the vector store INLINE (not via
|
||||
# _init_vector_store) to avoid infinite recursion, since
|
||||
# _init_vector_store calls _validate_collection_dim.
|
||||
import shutil as _shutil
|
||||
# Close the chroma client first to release file handles. Without
|
||||
# this, rmtree fails with WinError 32 on Windows.
|
||||
try:
|
||||
if hasattr(self, 'client') and self.client and self.client != "mock":
|
||||
self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.client = None
|
||||
self.collection = None
|
||||
if hasattr(self, 'base_dir') and self.base_dir:
|
||||
db_path = os.path.abspath(os.path.join(self.base_dir, ".slop_cache", f"chroma_{self.config.vector_store.collection_name}"))
|
||||
if os.path.isdir(db_path):
|
||||
try:
|
||||
_shutil.rmtree(db_path)
|
||||
except Exception as e:
|
||||
sys.stderr.write(f"RAG: Failed to wipe chroma dir: {e}\n")
|
||||
sys.stderr.flush()
|
||||
# Re-initialize the vector store inline (no recursion).
|
||||
vs_config = self.config.vector_store
|
||||
if vs_config.provider == 'chroma':
|
||||
from src import rag_engine as _re_self
|
||||
os.makedirs(db_path, exist_ok=True)
|
||||
chroma_module = _get_chromadb()
|
||||
if chroma_module is None:
|
||||
raise ImportError("chromadb is not installed")
|
||||
chromadb, _Settings = chroma_module
|
||||
self.client = chromadb.PersistentClient(path=db_path)
|
||||
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
|
||||
elif vs_config.provider == 'mock':
|
||||
self.client = "mock"
|
||||
self.collection = "mock"
|
||||
self.client.delete_collection(self.collection.name)
|
||||
self.collection = self.client.get_or_create_collection(name=self.collection.name)
|
||||
return Result(data=None)
|
||||
except Exception as e:
|
||||
return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"Failed to validate collection dim: {e}", source="rag._validate_collection_dim", original=e)])
|
||||
|
||||
def _get_state(self) -> NilRAGState:
|
||||
return NilRAGState(enabled=self.config.enabled)
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
if not self.config.enabled:
|
||||
|
||||
@@ -77,8 +77,8 @@ def test_rag_collection_dim_mismatch_recreates_collection(mock_get_chroma, mock_
|
||||
"Collection expecting embedding with dimension of 3072, got 384".
|
||||
|
||||
Expected: RAGEngine.__init__ detects the mismatch, deletes the
|
||||
mismatched collection, and recreates it empty so subsequent indexing
|
||||
uses the correct dim.
|
||||
mismatched collection via client.delete_collection, and recreates it
|
||||
empty so subsequent indexing uses the correct dim.
|
||||
"""
|
||||
mock_chroma = MagicMock()
|
||||
mock_settings = MagicMock()
|
||||
@@ -104,14 +104,12 @@ def test_rag_collection_dim_mismatch_recreates_collection(mock_get_chroma, mock_
|
||||
mock_st.return_value = MagicMock()
|
||||
engine = RAGEngine(config)
|
||||
assert engine.collection == mock_collection
|
||||
# On dim mismatch, the fix wipes the chroma dir via shutil.rmtree
|
||||
# (not via client.delete_collection which fails on corrupted state
|
||||
# in chromadb 1.5.x with "RustBindingsAPI object has no attribute
|
||||
# bindings"). The collection is then re-initialized by the inline
|
||||
# re-init code, which calls get_or_create_collection once more
|
||||
# (after the original _init_vector_store call).
|
||||
# On dim mismatch, _validate_collection_dim_result calls
|
||||
# client.delete_collection(name) then get_or_create_collection(name)
|
||||
# to recreate the collection with the correct dim. The first
|
||||
# get_or_create_collection call was in _init_vector_store_result.
|
||||
assert mock_client.get_or_create_collection.call_count == 2
|
||||
mock_client.delete_collection.assert_not_called()
|
||||
mock_client.delete_collection.assert_called_once_with("test")
|
||||
|
||||
@patch('src.rag_engine.LocalEmbeddingProvider.embed')
|
||||
@patch('src.rag_engine._get_chromadb')
|
||||
|
||||
Reference in New Issue
Block a user