feat(rag): Implement RAG engine, configuration schema, and vector store integration
This commit is contained in:
@@ -18,6 +18,8 @@ dependencies = [
|
|||||||
"mcp>=1.0.0",
|
"mcp>=1.0.0",
|
||||||
"pytest-timeout>=2.4.0",
|
"pytest-timeout>=2.4.0",
|
||||||
"pyopengl>=3.1.10",
|
"pyopengl>=3.1.10",
|
||||||
|
"chromadb>=1.5.8",
|
||||||
|
"sentence-transformers>=5.4.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
|
|||||||
@@ -201,6 +201,7 @@ class AppController:
|
|||||||
self._pending_actions: Dict[str, ConfirmDialog] = {}
|
self._pending_actions: Dict[str, ConfirmDialog] = {}
|
||||||
self._pending_ask_dialog: bool = False
|
self._pending_ask_dialog: bool = False
|
||||||
self.mcp_config: models.MCPConfiguration = models.MCPConfiguration()
|
self.mcp_config: models.MCPConfiguration = models.MCPConfiguration()
|
||||||
|
self.rag_config: Optional[models.RAGConfig] = None
|
||||||
# AI settings state
|
# AI settings state
|
||||||
self._current_provider: str = "gemini"
|
self._current_provider: str = "gemini"
|
||||||
self._current_model: str = "gemini-2.5-flash-lite"
|
self._current_model: str = "gemini-2.5-flash-lite"
|
||||||
@@ -948,6 +949,12 @@ class AppController:
|
|||||||
else:
|
else:
|
||||||
self.mcp_config = models.MCPConfiguration()
|
self.mcp_config = models.MCPConfiguration()
|
||||||
|
|
||||||
|
rag_data = self.config.get('rag')
|
||||||
|
if rag_data:
|
||||||
|
self.rag_config = models.RAGConfig.from_dict(rag_data)
|
||||||
|
else:
|
||||||
|
self.rag_config = models.RAGConfig()
|
||||||
|
|
||||||
from src.personas import PersonaManager
|
from src.personas import PersonaManager
|
||||||
self.persona_manager = PersonaManager(Path(self.active_project_path).parent if self.active_project_path else None)
|
self.persona_manager = PersonaManager(Path(self.active_project_path).parent if self.active_project_path else None)
|
||||||
self.personas = self.persona_manager.load_all()
|
self.personas = self.persona_manager.load_all()
|
||||||
@@ -2526,6 +2533,10 @@ class AppController:
|
|||||||
self.config["ai"]["system_prompt"] = self.ui_global_system_prompt
|
self.config["ai"]["system_prompt"] = self.ui_global_system_prompt
|
||||||
self.config["ai"]["base_system_prompt"] = self.ui_base_system_prompt
|
self.config["ai"]["base_system_prompt"] = self.ui_base_system_prompt
|
||||||
self.config["ai"]["use_default_base_prompt"] = self.ui_use_default_base_prompt
|
self.config["ai"]["use_default_base_prompt"] = self.ui_use_default_base_prompt
|
||||||
|
|
||||||
|
if self.rag_config:
|
||||||
|
self.config["rag"] = self.rag_config.to_dict()
|
||||||
|
|
||||||
self.config["projects"] = {"paths": self.project_paths, "active": self.active_project_path}
|
self.config["projects"] = {"paths": self.project_paths, "active": self.active_project_path}
|
||||||
from src import bg_shader
|
from src import bg_shader
|
||||||
# Update gui section while preserving other keys like bg_shader_enabled
|
# Update gui section while preserving other keys like bg_shader_enabled
|
||||||
|
|||||||
@@ -596,6 +596,57 @@ class MCPConfiguration:
|
|||||||
}
|
}
|
||||||
return cls(mcpServers=parsed_servers)
|
return cls(mcpServers=parsed_servers)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VectorStoreConfig:
|
||||||
|
provider: str # 'chroma', 'qdrant', 'mock'
|
||||||
|
url: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
collection_name: str = 'manual_slop'
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"provider": self.provider,
|
||||||
|
"url": self.url,
|
||||||
|
"api_key": self.api_key,
|
||||||
|
"collection_name": self.collection_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "VectorStoreConfig":
|
||||||
|
return cls(
|
||||||
|
provider=data["provider"],
|
||||||
|
url=data.get("url"),
|
||||||
|
api_key=data.get("api_key"),
|
||||||
|
collection_name=data.get("collection_name", "manual_slop"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RAGConfig:
|
||||||
|
enabled: bool = False
|
||||||
|
vector_store: VectorStoreConfig = field(default_factory=lambda: VectorStoreConfig(provider='mock'))
|
||||||
|
embedding_provider: str = 'gemini'
|
||||||
|
chunk_size: int = 1000
|
||||||
|
chunk_overlap: int = 200
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"enabled": self.enabled,
|
||||||
|
"vector_store": self.vector_store.to_dict(),
|
||||||
|
"embedding_provider": self.embedding_provider,
|
||||||
|
"chunk_size": self.chunk_size,
|
||||||
|
"chunk_overlap": self.chunk_overlap,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "RAGConfig":
|
||||||
|
return cls(
|
||||||
|
enabled=data.get("enabled", False),
|
||||||
|
vector_store=VectorStoreConfig.from_dict(data.get("vector_store", {"provider": "mock"})),
|
||||||
|
embedding_provider=data.get("embedding_provider", "gemini"),
|
||||||
|
chunk_size=data.get("chunk_size", 1000),
|
||||||
|
chunk_overlap=data.get("chunk_overlap", 200),
|
||||||
|
)
|
||||||
|
|
||||||
def load_mcp_config(path: str) -> MCPConfiguration:
|
def load_mcp_config(path: str) -> MCPConfiguration:
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
return MCPConfiguration()
|
return MCPConfiguration()
|
||||||
|
|||||||
@@ -0,0 +1,118 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import chromadb
|
||||||
|
from chromadb.config import Settings
|
||||||
|
from src import models
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
except ImportError:
|
||||||
|
SentenceTransformer = None
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types
|
||||||
|
from src import ai_client
|
||||||
|
|
||||||
|
class BaseEmbeddingProvider:
|
||||||
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
class LocalEmbeddingProvider(BaseEmbeddingProvider):
|
||||||
|
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
||||||
|
if SentenceTransformer is None:
|
||||||
|
raise ImportError("sentence-transformers is not installed")
|
||||||
|
self.model = SentenceTransformer(model_name)
|
||||||
|
|
||||||
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
embeddings = self.model.encode(texts)
|
||||||
|
return embeddings.tolist()
|
||||||
|
|
||||||
|
class GeminiEmbeddingProvider(BaseEmbeddingProvider):
|
||||||
|
def __init__(self, model_name: str = 'text-embedding-004'):
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
ai_client._ensure_gemini_client()
|
||||||
|
client = ai_client._gemini_client
|
||||||
|
if not client:
|
||||||
|
raise ValueError("Gemini client not initialized")
|
||||||
|
|
||||||
|
# For text-embedding-004, we can embed a batch
|
||||||
|
res = client.models.embed_content(
|
||||||
|
model=self.model_name,
|
||||||
|
contents=texts,
|
||||||
|
config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
|
||||||
|
)
|
||||||
|
return [e.values for e in res.embeddings]
|
||||||
|
|
||||||
|
class RAGEngine:
|
||||||
|
def __init__(self, config: models.RAGConfig, base_dir: str = "."):
|
||||||
|
self.config = config
|
||||||
|
self.base_dir = base_dir
|
||||||
|
self.client = None
|
||||||
|
self.collection = None
|
||||||
|
self.embedding_provider = None
|
||||||
|
|
||||||
|
if not self.config.enabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._init_embedding_provider()
|
||||||
|
self._init_vector_store()
|
||||||
|
|
||||||
|
def _init_embedding_provider(self):
|
||||||
|
if self.config.embedding_provider == 'gemini':
|
||||||
|
self.embedding_provider = GeminiEmbeddingProvider()
|
||||||
|
elif self.config.embedding_provider == 'local':
|
||||||
|
self.embedding_provider = LocalEmbeddingProvider()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}")
|
||||||
|
|
||||||
|
def _init_vector_store(self):
|
||||||
|
vs_config = self.config.vector_store
|
||||||
|
if vs_config.provider == 'chroma':
|
||||||
|
db_path = os.path.join(self.base_dir, ".slop_cache", "chroma_db")
|
||||||
|
os.makedirs(db_path, exist_ok=True)
|
||||||
|
self.client = chromadb.PersistentClient(path=db_path)
|
||||||
|
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
|
||||||
|
elif vs_config.provider == 'mock':
|
||||||
|
self.client = "mock"
|
||||||
|
self.collection = "mock"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown vector store provider: {vs_config.provider}")
|
||||||
|
|
||||||
|
def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None):
|
||||||
|
if not self.config.enabled or self.collection == "mock":
|
||||||
|
return
|
||||||
|
embeddings = self.embedding_provider.embed(texts)
|
||||||
|
self.collection.upsert(
|
||||||
|
ids=ids,
|
||||||
|
embeddings=embeddings,
|
||||||
|
documents=texts,
|
||||||
|
metadatas=metadatas
|
||||||
|
)
|
||||||
|
|
||||||
|
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||||
|
if not self.config.enabled or self.collection == "mock":
|
||||||
|
return []
|
||||||
|
|
||||||
|
query_embedding = self.embedding_provider.embed([query])[0]
|
||||||
|
results = self.collection.query(
|
||||||
|
query_embeddings=[query_embedding],
|
||||||
|
n_results=top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
if results and results["ids"] and results["ids"][0]:
|
||||||
|
for i in range(len(results["ids"][0])):
|
||||||
|
ret.append({
|
||||||
|
"id": results["ids"][0][i],
|
||||||
|
"document": results["documents"][0][i],
|
||||||
|
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
|
||||||
|
"distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0
|
||||||
|
})
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def delete_documents(self, ids: List[str]):
|
||||||
|
if not self.config.enabled or self.collection == "mock":
|
||||||
|
return
|
||||||
|
self.collection.delete(ids=ids)
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from src import models
|
||||||
|
from src.rag_engine import RAGEngine, BaseEmbeddingProvider, LocalEmbeddingProvider, GeminiEmbeddingProvider
|
||||||
|
|
||||||
|
class MockEmbeddingProvider(BaseEmbeddingProvider):
|
||||||
|
def embed(self, texts):
|
||||||
|
return [[0.1] * 384 for _ in texts]
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_rag_config():
|
||||||
|
vs_config = models.VectorStoreConfig(provider='mock', collection_name='test')
|
||||||
|
return models.RAGConfig(enabled=True, vector_store=vs_config, embedding_provider='local')
|
||||||
|
|
||||||
|
def test_rag_engine_init_mock(mock_rag_config):
|
||||||
|
engine = RAGEngine(mock_rag_config)
|
||||||
|
assert engine.config.enabled is True
|
||||||
|
assert engine.collection == "mock"
|
||||||
|
|
||||||
|
@patch('src.rag_engine.LocalEmbeddingProvider.embed')
|
||||||
|
@patch('src.rag_engine.chromadb.PersistentClient')
|
||||||
|
def test_rag_engine_chroma(mock_chroma, mock_embed):
|
||||||
|
mock_embed.return_value = [[0.1, 0.2, 0.3]]
|
||||||
|
|
||||||
|
mock_collection = MagicMock()
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.get_or_create_collection.return_value = mock_collection
|
||||||
|
mock_chroma.return_value = mock_client
|
||||||
|
|
||||||
|
vs_config = models.VectorStoreConfig(provider='chroma', collection_name='test')
|
||||||
|
config = models.RAGConfig(enabled=True, vector_store=vs_config, embedding_provider='local')
|
||||||
|
|
||||||
|
with patch('src.rag_engine.SentenceTransformer') as mock_st:
|
||||||
|
engine = RAGEngine(config)
|
||||||
|
assert engine.collection == mock_collection
|
||||||
|
|
||||||
|
engine.add_documents(["doc1"], ["hello world"])
|
||||||
|
mock_collection.upsert.assert_called_once()
|
||||||
|
|
||||||
|
mock_collection.query.return_value = {
|
||||||
|
"ids": [["doc1"]],
|
||||||
|
"documents": [["hello world"]],
|
||||||
|
"metadatas": [[{}]],
|
||||||
|
"distances": [[0.0]]
|
||||||
|
}
|
||||||
|
|
||||||
|
results = engine.search("hello", top_k=1)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["id"] == "doc1"
|
||||||
|
assert results[0]["document"] == "hello world"
|
||||||
|
|
||||||
|
engine.delete_documents(["doc1"])
|
||||||
|
mock_collection.delete.assert_called_once_with(ids=["doc1"])
|
||||||
Reference in New Issue
Block a user