From e80cd6bd3f50ae9f613ad6764cfd0de9c4e31f9d Mon Sep 17 00:00:00 2001 From: Ed_ Date: Mon, 4 May 2026 05:38:23 -0400 Subject: [PATCH] feat(rag): Implement RAG engine, configuration schema, and vector store integration --- pyproject.toml | 2 + src/app_controller.py | 11 ++++ src/models.py | 51 +++++++++++++++++ src/rag_engine.py | 118 +++++++++++++++++++++++++++++++++++++++ tests/test_rag_engine.py | 54 ++++++++++++++++++ 5 files changed, 236 insertions(+) create mode 100644 src/rag_engine.py create mode 100644 tests/test_rag_engine.py diff --git a/pyproject.toml b/pyproject.toml index 040e529..5c483db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,8 @@ dependencies = [ "mcp>=1.0.0", "pytest-timeout>=2.4.0", "pyopengl>=3.1.10", + "chromadb>=1.5.8", + "sentence-transformers>=5.4.1", ] [dependency-groups] diff --git a/src/app_controller.py b/src/app_controller.py index 9073c6c..c144737 100644 --- a/src/app_controller.py +++ b/src/app_controller.py @@ -201,6 +201,7 @@ class AppController: self._pending_actions: Dict[str, ConfirmDialog] = {} self._pending_ask_dialog: bool = False self.mcp_config: models.MCPConfiguration = models.MCPConfiguration() + self.rag_config: Optional[models.RAGConfig] = None # AI settings state self._current_provider: str = "gemini" self._current_model: str = "gemini-2.5-flash-lite" @@ -948,6 +949,12 @@ class AppController: else: self.mcp_config = models.MCPConfiguration() + rag_data = self.config.get('rag') + if rag_data: + self.rag_config = models.RAGConfig.from_dict(rag_data) + else: + self.rag_config = models.RAGConfig() + from src.personas import PersonaManager self.persona_manager = PersonaManager(Path(self.active_project_path).parent if self.active_project_path else None) self.personas = self.persona_manager.load_all() @@ -2526,6 +2533,10 @@ class AppController: self.config["ai"]["system_prompt"] = self.ui_global_system_prompt self.config["ai"]["base_system_prompt"] = self.ui_base_system_prompt self.config["ai"]["use_default_base_prompt"] = self.ui_use_default_base_prompt + + if self.rag_config: + self.config["rag"] = self.rag_config.to_dict() + self.config["projects"] = {"paths": self.project_paths, "active": self.active_project_path} from src import bg_shader # Update gui section while preserving other keys like bg_shader_enabled diff --git a/src/models.py b/src/models.py index a05bf26..885c49b 100644 --- a/src/models.py +++ b/src/models.py @@ -596,6 +596,57 @@ class MCPConfiguration: } return cls(mcpServers=parsed_servers) +@dataclass +class VectorStoreConfig: + provider: str # 'chroma', 'qdrant', 'mock' + url: Optional[str] = None + api_key: Optional[str] = None + collection_name: str = 'manual_slop' + + def to_dict(self) -> Dict[str, Any]: + return { + "provider": self.provider, + "url": self.url, + "api_key": self.api_key, + "collection_name": self.collection_name, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "VectorStoreConfig": + return cls( + provider=data["provider"], + url=data.get("url"), + api_key=data.get("api_key"), + collection_name=data.get("collection_name", "manual_slop"), + ) + +@dataclass +class RAGConfig: + enabled: bool = False + vector_store: VectorStoreConfig = field(default_factory=lambda: VectorStoreConfig(provider='mock')) + embedding_provider: str = 'gemini' + chunk_size: int = 1000 + chunk_overlap: int = 200 + + def to_dict(self) -> Dict[str, Any]: + return { + "enabled": self.enabled, + "vector_store": self.vector_store.to_dict(), + "embedding_provider": self.embedding_provider, + "chunk_size": self.chunk_size, + "chunk_overlap": self.chunk_overlap, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "RAGConfig": + return cls( + enabled=data.get("enabled", False), + vector_store=VectorStoreConfig.from_dict(data.get("vector_store", {"provider": "mock"})), + embedding_provider=data.get("embedding_provider", "gemini"), + chunk_size=data.get("chunk_size", 1000), + chunk_overlap=data.get("chunk_overlap", 200), + ) + def load_mcp_config(path: str) -> MCPConfiguration: if not os.path.exists(path): return MCPConfiguration() diff --git a/src/rag_engine.py b/src/rag_engine.py new file mode 100644 index 0000000..c486b27 --- /dev/null +++ b/src/rag_engine.py @@ -0,0 +1,118 @@ +import os +from typing import List, Dict, Any, Optional +import chromadb +from chromadb.config import Settings +from src import models + +try: + from sentence_transformers import SentenceTransformer +except ImportError: + SentenceTransformer = None + +from google import genai +from google.genai import types +from src import ai_client + +class BaseEmbeddingProvider: + def embed(self, texts: List[str]) -> List[List[float]]: + raise NotImplementedError() + +class LocalEmbeddingProvider(BaseEmbeddingProvider): + def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): + if SentenceTransformer is None: + raise ImportError("sentence-transformers is not installed") + self.model = SentenceTransformer(model_name) + + def embed(self, texts: List[str]) -> List[List[float]]: + embeddings = self.model.encode(texts) + return embeddings.tolist() + +class GeminiEmbeddingProvider(BaseEmbeddingProvider): + def __init__(self, model_name: str = 'text-embedding-004'): + self.model_name = model_name + + def embed(self, texts: List[str]) -> List[List[float]]: + ai_client._ensure_gemini_client() + client = ai_client._gemini_client + if not client: + raise ValueError("Gemini client not initialized") + + # For text-embedding-004, we can embed a batch + res = client.models.embed_content( + model=self.model_name, + contents=texts, + config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT") + ) + return [e.values for e in res.embeddings] + +class RAGEngine: + def __init__(self, config: models.RAGConfig, base_dir: str = "."): + self.config = config + self.base_dir = base_dir + self.client = None + self.collection = None + self.embedding_provider = None + + if not self.config.enabled: + return + + self._init_embedding_provider() + self._init_vector_store() + + def _init_embedding_provider(self): + if self.config.embedding_provider == 'gemini': + self.embedding_provider = GeminiEmbeddingProvider() + elif self.config.embedding_provider == 'local': + self.embedding_provider = LocalEmbeddingProvider() + else: + raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}") + + def _init_vector_store(self): + vs_config = self.config.vector_store + if vs_config.provider == 'chroma': + db_path = os.path.join(self.base_dir, ".slop_cache", "chroma_db") + os.makedirs(db_path, exist_ok=True) + self.client = chromadb.PersistentClient(path=db_path) + self.collection = self.client.get_or_create_collection(name=vs_config.collection_name) + elif vs_config.provider == 'mock': + self.client = "mock" + self.collection = "mock" + else: + raise ValueError(f"Unknown vector store provider: {vs_config.provider}") + + def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None): + if not self.config.enabled or self.collection == "mock": + return + embeddings = self.embedding_provider.embed(texts) + self.collection.upsert( + ids=ids, + embeddings=embeddings, + documents=texts, + metadatas=metadatas + ) + + def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + if not self.config.enabled or self.collection == "mock": + return [] + + query_embedding = self.embedding_provider.embed([query])[0] + results = self.collection.query( + query_embeddings=[query_embedding], + n_results=top_k + ) + + ret = [] + if results and results["ids"] and results["ids"][0]: + for i in range(len(results["ids"][0])): + ret.append({ + "id": results["ids"][0][i], + "document": results["documents"][0][i], + "metadata": results["metadatas"][0][i] if results["metadatas"] else {}, + "distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0 + }) + return ret + + def delete_documents(self, ids: List[str]): + if not self.config.enabled or self.collection == "mock": + return + self.collection.delete(ids=ids) diff --git a/tests/test_rag_engine.py b/tests/test_rag_engine.py new file mode 100644 index 0000000..b5a1f34 --- /dev/null +++ b/tests/test_rag_engine.py @@ -0,0 +1,54 @@ +import pytest +import os +from unittest.mock import MagicMock, patch +from src import models +from src.rag_engine import RAGEngine, BaseEmbeddingProvider, LocalEmbeddingProvider, GeminiEmbeddingProvider + +class MockEmbeddingProvider(BaseEmbeddingProvider): + def embed(self, texts): + return [[0.1] * 384 for _ in texts] + +@pytest.fixture +def mock_rag_config(): + vs_config = models.VectorStoreConfig(provider='mock', collection_name='test') + return models.RAGConfig(enabled=True, vector_store=vs_config, embedding_provider='local') + +def test_rag_engine_init_mock(mock_rag_config): + engine = RAGEngine(mock_rag_config) + assert engine.config.enabled is True + assert engine.collection == "mock" + +@patch('src.rag_engine.LocalEmbeddingProvider.embed') +@patch('src.rag_engine.chromadb.PersistentClient') +def test_rag_engine_chroma(mock_chroma, mock_embed): + mock_embed.return_value = [[0.1, 0.2, 0.3]] + + mock_collection = MagicMock() + mock_client = MagicMock() + mock_client.get_or_create_collection.return_value = mock_collection + mock_chroma.return_value = mock_client + + vs_config = models.VectorStoreConfig(provider='chroma', collection_name='test') + config = models.RAGConfig(enabled=True, vector_store=vs_config, embedding_provider='local') + + with patch('src.rag_engine.SentenceTransformer') as mock_st: + engine = RAGEngine(config) + assert engine.collection == mock_collection + + engine.add_documents(["doc1"], ["hello world"]) + mock_collection.upsert.assert_called_once() + + mock_collection.query.return_value = { + "ids": [["doc1"]], + "documents": [["hello world"]], + "metadatas": [[{}]], + "distances": [[0.0]] + } + + results = engine.search("hello", top_k=1) + assert len(results) == 1 + assert results[0]["id"] == "doc1" + assert results[0]["document"] == "hello world" + + engine.delete_documents(["doc1"]) + mock_collection.delete.assert_called_once_with(ids=["doc1"])