fix(rag): detect ChromaDB dim mismatch and recreate collection on provider switch
This commit is contained in:
@@ -122,12 +122,49 @@ class RAGEngine:
|
||||
chromadb, Settings = chroma_module
|
||||
self.client = chromadb.PersistentClient(path=db_path)
|
||||
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
|
||||
self._validate_collection_dim()
|
||||
elif vs_config.provider == 'mock':
|
||||
self.client = "mock"
|
||||
self.collection = "mock"
|
||||
else:
|
||||
raise ValueError(f"Unknown vector store provider: {vs_config.provider}")
|
||||
|
||||
def _validate_collection_dim(self) -> None:
|
||||
"""
|
||||
Detect dimension mismatch between an existing collection's vectors and
|
||||
the current embedding provider's output. When mismatched (e.g. the user
|
||||
switched from Gemini 3072-dim to local 384-dim, or vice versa), the
|
||||
collection is deleted and recreated empty so the next index pass
|
||||
populates it with the correct dim. Prevents silent corruption that
|
||||
would later surface as a search error ("Collection expecting
|
||||
embedding with dimension of X, got Y") and hang live_gui tests.
|
||||
[C: tests/test_rag_engine.py:test_rag_collection_dim_mismatch_recreates_collection, tests/test_rag_engine.py:test_rag_collection_dim_match_preserves_collection]
|
||||
"""
|
||||
if self.collection is None or self.collection == "mock" or self.embedding_provider is None:
|
||||
return
|
||||
try:
|
||||
res = self.collection.get(limit=1, include=["embeddings"])
|
||||
if not res:
|
||||
return
|
||||
embeddings = res.get("embeddings") if isinstance(res, dict) else None
|
||||
if not embeddings or len(embeddings) == 0:
|
||||
return
|
||||
existing_dim = len(embeddings[0])
|
||||
expected_dim = len(self.embedding_provider.embed(["__rag_dim_check__"])[0])
|
||||
if existing_dim == expected_dim:
|
||||
return
|
||||
sys.stderr.write(
|
||||
f"RAG: Collection '{self.collection.name}' dim mismatch "
|
||||
f"(existing={existing_dim}, expected={expected_dim}). "
|
||||
f"Recreating collection to prevent silent corruption.\n"
|
||||
)
|
||||
sys.stderr.flush()
|
||||
self.client.delete_collection(self.collection.name)
|
||||
self.collection = self.client.get_or_create_collection(name=self.collection.name)
|
||||
except Exception as e:
|
||||
sys.stderr.write(f"RAG: Failed to validate collection dim: {e}\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
if not self.config.enabled:
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user