feat(rag): Implement RAG engine, configuration schema, and vector store integration
This commit is contained in:
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from src import models
|
||||
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
except ImportError:
|
||||
SentenceTransformer = None
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from src import ai_client
|
||||
|
||||
class BaseEmbeddingProvider:
|
||||
def embed(self, texts: List[str]) -> List[List[float]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
class LocalEmbeddingProvider(BaseEmbeddingProvider):
|
||||
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
||||
if SentenceTransformer is None:
|
||||
raise ImportError("sentence-transformers is not installed")
|
||||
self.model = SentenceTransformer(model_name)
|
||||
|
||||
def embed(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = self.model.encode(texts)
|
||||
return embeddings.tolist()
|
||||
|
||||
class GeminiEmbeddingProvider(BaseEmbeddingProvider):
|
||||
def __init__(self, model_name: str = 'text-embedding-004'):
|
||||
self.model_name = model_name
|
||||
|
||||
def embed(self, texts: List[str]) -> List[List[float]]:
|
||||
ai_client._ensure_gemini_client()
|
||||
client = ai_client._gemini_client
|
||||
if not client:
|
||||
raise ValueError("Gemini client not initialized")
|
||||
|
||||
# For text-embedding-004, we can embed a batch
|
||||
res = client.models.embed_content(
|
||||
model=self.model_name,
|
||||
contents=texts,
|
||||
config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
|
||||
)
|
||||
return [e.values for e in res.embeddings]
|
||||
|
||||
class RAGEngine:
|
||||
def __init__(self, config: models.RAGConfig, base_dir: str = "."):
|
||||
self.config = config
|
||||
self.base_dir = base_dir
|
||||
self.client = None
|
||||
self.collection = None
|
||||
self.embedding_provider = None
|
||||
|
||||
if not self.config.enabled:
|
||||
return
|
||||
|
||||
self._init_embedding_provider()
|
||||
self._init_vector_store()
|
||||
|
||||
def _init_embedding_provider(self):
|
||||
if self.config.embedding_provider == 'gemini':
|
||||
self.embedding_provider = GeminiEmbeddingProvider()
|
||||
elif self.config.embedding_provider == 'local':
|
||||
self.embedding_provider = LocalEmbeddingProvider()
|
||||
else:
|
||||
raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}")
|
||||
|
||||
def _init_vector_store(self):
|
||||
vs_config = self.config.vector_store
|
||||
if vs_config.provider == 'chroma':
|
||||
db_path = os.path.join(self.base_dir, ".slop_cache", "chroma_db")
|
||||
os.makedirs(db_path, exist_ok=True)
|
||||
self.client = chromadb.PersistentClient(path=db_path)
|
||||
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
|
||||
elif vs_config.provider == 'mock':
|
||||
self.client = "mock"
|
||||
self.collection = "mock"
|
||||
else:
|
||||
raise ValueError(f"Unknown vector store provider: {vs_config.provider}")
|
||||
|
||||
def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None):
|
||||
if not self.config.enabled or self.collection == "mock":
|
||||
return
|
||||
embeddings = self.embedding_provider.embed(texts)
|
||||
self.collection.upsert(
|
||||
ids=ids,
|
||||
embeddings=embeddings,
|
||||
documents=texts,
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
if not self.config.enabled or self.collection == "mock":
|
||||
return []
|
||||
|
||||
query_embedding = self.embedding_provider.embed([query])[0]
|
||||
results = self.collection.query(
|
||||
query_embeddings=[query_embedding],
|
||||
n_results=top_k
|
||||
)
|
||||
|
||||
ret = []
|
||||
if results and results["ids"] and results["ids"][0]:
|
||||
for i in range(len(results["ids"][0])):
|
||||
ret.append({
|
||||
"id": results["ids"][0][i],
|
||||
"document": results["documents"][0][i],
|
||||
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
|
||||
"distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0
|
||||
})
|
||||
return ret
|
||||
|
||||
def delete_documents(self, ids: List[str]):
|
||||
if not self.config.enabled or self.collection == "mock":
|
||||
return
|
||||
self.collection.delete(ids=ids)
|
||||
Reference in New Issue
Block a user