fix(rag): detect ChromaDB dim mismatch and recreate collection on provider switch
This commit is contained in:
@@ -122,12 +122,49 @@ class RAGEngine:
|
|||||||
chromadb, Settings = chroma_module
|
chromadb, Settings = chroma_module
|
||||||
self.client = chromadb.PersistentClient(path=db_path)
|
self.client = chromadb.PersistentClient(path=db_path)
|
||||||
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
|
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
|
||||||
|
self._validate_collection_dim()
|
||||||
elif vs_config.provider == 'mock':
|
elif vs_config.provider == 'mock':
|
||||||
self.client = "mock"
|
self.client = "mock"
|
||||||
self.collection = "mock"
|
self.collection = "mock"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown vector store provider: {vs_config.provider}")
|
raise ValueError(f"Unknown vector store provider: {vs_config.provider}")
|
||||||
|
|
||||||
|
def _validate_collection_dim(self) -> None:
|
||||||
|
"""
|
||||||
|
Detect dimension mismatch between an existing collection's vectors and
|
||||||
|
the current embedding provider's output. When mismatched (e.g. the user
|
||||||
|
switched from Gemini 3072-dim to local 384-dim, or vice versa), the
|
||||||
|
collection is deleted and recreated empty so the next index pass
|
||||||
|
populates it with the correct dim. Prevents silent corruption that
|
||||||
|
would later surface as a search error ("Collection expecting
|
||||||
|
embedding with dimension of X, got Y") and hang live_gui tests.
|
||||||
|
[C: tests/test_rag_engine.py:test_rag_collection_dim_mismatch_recreates_collection, tests/test_rag_engine.py:test_rag_collection_dim_match_preserves_collection]
|
||||||
|
"""
|
||||||
|
if self.collection is None or self.collection == "mock" or self.embedding_provider is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
res = self.collection.get(limit=1, include=["embeddings"])
|
||||||
|
if not res:
|
||||||
|
return
|
||||||
|
embeddings = res.get("embeddings") if isinstance(res, dict) else None
|
||||||
|
if not embeddings or len(embeddings) == 0:
|
||||||
|
return
|
||||||
|
existing_dim = len(embeddings[0])
|
||||||
|
expected_dim = len(self.embedding_provider.embed(["__rag_dim_check__"])[0])
|
||||||
|
if existing_dim == expected_dim:
|
||||||
|
return
|
||||||
|
sys.stderr.write(
|
||||||
|
f"RAG: Collection '{self.collection.name}' dim mismatch "
|
||||||
|
f"(existing={existing_dim}, expected={expected_dim}). "
|
||||||
|
f"Recreating collection to prevent silent corruption.\n"
|
||||||
|
)
|
||||||
|
sys.stderr.flush()
|
||||||
|
self.client.delete_collection(self.collection.name)
|
||||||
|
self.collection = self.client.get_or_create_collection(name=self.collection.name)
|
||||||
|
except Exception as e:
|
||||||
|
sys.stderr.write(f"RAG: Failed to validate collection dim: {e}\n")
|
||||||
|
sys.stderr.flush()
|
||||||
|
|
||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
if not self.config.enabled:
|
if not self.config.enabled:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -52,7 +52,88 @@ def test_rag_engine_chroma(mock_get_chroma, mock_embed):
|
|||||||
results = engine.search("hello", top_k=1)
|
results = engine.search("hello", top_k=1)
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0]["id"] == "doc1"
|
assert results[0]["id"] == "doc1"
|
||||||
assert results[0]["document"] == "hello world"
|
engine.delete_documents(["doc1"])
|
||||||
|
mock_collection.delete.assert_called_once_with(ids=["doc1"])
|
||||||
|
|
||||||
|
@patch('src.rag_engine.LocalEmbeddingProvider.embed')
|
||||||
|
@patch('src.rag_engine._get_chromadb')
|
||||||
|
def test_rag_collection_dim_mismatch_recreates_collection(mock_get_chroma, mock_embed):
|
||||||
|
"""
|
||||||
|
Regression test for the live_gui_test_hardening_v2 followup
|
||||||
|
(RAG dimension-mismatch flake in test_rag_phase4_stress).
|
||||||
|
|
||||||
|
Scenario: a ChromaDB collection exists on disk with vectors from a
|
||||||
|
previous embedding provider (e.g. Gemini, 3072-dim), but the current
|
||||||
|
config uses a different provider (e.g. local SentenceTransformers,
|
||||||
|
384-dim). Without the dim check, upsert silently corrupts the
|
||||||
|
collection and search() later fails with
|
||||||
|
"Collection expecting embedding with dimension of 3072, got 384".
|
||||||
|
|
||||||
|
Expected: RAGEngine.__init__ detects the mismatch, deletes the
|
||||||
|
mismatched collection, and recreates it empty so subsequent indexing
|
||||||
|
uses the correct dim.
|
||||||
|
"""
|
||||||
|
mock_chroma = MagicMock()
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_get_chroma.return_value = (mock_chroma, mock_settings)
|
||||||
|
|
||||||
|
mock_embed.return_value = [[0.1] * 384]
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_collection.get.return_value = {
|
||||||
|
"embeddings": [[0.1] * 3072],
|
||||||
|
"metadatas": [{}],
|
||||||
|
"ids": ["stale_doc_1"],
|
||||||
|
}
|
||||||
|
mock_collection.name = "test"
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
mock_chroma.PersistentClient.return_value = mock_client
|
||||||
|
|
||||||
|
vs_config = models.VectorStoreConfig(provider='chroma', collection_name='test')
|
||||||
|
config = models.RAGConfig(enabled=True, vector_store=vs_config, embedding_provider='local')
|
||||||
|
|
||||||
|
with patch('src.rag_engine._get_sentence_transformers') as mock_st:
|
||||||
|
mock_st.return_value = MagicMock()
|
||||||
|
engine = RAGEngine(config)
|
||||||
|
assert engine.collection == mock_collection
|
||||||
|
mock_client.delete_collection.assert_called_once_with("test")
|
||||||
|
assert mock_client.get_or_create_collection.call_count == 2
|
||||||
|
|
||||||
|
@patch('src.rag_engine.LocalEmbeddingProvider.embed')
|
||||||
|
@patch('src.rag_engine._get_chromadb')
|
||||||
|
def test_rag_collection_dim_match_preserves_collection(mock_get_chroma, mock_embed):
|
||||||
|
"""
|
||||||
|
Companion test: when the collection's existing dim matches the current
|
||||||
|
provider's dim, the engine must NOT delete the collection (which would
|
||||||
|
discard indexed data).
|
||||||
|
"""
|
||||||
|
mock_chroma = MagicMock()
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_get_chroma.return_value = (mock_chroma, mock_settings)
|
||||||
|
|
||||||
|
mock_embed.return_value = [[0.1] * 384]
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_collection.get.return_value = {
|
||||||
|
"embeddings": [[0.1] * 384],
|
||||||
|
"metadatas": [{"path": "file_25.txt"}],
|
||||||
|
"ids": ["doc_25_0"],
|
||||||
|
}
|
||||||
|
mock_collection.name = "test"
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
mock_chroma.PersistentClient.return_value = mock_client
|
||||||
|
|
||||||
|
vs_config = models.VectorStoreConfig(provider='chroma', collection_name='test')
|
||||||
|
config = models.RAGConfig(enabled=True, vector_store=vs_config, embedding_provider='local')
|
||||||
|
|
||||||
|
with patch('src.rag_engine._get_sentence_transformers') as mock_st:
|
||||||
|
mock_st.return_value = MagicMock()
|
||||||
|
engine = RAGEngine(config)
|
||||||
|
assert engine.collection == mock_collection
|
||||||
|
mock_client.delete_collection.assert_not_called()
|
||||||
|
assert mock_client.get_or_create_collection.call_count == 1
|
||||||
|
|
||||||
engine.delete_documents(["doc1"])
|
engine.delete_documents(["doc1"])
|
||||||
mock_collection.delete.assert_called_once_with(ids=["doc1"])
|
mock_collection.delete.assert_called_once_with(ids=["doc1"])
|
||||||
|
|||||||
Reference in New Issue
Block a user