feat(rag): Implement RAG engine, configuration schema, and vector store integration
This commit is contained in:
@@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
from src import models
|
||||
from src.rag_engine import RAGEngine, BaseEmbeddingProvider, LocalEmbeddingProvider, GeminiEmbeddingProvider
|
||||
|
||||
class MockEmbeddingProvider(BaseEmbeddingProvider):
|
||||
def embed(self, texts):
|
||||
return [[0.1] * 384 for _ in texts]
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_config():
|
||||
vs_config = models.VectorStoreConfig(provider='mock', collection_name='test')
|
||||
return models.RAGConfig(enabled=True, vector_store=vs_config, embedding_provider='local')
|
||||
|
||||
def test_rag_engine_init_mock(mock_rag_config):
|
||||
engine = RAGEngine(mock_rag_config)
|
||||
assert engine.config.enabled is True
|
||||
assert engine.collection == "mock"
|
||||
|
||||
@patch('src.rag_engine.LocalEmbeddingProvider.embed')
|
||||
@patch('src.rag_engine.chromadb.PersistentClient')
|
||||
def test_rag_engine_chroma(mock_chroma, mock_embed):
|
||||
mock_embed.return_value = [[0.1, 0.2, 0.3]]
|
||||
|
||||
mock_collection = MagicMock()
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_or_create_collection.return_value = mock_collection
|
||||
mock_chroma.return_value = mock_client
|
||||
|
||||
vs_config = models.VectorStoreConfig(provider='chroma', collection_name='test')
|
||||
config = models.RAGConfig(enabled=True, vector_store=vs_config, embedding_provider='local')
|
||||
|
||||
with patch('src.rag_engine.SentenceTransformer') as mock_st:
|
||||
engine = RAGEngine(config)
|
||||
assert engine.collection == mock_collection
|
||||
|
||||
engine.add_documents(["doc1"], ["hello world"])
|
||||
mock_collection.upsert.assert_called_once()
|
||||
|
||||
mock_collection.query.return_value = {
|
||||
"ids": [["doc1"]],
|
||||
"documents": [["hello world"]],
|
||||
"metadatas": [[{}]],
|
||||
"distances": [[0.0]]
|
||||
}
|
||||
|
||||
results = engine.search("hello", top_k=1)
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc1"
|
||||
assert results[0]["document"] == "hello world"
|
||||
|
||||
engine.delete_documents(["doc1"])
|
||||
mock_collection.delete.assert_called_once_with(ids=["doc1"])
|
||||
Reference in New Issue
Block a user