import asyncio import copy import json import os import sys from typing import List, Dict, Any, Optional from src import ai_client from src import models from src import mcp_client from src.file_cache import ASTParser _SENTENCE_TRANSFORMERS = None _GOOGLE_GENAI = None _CHROMADB = None LOCAL_RAG_INSTALL_HINT = "Local RAG embeddings require sentence-transformers. Install with manual_slop[local-rag] to use local embeddings." def _get_sentence_transformers(): global _SENTENCE_TRANSFORMERS if _SENTENCE_TRANSFORMERS is None: try: from sentence_transformers import SentenceTransformer _SENTENCE_TRANSFORMERS = SentenceTransformer except ModuleNotFoundError as e: if e.name == "sentence_transformers": raise ImportError(LOCAL_RAG_INSTALL_HINT) from e raise except Exception as e: sys.stderr.write(f"FAILED to import sentence_transformers: {e}\n") sys.stderr.flush() raise e return _SENTENCE_TRANSFORMERS def _get_google_genai(): global _GOOGLE_GENAI if _GOOGLE_GENAI is None: from google import genai from google.genai import types _GOOGLE_GENAI = (genai, types) return _GOOGLE_GENAI def _get_chromadb(): global _CHROMADB if _CHROMADB is None: import chromadb from chromadb.config import Settings _CHROMADB = (chromadb, Settings) return _CHROMADB class BaseEmbeddingProvider: def embed(self, texts: List[str]) -> List[List[float]]: raise NotImplementedError() class LocalEmbeddingProvider(BaseEmbeddingProvider): def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): ST = _get_sentence_transformers() self.model = ST(model_name) def embed(self, texts: List[str]) -> List[List[float]]: embeddings = self.model.encode(texts) return embeddings.tolist() class GeminiEmbeddingProvider(BaseEmbeddingProvider): def __init__(self, model_name: str = 'gemini-embedding-001'): self.model_name = model_name def embed(self, texts: List[str]) -> List[List[float]]: google_module = _get_google_genai() if google_module is None: raise ImportError("google-genai is not installed") genai_pkg, types = google_module ai_client._ensure_gemini_client() client = ai_client._gemini_client if not client: raise ValueError("Gemini client not initialized") res = client.models.embed_content( model = self.model_name, contents = texts, config = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT") ) return [e.values for e in res.embeddings] class RAGEngine: def __init__(self, config: models.RAGConfig, base_dir: str = "."): self.config = copy.deepcopy(config) self.base_dir = base_dir self.client = None self.collection = None self.embedding_provider = None if not self.config.enabled: return self._init_embedding_provider() self._init_vector_store() def _init_embedding_provider(self): if self.config.embedding_provider == 'gemini': self.embedding_provider = GeminiEmbeddingProvider() elif self.config.embedding_provider == 'local': self.embedding_provider = LocalEmbeddingProvider() else: raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}") def _init_vector_store(self): vs_config = self.config.vector_store if vs_config.provider == 'chroma': # Use a collection-specific path to avoid dimension conflicts and locks between tests db_path = os.path.abspath(os.path.join(self.base_dir, ".slop_cache", f"chroma_{vs_config.collection_name}")) os.makedirs(db_path, exist_ok=True) chroma_module = _get_chromadb() if chroma_module is None: raise ImportError("chromadb is not installed") chromadb, Settings = chroma_module self.client = chromadb.PersistentClient(path=db_path) self.collection = self.client.get_or_create_collection(name=vs_config.collection_name) self._validate_collection_dim() elif vs_config.provider == 'mock': self.client = "mock" self.collection = "mock" else: raise ValueError(f"Unknown vector store provider: {vs_config.provider}") def _validate_collection_dim(self) -> None: """ Detect dimension mismatch between an existing collection's vectors and the current embedding provider's output. When mismatched (e.g. the user switched from Gemini 3072-dim to local 384-dim, or vice versa), the collection is deleted and recreated empty so the next index pass populates it with the correct dim. Prevents silent corruption that would later surface as a search error ("Collection expecting embedding with dimension of X, got Y") and hang live_gui tests. [C: tests/test_rag_engine.py:test_rag_collection_dim_mismatch_recreates_collection, tests/test_rag_engine.py:test_rag_collection_dim_match_preserves_collection] """ if self.collection is None or self.collection == "mock" or self.embedding_provider is None: return try: res = self.collection.get(limit=1, include=["embeddings"]) if not res: return embeddings = res.get("embeddings") if isinstance(res, dict) else None if not embeddings or len(embeddings) == 0: return existing_dim = len(embeddings[0]) expected_dim = len(self.embedding_provider.embed(["__rag_dim_check__"])[0]) if existing_dim == expected_dim: return sys.stderr.write( f"RAG: Collection '{self.collection.name}' dim mismatch " f"(existing={existing_dim}, expected={expected_dim}). " f"Recreating collection to prevent silent corruption.\n" ) sys.stderr.flush() self.client.delete_collection(self.collection.name) self.collection = self.client.get_or_create_collection(name=self.collection.name) except Exception as e: sys.stderr.write(f"RAG: Failed to validate collection dim: {e}\n") sys.stderr.flush() def is_empty(self) -> bool: if not self.config.enabled: return True if self.config.vector_store.provider == 'mock' or self.collection == "mock": return True if self.collection is None: return True return self.collection.count() == 0 def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None): """ [C: tests/test_rag_engine.py:test_rag_engine_chroma] """ if not self.config.enabled or self.collection == "mock": return embeddings = self.embedding_provider.embed(texts) self.collection.upsert( ids = ids, embeddings = embeddings, documents = texts, metadatas = metadatas ) def _chunk_text(self, content: str) -> List[str]: """Character-based chunking with overlap.""" chunks = [] if not content: return chunks chunk_size = self.config.chunk_size overlap = self.config.chunk_overlap start = 0 while start < len(content): end = start + chunk_size chunks.append(content[start:end]) if end >= len(content): break start += (chunk_size - overlap) return chunks def _chunk_code(self, content: str, file_path: str) -> List[str]: """AST-aware chunking for Python code.""" try: parser = ASTParser("python") tree = parser.parse(content) chunks = [] for node in tree.root_node.children: if node.type in ("function_definition", "class_definition"): chunks.append(content[node.start_byte:node.end_byte]) if not chunks or len(content) < self.config.chunk_size: return self._chunk_text(content) return chunks except Exception: return self._chunk_text(content) def index_file(self, file_path: str): """Reads, chunks, and indexes a file into the vector store.""" if not self.config.enabled or self.collection == "mock": return full_path = os.path.join(self.base_dir, file_path) if not os.path.exists(full_path): # CWD fallback: handle the case where base_dir was resolved to a # parent directory (e.g. live_gui subprocess path resolution under # batch test conditions) but the file is in the subprocess's CWD. # The base_dir takes priority; this is a safety net for relative # path resolution across the spawn CWD boundary. cwd_path = os.path.join(os.getcwd(), file_path) if os.path.exists(cwd_path): full_path = cwd_path else: return try: mtime = os.path.getmtime(full_path) except Exception: return try: res = self.collection.get(where={"path": file_path}, limit=1, include=["metadatas"]) if res and res["metadatas"] and res["metadatas"][0]: if res["metadatas"][0].get("mtime") == mtime: return except Exception: pass try: with open(full_path, "r", encoding="utf-8", errors="ignore") as f: content = f.read() except Exception: return self.collection.delete(where={"path": file_path}) if file_path.lower().endswith(".py"): chunks = self._chunk_code(content, file_path) else: chunks = self._chunk_text(content) if not chunks: return ids = [f"{file_path}_{i}" for i in range(len(chunks))] metadatas = [{"path": file_path, "chunk": i, "mtime": mtime} for i in range(len(chunks))] self.add_documents(ids, chunks, metadatas) def _search_mcp(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: async def _async_search_mcp(): tool_name = self.config.vector_store.mcp_tool or "rag_search" args = {"query": query, "top_k": top_k} res_str = await mcp_client.async_dispatch(tool_name, args) try: data = json.loads(res_str) if isinstance(data, list): return data elif isinstance(data, dict) and "results" in data: return data["results"] return [] except: return [] return asyncio.run(_async_search_mcp()) def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: """ [C: tests/mock_concurrent_mma.py:main, tests/test_rag_engine.py:test_rag_engine_chroma] """ if not self.config.enabled: return [] if self.config.vector_store.provider == 'mcp': return self._search_mcp(query, top_k) if self.collection == "mock": return [] query_embedding = self.embedding_provider.embed([query])[0] results = self.collection.query( query_embeddings = [query_embedding], n_results = top_k ) ret = [] if results and results["ids"] and results["ids"][0]: for i in range(len(results["ids"][0])): ret.append({ "id": results["ids"][0][i], "document": results["documents"][0][i], "metadata": results["metadatas"][0][i] if results["metadatas"] else {}, "distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0 }) return ret def delete_documents(self, ids: List[str]): """ [C: tests/test_rag_engine.py:test_rag_engine_chroma] """ if not self.config.enabled or self.collection == "mock": return self.collection.delete(ids=ids) def get_all_indexed_paths(self) -> List[str]: if not self.config.enabled or self.collection == "mock": return [] res = self.collection.get(include=["metadatas"]) if not res or not res["metadatas"]: return [] return list(set(m.get("path") for m in res["metadatas"] if m.get("path"))) def delete_documents_by_path(self, file_paths: List[str]): if not self.config.enabled or self.collection == "mock": return for path in file_paths: self.collection.delete(where={"path": path})