324 lines
10 KiB
Python
324 lines
10 KiB
Python
import asyncio
|
|
import copy
|
|
import json
|
|
import os
|
|
import sys
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
from src import ai_client
|
|
from src import models
|
|
from src import mcp_client
|
|
|
|
from src.file_cache import ASTParser
|
|
|
|
|
|
_SENTENCE_TRANSFORMERS = None
|
|
_GOOGLE_GENAI = None
|
|
_CHROMADB = None
|
|
LOCAL_RAG_INSTALL_HINT = "Local RAG embeddings require sentence-transformers. Install with manual_slop[local-rag] to use local embeddings."
|
|
|
|
|
|
def _get_sentence_transformers():
|
|
global _SENTENCE_TRANSFORMERS
|
|
if _SENTENCE_TRANSFORMERS is None:
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
_SENTENCE_TRANSFORMERS = SentenceTransformer
|
|
except ModuleNotFoundError as e:
|
|
if e.name == "sentence_transformers":
|
|
raise ImportError(LOCAL_RAG_INSTALL_HINT) from e
|
|
raise
|
|
except Exception as e:
|
|
sys.stderr.write(f"FAILED to import sentence_transformers: {e}\n")
|
|
sys.stderr.flush()
|
|
raise e
|
|
return _SENTENCE_TRANSFORMERS
|
|
|
|
def _get_google_genai():
|
|
global _GOOGLE_GENAI
|
|
if _GOOGLE_GENAI is None:
|
|
from google import genai
|
|
from google.genai import types
|
|
_GOOGLE_GENAI = (genai, types)
|
|
return _GOOGLE_GENAI
|
|
|
|
def _get_chromadb():
|
|
global _CHROMADB
|
|
if _CHROMADB is None:
|
|
import chromadb
|
|
from chromadb.config import Settings
|
|
_CHROMADB = (chromadb, Settings)
|
|
return _CHROMADB
|
|
|
|
class BaseEmbeddingProvider:
|
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
|
raise NotImplementedError()
|
|
|
|
class LocalEmbeddingProvider(BaseEmbeddingProvider):
|
|
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
|
ST = _get_sentence_transformers()
|
|
self.model = ST(model_name)
|
|
|
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
|
embeddings = self.model.encode(texts)
|
|
return embeddings.tolist()
|
|
|
|
class GeminiEmbeddingProvider(BaseEmbeddingProvider):
|
|
def __init__(self, model_name: str = 'gemini-embedding-001'):
|
|
self.model_name = model_name
|
|
|
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
|
google_module = _get_google_genai()
|
|
if google_module is None:
|
|
raise ImportError("google-genai is not installed")
|
|
genai_pkg, types = google_module
|
|
ai_client._ensure_gemini_client()
|
|
client = ai_client._gemini_client
|
|
if not client:
|
|
raise ValueError("Gemini client not initialized")
|
|
res = client.models.embed_content(
|
|
model = self.model_name,
|
|
contents = texts,
|
|
config = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
|
|
)
|
|
return [e.values for e in res.embeddings]
|
|
|
|
class RAGEngine:
|
|
def __init__(self, config: models.RAGConfig, base_dir: str = "."):
|
|
self.config = copy.deepcopy(config)
|
|
self.base_dir = base_dir
|
|
self.client = None
|
|
self.collection = None
|
|
self.embedding_provider = None
|
|
|
|
if not self.config.enabled: return
|
|
|
|
self._init_embedding_provider()
|
|
self._init_vector_store()
|
|
|
|
def _init_embedding_provider(self):
|
|
if self.config.embedding_provider == 'gemini':
|
|
self.embedding_provider = GeminiEmbeddingProvider()
|
|
elif self.config.embedding_provider == 'local':
|
|
self.embedding_provider = LocalEmbeddingProvider()
|
|
else:
|
|
raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}")
|
|
|
|
def _init_vector_store(self):
|
|
vs_config = self.config.vector_store
|
|
if vs_config.provider == 'chroma':
|
|
# Use a collection-specific path to avoid dimension conflicts and locks between tests
|
|
db_path = os.path.abspath(os.path.join(self.base_dir, ".slop_cache", f"chroma_{vs_config.collection_name}"))
|
|
os.makedirs(db_path, exist_ok=True)
|
|
chroma_module = _get_chromadb()
|
|
if chroma_module is None:
|
|
raise ImportError("chromadb is not installed")
|
|
chromadb, Settings = chroma_module
|
|
self.client = chromadb.PersistentClient(path=db_path)
|
|
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
|
|
self._validate_collection_dim()
|
|
elif vs_config.provider == 'mock':
|
|
self.client = "mock"
|
|
self.collection = "mock"
|
|
else:
|
|
raise ValueError(f"Unknown vector store provider: {vs_config.provider}")
|
|
|
|
def _validate_collection_dim(self) -> None:
|
|
"""
|
|
Detect dimension mismatch between an existing collection's vectors and
|
|
the current embedding provider's output. When mismatched (e.g. the user
|
|
switched from Gemini 3072-dim to local 384-dim, or vice versa), the
|
|
collection is deleted and recreated empty so the next index pass
|
|
populates it with the correct dim. Prevents silent corruption that
|
|
would later surface as a search error ("Collection expecting
|
|
embedding with dimension of X, got Y") and hang live_gui tests.
|
|
[C: tests/test_rag_engine.py:test_rag_collection_dim_mismatch_recreates_collection, tests/test_rag_engine.py:test_rag_collection_dim_match_preserves_collection]
|
|
"""
|
|
if self.collection is None or self.collection == "mock" or self.embedding_provider is None:
|
|
return
|
|
try:
|
|
res = self.collection.get(limit=1, include=["embeddings"])
|
|
if not res:
|
|
return
|
|
embeddings = res.get("embeddings") if isinstance(res, dict) else None
|
|
if not embeddings or len(embeddings) == 0:
|
|
return
|
|
existing_dim = len(embeddings[0])
|
|
expected_dim = len(self.embedding_provider.embed(["__rag_dim_check__"])[0])
|
|
if existing_dim == expected_dim:
|
|
return
|
|
sys.stderr.write(
|
|
f"RAG: Collection '{self.collection.name}' dim mismatch "
|
|
f"(existing={existing_dim}, expected={expected_dim}). "
|
|
f"Recreating collection to prevent silent corruption.\n"
|
|
)
|
|
sys.stderr.flush()
|
|
self.client.delete_collection(self.collection.name)
|
|
self.collection = self.client.get_or_create_collection(name=self.collection.name)
|
|
except Exception as e:
|
|
sys.stderr.write(f"RAG: Failed to validate collection dim: {e}\n")
|
|
sys.stderr.flush()
|
|
|
|
def is_empty(self) -> bool:
|
|
if not self.config.enabled:
|
|
return True
|
|
if self.config.vector_store.provider == 'mock' or self.collection == "mock":
|
|
return True
|
|
if self.collection is None:
|
|
return True
|
|
return self.collection.count() == 0
|
|
|
|
def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None):
|
|
"""
|
|
[C: tests/test_rag_engine.py:test_rag_engine_chroma]
|
|
"""
|
|
if not self.config.enabled or self.collection == "mock":
|
|
return
|
|
embeddings = self.embedding_provider.embed(texts)
|
|
self.collection.upsert(
|
|
ids = ids,
|
|
embeddings = embeddings,
|
|
documents = texts,
|
|
metadatas = metadatas
|
|
)
|
|
|
|
def _chunk_text(self, content: str) -> List[str]:
|
|
"""Character-based chunking with overlap."""
|
|
chunks = []
|
|
if not content:
|
|
return chunks
|
|
chunk_size = self.config.chunk_size
|
|
overlap = self.config.chunk_overlap
|
|
start = 0
|
|
while start < len(content):
|
|
end = start + chunk_size
|
|
chunks.append(content[start:end])
|
|
if end >= len(content):
|
|
break
|
|
start += (chunk_size - overlap)
|
|
return chunks
|
|
|
|
def _chunk_code(self, content: str, file_path: str) -> List[str]:
|
|
"""AST-aware chunking for Python code."""
|
|
try:
|
|
parser = ASTParser("python")
|
|
tree = parser.parse(content)
|
|
chunks = []
|
|
|
|
for node in tree.root_node.children:
|
|
if node.type in ("function_definition", "class_definition"):
|
|
chunks.append(content[node.start_byte:node.end_byte])
|
|
|
|
if not chunks or len(content) < self.config.chunk_size:
|
|
return self._chunk_text(content)
|
|
return chunks
|
|
except Exception:
|
|
return self._chunk_text(content)
|
|
|
|
def index_file(self, file_path: str):
|
|
"""Reads, chunks, and indexes a file into the vector store."""
|
|
if not self.config.enabled or self.collection == "mock":
|
|
return
|
|
|
|
full_path = os.path.join(self.base_dir, file_path)
|
|
if not os.path.exists(full_path):
|
|
return
|
|
|
|
try:
|
|
mtime = os.path.getmtime(full_path)
|
|
except Exception:
|
|
return
|
|
|
|
try:
|
|
res = self.collection.get(where={"path": file_path}, limit=1, include=["metadatas"])
|
|
if res and res["metadatas"] and res["metadatas"][0]:
|
|
if res["metadatas"][0].get("mtime") == mtime:
|
|
return
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
with open(full_path, "r", encoding="utf-8", errors="ignore") as f:
|
|
content = f.read()
|
|
except Exception:
|
|
return
|
|
|
|
self.collection.delete(where={"path": file_path})
|
|
|
|
if file_path.lower().endswith(".py"):
|
|
chunks = self._chunk_code(content, file_path)
|
|
else:
|
|
chunks = self._chunk_text(content)
|
|
|
|
if not chunks:
|
|
return
|
|
|
|
ids = [f"{file_path}_{i}" for i in range(len(chunks))]
|
|
metadatas = [{"path": file_path, "chunk": i, "mtime": mtime} for i in range(len(chunks))]
|
|
self.add_documents(ids, chunks, metadatas)
|
|
|
|
def _search_mcp(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
async def _async_search_mcp():
|
|
tool_name = self.config.vector_store.mcp_tool or "rag_search"
|
|
args = {"query": query, "top_k": top_k}
|
|
res_str = await mcp_client.async_dispatch(tool_name, args)
|
|
try:
|
|
data = json.loads(res_str)
|
|
if isinstance(data, list):
|
|
return data
|
|
elif isinstance(data, dict) and "results" in data:
|
|
return data["results"]
|
|
return []
|
|
except:
|
|
return []
|
|
|
|
return asyncio.run(_async_search_mcp())
|
|
|
|
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
"""
|
|
[C: tests/mock_concurrent_mma.py:main, tests/test_rag_engine.py:test_rag_engine_chroma]
|
|
"""
|
|
if not self.config.enabled: return []
|
|
if self.config.vector_store.provider == 'mcp': return self._search_mcp(query, top_k)
|
|
if self.collection == "mock": return []
|
|
|
|
query_embedding = self.embedding_provider.embed([query])[0]
|
|
results = self.collection.query(
|
|
query_embeddings = [query_embedding],
|
|
n_results = top_k
|
|
)
|
|
|
|
ret = []
|
|
if results and results["ids"] and results["ids"][0]:
|
|
for i in range(len(results["ids"][0])):
|
|
ret.append({
|
|
"id": results["ids"][0][i],
|
|
"document": results["documents"][0][i],
|
|
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
|
|
"distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0
|
|
})
|
|
return ret
|
|
|
|
def delete_documents(self, ids: List[str]):
|
|
"""
|
|
[C: tests/test_rag_engine.py:test_rag_engine_chroma]
|
|
"""
|
|
if not self.config.enabled or self.collection == "mock":
|
|
return
|
|
self.collection.delete(ids=ids)
|
|
|
|
def get_all_indexed_paths(self) -> List[str]:
|
|
if not self.config.enabled or self.collection == "mock":
|
|
return []
|
|
res = self.collection.get(include=["metadatas"])
|
|
if not res or not res["metadatas"]:
|
|
return []
|
|
return list(set(m.get("path") for m in res["metadatas"] if m.get("path")))
|
|
|
|
def delete_documents_by_path(self, file_paths: List[str]):
|
|
if not self.config.enabled or self.collection == "mock":
|
|
return
|
|
for path in file_paths:
|
|
self.collection.delete(where={"path": path})
|