import asyncio import copy import json import os import sys from typing import List, Dict, Any, Optional from src import ai_client from src import models from src import mcp_client from src.result_types import ErrorInfo, ErrorKind, NilRAGState, Result from src.file_cache import ASTParser _SENTENCE_TRANSFORMERS = None _GOOGLE_GENAI = None _CHROMADB = None LOCAL_RAG_INSTALL_HINT = "Local RAG embeddings require sentence-transformers. Install with manual_slop[local-rag] to use local embeddings." def _get_sentence_transformers(): global _SENTENCE_TRANSFORMERS if _SENTENCE_TRANSFORMERS is None: try: from sentence_transformers import SentenceTransformer _SENTENCE_TRANSFORMERS = SentenceTransformer except ModuleNotFoundError as e: if e.name == "sentence_transformers": raise ImportError(LOCAL_RAG_INSTALL_HINT) from e raise except (ImportError, AttributeError) as e: sys.stderr.write(f"FAILED to import sentence_transformers: {e}\n") sys.stderr.flush() raise return _SENTENCE_TRANSFORMERS def _get_google_genai(): global _GOOGLE_GENAI if _GOOGLE_GENAI is None: from google import genai from google.genai import types _GOOGLE_GENAI = (genai, types) return _GOOGLE_GENAI def _get_chromadb(): global _CHROMADB if _CHROMADB is None: import chromadb from chromadb.config import Settings _CHROMADB = (chromadb, Settings) return _CHROMADB class BaseEmbeddingProvider: def embed(self, texts: List[str]) -> List[List[float]]: raise NotImplementedError() class LocalEmbeddingProvider(BaseEmbeddingProvider): def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): ST = _get_sentence_transformers() self.model = ST(model_name) def embed(self, texts: List[str]) -> List[List[float]]: embeddings = self.model.encode(texts) return embeddings.tolist() class GeminiEmbeddingProvider(BaseEmbeddingProvider): def __init__(self, model_name: str = 'gemini-embedding-001'): self.model_name = model_name def embed(self, texts: List[str]) -> List[List[float]]: google_module = _get_google_genai() if google_module is None: raise ImportError("google-genai is not installed") genai_pkg, types = google_module ai_client._ensure_gemini_client() client = ai_client._gemini_client if not client: raise ValueError("Gemini client not initialized") res = client.models.embed_content( model = self.model_name, contents = texts, config = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT") ) return [e.values for e in res.embeddings] def _parse_search_response_result(res_str: str) -> Result[List[Dict[str, Any]]]: """Parse the MCP rag_search response. Returns Result[List[dict]]. On JSON parse failure, returns Result(errors=[ErrorInfo]). The legacy caller returns [] on errors, preserving the original behavior.""" try: data = json.loads(res_str) except (ValueError, TypeError) as e: return Result( data=None, errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"_search_mcp JSON parse failed: {e}", source="rag_engine._parse_search_response_result", original=e)], ) if isinstance(data, list): return Result(data=data) if isinstance(data, dict) and "results" in data: return Result(data=data["results"]) return Result(data=[]) class RAGEngine: def __init__(self, config: models.RAGConfig, base_dir: str = "."): self.config = copy.deepcopy(config) self.base_dir = base_dir self.client = None self.collection = None self.embedding_provider = None if not self.config.enabled: return self._init_embedding_provider() r = self._init_vector_store_result() if not r.ok: self.collection = None def _init_embedding_provider(self): if self.config.embedding_provider == 'gemini': self.embedding_provider = GeminiEmbeddingProvider() elif self.config.embedding_provider == 'local': self.embedding_provider = LocalEmbeddingProvider() else: raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}") def _init_vector_store_result(self) -> Result[None]: vs_config = self.config.vector_store if vs_config.provider == 'chroma': db_path = os.path.abspath(os.path.join(self.base_dir, ".slop_cache", f"chroma_{vs_config.collection_name}")) os.makedirs(db_path, exist_ok=True) chroma_module = _get_chromadb() if chroma_module is None: return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.CONFIG, message="chromadb is not installed", source="rag._init_vector_store")]) chromadb, Settings = chroma_module self.client = chromadb.PersistentClient(path=db_path) self.collection = self.client.get_or_create_collection(name=vs_config.collection_name) return self._validate_collection_dim_result() elif vs_config.provider == 'mock': self.client = "mock" self.collection = "mock" return Result(data=None) else: return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.CONFIG, message=f"Unknown vector store provider: {vs_config.provider}", source="rag._init_vector_store")]) def _validate_collection_dim_result(self) -> Result[None]: """ Detect dimension mismatch between an existing collection's vectors and the current embedding provider's output. When mismatched (e.g. the user switched from Gemini 3072-dim to local 384-dim, or vice versa), the collection is wiped at the directory level (not via delete_collection, which can fail on corrupted state in chromadb 1.5.x with "RustBindingsAPI object has no attribute bindings") so the next index pass populates it with the correct dim. Prevents silent corruption that would later surface as a search error ("Collection expecting embedding with dimension of X, got Y") and hang live_gui tests. """ if self.collection is None or self.collection == "mock" or self.embedding_provider is None: return Result(data=None) try: res = self.collection.get(limit=1, include=["embeddings"]) if not res: return Result(data=None) embeddings = res.get("embeddings") if isinstance(res, dict) else None if embeddings is None or len(embeddings) == 0: return Result(data=None) existing_dim = len(embeddings[0]) expected_dim = len(self.embedding_provider.embed(["__rag_dim_check__"])[0]) if existing_dim == expected_dim: return Result(data=None) sys.stderr.write( f"RAG: Collection '{self.collection.name}' dim mismatch " f"(existing={existing_dim}, expected={expected_dim}). " f"Recreating collection to prevent silent corruption.\n" ) sys.stderr.flush() self.client.delete_collection(self.collection.name) self.collection = self.client.get_or_create_collection(name=self.collection.name) return Result(data=None) except Exception as e: return Result(data=None, errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"Failed to validate collection dim: {e}", source="rag._validate_collection_dim", original=e)]) def _get_state(self) -> NilRAGState: return NilRAGState(enabled=self.config.enabled) def is_empty(self) -> bool: if not self.config.enabled: return True if self.config.vector_store.provider == 'mock' or self.collection == "mock": return True if self.collection is None: return True return self.collection.count() == 0 def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None): """ [C: tests/test_rag_engine.py:test_rag_engine_chroma] """ if not self.config.enabled or self.collection == "mock": return embeddings = self.embedding_provider.embed(texts) self.collection.upsert( ids = ids, embeddings = embeddings, documents = texts, metadatas = metadatas ) def _chunk_text(self, content: str) -> List[str]: """Character-based chunking with overlap.""" chunks = [] if not content: return chunks chunk_size = self.config.chunk_size overlap = self.config.chunk_overlap start = 0 while start < len(content): end = start + chunk_size chunks.append(content[start:end]) if end >= len(content): break start += (chunk_size - overlap) return chunks def _chunk_code_result(self, content: str, file_path: str) -> Result[List[str]]: """AST-aware chunking for Python code. Returns Result[List[str]]. On AST parse failure, returns Result(errors=[ErrorInfo]). The legacy caller (_chunk_code) decides whether to fallback to text chunking (preserving the original behavior). """ try: parser = ASTParser("python") tree = parser.parse(content) chunks: List[str] = [] for node in tree.root_node.children: if node.type in ("function_definition", "class_definition"): chunks.append(content[node.start_byte:node.end_byte]) return Result(data=chunks) except Exception as e: return Result( data=None, errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"AST chunking failed for {file_path}: {e}", source="rag_engine._chunk_code_result", original=e)], ) def _get_file_mtime_result(self, full_path: str) -> Result[float]: """Get file modification time. Returns Result[float].""" try: return Result(data=os.path.getmtime(full_path)) except OSError as e: return Result( data=None, errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to get mtime for {full_path}: {e}", source="rag_engine._get_file_mtime_result", original=e)], ) def _check_existing_index_result(self, file_path: str, mtime: float) -> Result[bool]: """Check if the file is already indexed at the current mtime. Returns Result(data=True) if already indexed (skip), Result(data=False) if needs re-indexing, Result(data=False, errors=[ErrorInfo]) on collection failure. """ try: res = self.collection.get(where={"path": file_path}, limit=1, include=["metadatas"]) if res and res["metadatas"] and res["metadatas"][0]: if res["metadatas"][0].get("mtime") == mtime: return Result(data=True) return Result(data=False) except Exception as e: return Result( data=False, errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to check existing index for {file_path}: {e}", source="rag_engine._check_existing_index_result", original=e)], ) def _read_file_content_result(self, full_path: str) -> Result[str]: """Read file contents. Returns Result[str].""" try: with open(full_path, "r", encoding="utf-8", errors="ignore") as f: return Result(data=f.read()) except (OSError, UnicodeDecodeError) as e: return Result( data=None, errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to read {full_path}: {e}", source="rag_engine._read_file_content_result", original=e)], ) def index_file(self, file_path: str): """Reads, chunks, and indexes a file into the vector store.""" if not self.config.enabled or self.collection == "mock": return full_path = os.path.join(self.base_dir, file_path) if not os.path.exists(full_path): # CWD fallback: handle the case where base_dir was resolved to a # parent directory (e.g. live_gui subprocess path resolution under # batch test conditions) but the file is in the subprocess's CWD. # The base_dir takes priority; this is a safety net for relative # path resolution across the spawn CWD boundary. cwd_path = os.path.join(os.getcwd(), file_path) if os.path.exists(cwd_path): full_path = cwd_path else: return mtime_result = self._get_file_mtime_result(full_path) if not mtime_result.ok: return mtime = mtime_result.data existing_result = self._check_existing_index_result(file_path, mtime) if existing_result.ok and existing_result.data: return content_result = self._read_file_content_result(full_path) if not content_result.ok: return content = content_result.data self.collection.delete(where={"path": file_path}) if file_path.lower().endswith(".py"): chunk_result = self._chunk_code_result(content, file_path) if not chunk_result.ok or not chunk_result.data or len(content) < self.config.chunk_size: chunks = self._chunk_text(content) else: chunks = chunk_result.data else: chunks = self._chunk_text(content) if not chunks: return ids = [f"{file_path}_{i}" for i in range(len(chunks))] metadatas = [{"path": file_path, "chunk": i, "mtime": mtime} for i in range(len(chunks))] self.add_documents(ids, chunks, metadatas) def _search_mcp(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: async def _async_search_mcp() -> List[Dict[str, Any]]: tool_name = self.config.vector_store.mcp_tool or "rag_search" args = {"query": query, "top_k": top_k} res_str = await mcp_client.async_dispatch(tool_name, args) parse_result = _parse_search_response_result(res_str) return parse_result.data if parse_result.ok else [] return asyncio.run(_async_search_mcp()) def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: """ [C: tests/mock_concurrent_mma.py:main, tests/test_rag_engine.py:test_rag_engine_chroma] """ if not self.config.enabled: return [] if self.config.vector_store.provider == 'mcp': return self._search_mcp(query, top_k) if self.collection == "mock": return [] query_embedding = self.embedding_provider.embed([query])[0] results = self.collection.query( query_embeddings = [query_embedding], n_results = top_k ) ret = [] if results and results["ids"] and results["ids"][0]: for i in range(len(results["ids"][0])): ret.append({ "id": results["ids"][0][i], "document": results["documents"][0][i], "metadata": results["metadatas"][0][i] if results["metadatas"] else {}, "distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0 }) return ret def delete_documents(self, ids: List[str]): """ [C: tests/test_rag_engine.py:test_rag_engine_chroma] """ if not self.config.enabled or self.collection == "mock": return self.collection.delete(ids=ids) def get_all_indexed_paths(self) -> List[str]: if not self.config.enabled or self.collection == "mock": return [] res = self.collection.get(include=["metadatas"]) if not res or not res["metadatas"]: return [] return list(set(m["path"] for m in res["metadatas"] if m is not None and m.get("path"))) def delete_documents_by_path(self, file_paths: List[str]): if not self.config.enabled or self.collection == "mock": return for path in file_paths: self.collection.delete(where={"path": path})