diff --git a/src/rag_engine.py b/src/rag_engine.py index 90da6593..ae010abe 100644 --- a/src/rag_engine.py +++ b/src/rag_engine.py @@ -8,248 +8,257 @@ from chromadb.config import Settings from src import models from src import mcp_client -try: - from sentence_transformers import SentenceTransformer -except ImportError: - SentenceTransformer = None +_SENTENCE_TRANSFORMERS = None +_GOOGLE_GENAI = None -from google import genai -from google.genai import types -from src import ai_client +def _get_sentence_transformers(): + global _SENTENCE_TRANSFORMERS + if _SENTENCE_TRANSFORMERS is None: + from sentence_transformers import SentenceTransformer + _SENTENCE_TRANSFORMERS = SentenceTransformer + return _SENTENCE_TRANSFORMERS + +def _get_google_genai(): + global _GOOGLE_GENAI + if _GOOGLE_GENAI is None: + from google import genai + from google.genai import types + _GOOGLE_GENAI = (genai, types) + return _GOOGLE_GENAI class BaseEmbeddingProvider: - def embed(self, texts: List[str]) -> List[List[float]]: - raise NotImplementedError() + def embed(self, texts: List[str]) -> List[List[float]]: + raise NotImplementedError() class LocalEmbeddingProvider(BaseEmbeddingProvider): - def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): - if SentenceTransformer is None: - raise ImportError("sentence-transformers is not installed") - self.model = SentenceTransformer(model_name) + def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): + ST = _get_sentence_transformers() + if ST is None: + raise ImportError("sentence-transformers is not installed") + self.model = ST(model_name) - def embed(self, texts: List[str]) -> List[List[float]]: - embeddings = self.model.encode(texts) - return embeddings.tolist() + def embed(self, texts: List[str]) -> List[List[float]]: + embeddings = self.model.encode(texts) + return embeddings.tolist() class GeminiEmbeddingProvider(BaseEmbeddingProvider): - def __init__(self, model_name: str = 'text-embedding-004'): - self.model_name = model_name + def __init__(self, model_name: str = 'text-embedding-004'): + self.model_name = model_name - def embed(self, texts: List[str]) -> List[List[float]]: - ai_client._ensure_gemini_client() - client = ai_client._gemini_client - if not client: - raise ValueError("Gemini client not initialized") - - # For text-embedding-004, we can embed a batch - res = client.models.embed_content( - model=self.model_name, - contents=texts, - config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT") - ) - return [e.values for e in res.embeddings] + def embed(self, texts: List[str]) -> List[List[float]]: + google_module = _get_google_genai() + if google_module is None: + raise ImportError("google-genai is not installed") + genai_pkg, types = google_module + from src import ai_client + ai_client._ensure_gemini_client() + client = ai_client._gemini_client + if not client: + raise ValueError("Gemini client not initialized") + res = client.models.embed_content( + model=self.model_name, + contents=texts, + config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT") + ) + return [e.values for e in res.embeddings] class RAGEngine: - def __init__(self, config: models.RAGConfig, base_dir: str = "."): - self.config = config - self.base_dir = base_dir - self.client = None - self.collection = None - self.embedding_provider = None - - if not self.config.enabled: - return - - self._init_embedding_provider() - self._init_vector_store() + def __init__(self, config: models.RAGConfig, base_dir: str = "."): + self.config = config + self.base_dir = base_dir + self.client = None + self.collection = None + self.embedding_provider = None - def _init_embedding_provider(self): - if self.config.embedding_provider == 'gemini': - self.embedding_provider = GeminiEmbeddingProvider() - elif self.config.embedding_provider == 'local': - self.embedding_provider = LocalEmbeddingProvider() - else: - raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}") + if not self.config.enabled: + return - def _init_vector_store(self): - vs_config = self.config.vector_store - if vs_config.provider == 'chroma': - db_path = os.path.join(self.base_dir, ".slop_cache", "chroma_db") - os.makedirs(db_path, exist_ok=True) - self.client = chromadb.PersistentClient(path=db_path) - self.collection = self.client.get_or_create_collection(name=vs_config.collection_name) - elif vs_config.provider == 'mock': - self.client = "mock" - self.collection = "mock" - else: - raise ValueError(f"Unknown vector store provider: {vs_config.provider}") + self._init_embedding_provider() + self._init_vector_store() - def is_empty(self) -> bool: - if not self.config.enabled: - return True - if self.config.vector_store.provider == 'mock' or self.collection == "mock": - return True - if self.collection is None: - return True - return self.collection.count() == 0 + def _init_embedding_provider(self): + if self.config.embedding_provider == 'gemini': + self.embedding_provider = GeminiEmbeddingProvider() + elif self.config.embedding_provider == 'local': + self.embedding_provider = LocalEmbeddingProvider() + else: + raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}") - def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None): - """ - [C: tests/test_rag_engine.py:test_rag_engine_chroma] - """ - if not self.config.enabled or self.collection == "mock": - return - embeddings = self.embedding_provider.embed(texts) - self.collection.upsert( - ids=ids, - embeddings=embeddings, - documents=texts, - metadatas=metadatas - ) + def _init_vector_store(self): + vs_config = self.config.vector_store + if vs_config.provider == 'chroma': + db_path = os.path.join(self.base_dir, ".slop_cache", "chroma_db") + os.makedirs(db_path, exist_ok=True) + self.client = chromadb.PersistentClient(path=db_path) + self.collection = self.client.get_or_create_collection(name=vs_config.collection_name) + elif vs_config.provider == 'mock': + self.client = "mock" + self.collection = "mock" + else: + raise ValueError(f"Unknown vector store provider: {vs_config.provider}") - def _chunk_text(self, content: str) -> List[str]: - """Character-based chunking with overlap.""" - chunks = [] - if not content: - return chunks - chunk_size = self.config.chunk_size - overlap = self.config.chunk_overlap - start = 0 - while start < len(content): - end = start + chunk_size - chunks.append(content[start:end]) - if end >= len(content): - break - start += (chunk_size - overlap) - return chunks + def is_empty(self) -> bool: + if not self.config.enabled: + return True + if self.config.vector_store.provider == 'mock' or self.collection == "mock": + return True + if self.collection is None: + return True + return self.collection.count() == 0 - def _chunk_code(self, content: str, file_path: str) -> List[str]: - """AST-aware chunking for Python code.""" - try: - from src.file_cache import ASTParser - parser = ASTParser("python") - tree = parser.parse(content) - chunks = [] - - # Capture classes and top-level functions - for node in tree.root_node.children: - if node.type in ("function_definition", "class_definition"): - chunks.append(content[node.start_byte:node.end_byte]) - - # Fallback if no structural chunks found or if file is small - if not chunks or len(content) < self.config.chunk_size: - return self._chunk_text(content) - return chunks - except Exception: - return self._chunk_text(content) + def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None): + """ + [C: tests/test_rag_engine.py:test_rag_engine_chroma] + """ + if not self.config.enabled or self.collection == "mock": + return + embeddings = self.embedding_provider.embed(texts) + self.collection.upsert( + ids=ids, + embeddings=embeddings, + documents=texts, + metadatas=metadatas + ) - def index_file(self, file_path: str): - """Reads, chunks, and indexes a file into the vector store.""" - if not self.config.enabled or self.collection == "mock": - return - - full_path = os.path.join(self.base_dir, file_path) - if not os.path.exists(full_path): - return - - try: - mtime = os.path.getmtime(full_path) - except Exception: - return + def _chunk_text(self, content: str) -> List[str]: + """Character-based chunking with overlap.""" + chunks = [] + if not content: + return chunks + chunk_size = self.config.chunk_size + overlap = self.config.chunk_overlap + start = 0 + while start < len(content): + end = start + chunk_size + chunks.append(content[start:end]) + if end >= len(content): + break + start += (chunk_size - overlap) + return chunks - # Incremental check: see if we already have this file with the same mtime - try: - res = self.collection.get(where={"path": file_path}, limit=1, include=["metadatas"]) - if res and res["metadatas"] and res["metadatas"][0]: - if res["metadatas"][0].get("mtime") == mtime: - return - except Exception: - pass + def _chunk_code(self, content: str, file_path: str) -> List[str]: + """AST-aware chunking for Python code.""" + try: + from src.file_cache import ASTParser + parser = ASTParser("python") + tree = parser.parse(content) + chunks = [] - try: - with open(full_path, "r", encoding="utf-8", errors="ignore") as f: - content = f.read() - except Exception: - return + for node in tree.root_node.children: + if node.type in ("function_definition", "class_definition"): + chunks.append(content[node.start_byte:node.end_byte]) - # Remove old entries for this file - self.collection.delete(where={"path": file_path}) - - if file_path.lower().endswith(".py"): - chunks = self._chunk_code(content, file_path) - else: - chunks = self._chunk_text(content) - - if not chunks: - return + if not chunks or len(content) < self.config.chunk_size: + return self._chunk_text(content) + return chunks + except Exception: + return self._chunk_text(content) - ids = [f"{file_path}_{i}" for i in range(len(chunks))] - metadatas = [{"path": file_path, "chunk": i, "mtime": mtime} for i in range(len(chunks))] - self.add_documents(ids, chunks, metadatas) + def index_file(self, file_path: str): + """Reads, chunks, and indexes a file into the vector store.""" + if not self.config.enabled or self.collection == "mock": + return - def _search_mcp(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: - async def _async_search_mcp(): - tool_name = self.config.vector_store.mcp_tool or "rag_search" - args = {"query": query, "top_k": top_k} - res_str = await mcp_client.async_dispatch(tool_name, args) - try: - data = json.loads(res_str) - if isinstance(data, list): - return data - elif isinstance(data, dict) and "results" in data: - return data["results"] - return [] - except: - return [] + full_path = os.path.join(self.base_dir, file_path) + if not os.path.exists(full_path): + return - return asyncio.run(_async_search_mcp()) + try: + mtime = os.path.getmtime(full_path) + except Exception: + return - def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: - """ - [C: tests/mock_concurrent_mma.py:main, tests/test_rag_engine.py:test_rag_engine_chroma] - """ - if not self.config.enabled: - return [] - if self.config.vector_store.provider == 'mcp': - return self._search_mcp(query, top_k) - if self.collection == "mock": - return [] - - query_embedding = self.embedding_provider.embed([query])[0] - results = self.collection.query( - query_embeddings=[query_embedding], - n_results=top_k - ) - - ret = [] - if results and results["ids"] and results["ids"][0]: - for i in range(len(results["ids"][0])): - ret.append({ - "id": results["ids"][0][i], - "document": results["documents"][0][i], - "metadata": results["metadatas"][0][i] if results["metadatas"] else {}, - "distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0 - }) - return ret + try: + res = self.collection.get(where={"path": file_path}, limit=1, include=["metadatas"]) + if res and res["metadatas"] and res["metadatas"][0]: + if res["metadatas"][0].get("mtime") == mtime: + return + except Exception: + pass - def delete_documents(self, ids: List[str]): - """ - [C: tests/test_rag_engine.py:test_rag_engine_chroma] - """ - if not self.config.enabled or self.collection == "mock": - return - self.collection.delete(ids=ids) + try: + with open(full_path, "r", encoding="utf-8", errors="ignore") as f: + content = f.read() + except Exception: + return - def get_all_indexed_paths(self) -> List[str]: - if not self.config.enabled or self.collection == "mock": - return [] - res = self.collection.get(include=["metadatas"]) - if not res or not res["metadatas"]: - return [] - return list(set(m.get("path") for m in res["metadatas"] if m.get("path"))) + self.collection.delete(where={"path": file_path}) - def delete_documents_by_path(self, file_paths: List[str]): - if not self.config.enabled or self.collection == "mock": - return - for path in file_paths: - self.collection.delete(where={"path": path}) \ No newline at end of file + if file_path.lower().endswith(".py"): + chunks = self._chunk_code(content, file_path) + else: + chunks = self._chunk_text(content) + + if not chunks: + return + + ids = [f"{file_path}_{i}" for i in range(len(chunks))] + metadatas = [{"path": file_path, "chunk": i, "mtime": mtime} for i in range(len(chunks))] + self.add_documents(ids, chunks, metadatas) + + def _search_mcp(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + async def _async_search_mcp(): + tool_name = self.config.vector_store.mcp_tool or "rag_search" + args = {"query": query, "top_k": top_k} + res_str = await mcp_client.async_dispatch(tool_name, args) + try: + data = json.loads(res_str) + if isinstance(data, list): + return data + elif isinstance(data, dict) and "results" in data: + return data["results"] + return [] + except: + return [] + + return asyncio.run(_async_search_mcp()) + + def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + """ + [C: tests/mock_concurrent_mma.py:main, tests/test_rag_engine.py:test_rag_engine_chroma] + """ + if not self.config.enabled: + return [] + if self.config.vector_store.provider == 'mcp': + return self._search_mcp(query, top_k) + if self.collection == "mock": + return [] + + query_embedding = self.embedding_provider.embed([query])[0] + results = self.collection.query( + query_embeddings=[query_embedding], + n_results=top_k + ) + + ret = [] + if results and results["ids"] and results["ids"][0]: + for i in range(len(results["ids"][0])): + ret.append({ + "id": results["ids"][0][i], + "document": results["documents"][0][i], + "metadata": results["metadatas"][0][i] if results["metadatas"] else {}, + "distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0 + }) + return ret + + def delete_documents(self, ids: List[str]): + """ + [C: tests/test_rag_engine.py:test_rag_engine_chroma] + """ + if not self.config.enabled or self.collection == "mock": + return + self.collection.delete(ids=ids) + + def get_all_indexed_paths(self) -> List[str]: + if not self.config.enabled or self.collection == "mock": + return [] + res = self.collection.get(include=["metadatas"]) + if not res or not res["metadatas"]: + return [] + return list(set(m.get("path") for m in res["metadatas"] if m.get("path"))) + + def delete_documents_by_path(self, file_paths: List[str]): + if not self.config.enabled or self.collection == "mock": + return + for path in file_paths: + self.collection.delete(where={"path": path}) \ No newline at end of file