644d88ab93
The wipe path called self._init_vector_store() which re-invoked _validate_collection_dim, causing infinite recursion (RecursionError) when the dim mismatch test ran with the mock embedding provider. Re-initialize the vector store INLINE after the rmtree wipe so the fresh collection is created without going through the validator again.
385 lines
13 KiB
Python
385 lines
13 KiB
Python
import asyncio
|
|
import copy
|
|
import json
|
|
import os
|
|
import sys
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
from src import ai_client
|
|
from src import models
|
|
from src import mcp_client
|
|
|
|
from src.file_cache import ASTParser
|
|
|
|
|
|
_SENTENCE_TRANSFORMERS = None
|
|
_GOOGLE_GENAI = None
|
|
_CHROMADB = None
|
|
LOCAL_RAG_INSTALL_HINT = "Local RAG embeddings require sentence-transformers. Install with manual_slop[local-rag] to use local embeddings."
|
|
|
|
|
|
def _get_sentence_transformers():
|
|
global _SENTENCE_TRANSFORMERS
|
|
if _SENTENCE_TRANSFORMERS is None:
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
_SENTENCE_TRANSFORMERS = SentenceTransformer
|
|
except ModuleNotFoundError as e:
|
|
if e.name == "sentence_transformers":
|
|
raise ImportError(LOCAL_RAG_INSTALL_HINT) from e
|
|
raise
|
|
except Exception as e:
|
|
sys.stderr.write(f"FAILED to import sentence_transformers: {e}\n")
|
|
sys.stderr.flush()
|
|
raise e
|
|
return _SENTENCE_TRANSFORMERS
|
|
|
|
def _get_google_genai():
|
|
global _GOOGLE_GENAI
|
|
if _GOOGLE_GENAI is None:
|
|
from google import genai
|
|
from google.genai import types
|
|
_GOOGLE_GENAI = (genai, types)
|
|
return _GOOGLE_GENAI
|
|
|
|
def _get_chromadb():
|
|
global _CHROMADB
|
|
if _CHROMADB is None:
|
|
import chromadb
|
|
from chromadb.config import Settings
|
|
_CHROMADB = (chromadb, Settings)
|
|
return _CHROMADB
|
|
|
|
class BaseEmbeddingProvider:
|
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
|
raise NotImplementedError()
|
|
|
|
class LocalEmbeddingProvider(BaseEmbeddingProvider):
|
|
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
|
ST = _get_sentence_transformers()
|
|
self.model = ST(model_name)
|
|
|
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
|
embeddings = self.model.encode(texts)
|
|
return embeddings.tolist()
|
|
|
|
class GeminiEmbeddingProvider(BaseEmbeddingProvider):
|
|
def __init__(self, model_name: str = 'gemini-embedding-001'):
|
|
self.model_name = model_name
|
|
|
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
|
google_module = _get_google_genai()
|
|
if google_module is None:
|
|
raise ImportError("google-genai is not installed")
|
|
genai_pkg, types = google_module
|
|
ai_client._ensure_gemini_client()
|
|
client = ai_client._gemini_client
|
|
if not client:
|
|
raise ValueError("Gemini client not initialized")
|
|
res = client.models.embed_content(
|
|
model = self.model_name,
|
|
contents = texts,
|
|
config = types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
|
|
)
|
|
return [e.values for e in res.embeddings]
|
|
|
|
class RAGEngine:
|
|
def __init__(self, config: models.RAGConfig, base_dir: str = "."):
|
|
self.config = copy.deepcopy(config)
|
|
self.base_dir = base_dir
|
|
self.client = None
|
|
self.collection = None
|
|
self.embedding_provider = None
|
|
|
|
if not self.config.enabled: return
|
|
|
|
self._init_embedding_provider()
|
|
self._init_vector_store()
|
|
|
|
def _init_embedding_provider(self):
|
|
if self.config.embedding_provider == 'gemini':
|
|
self.embedding_provider = GeminiEmbeddingProvider()
|
|
elif self.config.embedding_provider == 'local':
|
|
self.embedding_provider = LocalEmbeddingProvider()
|
|
else:
|
|
raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}")
|
|
|
|
def _init_vector_store(self):
|
|
vs_config = self.config.vector_store
|
|
if vs_config.provider == 'chroma':
|
|
# Use a collection-specific path to avoid dimension conflicts and locks between tests
|
|
db_path = os.path.abspath(os.path.join(self.base_dir, ".slop_cache", f"chroma_{vs_config.collection_name}"))
|
|
os.makedirs(db_path, exist_ok=True)
|
|
chroma_module = _get_chromadb()
|
|
if chroma_module is None:
|
|
raise ImportError("chromadb is not installed")
|
|
chromadb, Settings = chroma_module
|
|
self.client = chromadb.PersistentClient(path=db_path)
|
|
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
|
|
self._validate_collection_dim()
|
|
elif vs_config.provider == 'mock':
|
|
self.client = "mock"
|
|
self.collection = "mock"
|
|
else:
|
|
raise ValueError(f"Unknown vector store provider: {vs_config.provider}")
|
|
|
|
def _validate_collection_dim(self) -> None:
|
|
"""
|
|
Detect dimension mismatch between an existing collection's vectors and
|
|
the current embedding provider's output. When mismatched (e.g. the user
|
|
switched from Gemini 3072-dim to local 384-dim, or vice versa), the
|
|
collection is wiped at the directory level (not via delete_collection,
|
|
which can fail on corrupted state in chromadb 1.5.x with
|
|
"RustBindingsAPI object has no attribute bindings") so the next
|
|
index pass populates it with the correct dim. Prevents silent
|
|
corruption that would later surface as a search error
|
|
("Collection expecting embedding with dimension of X, got Y") and
|
|
hang live_gui tests.
|
|
[C: tests/test_rag_engine.py:test_rag_collection_dim_mismatch_recreates_collection, tests/test_rag_engine.py:test_rag_collection_dim_match_preserves_collection]
|
|
"""
|
|
if self.collection is None or self.collection == "mock" or self.embedding_provider is None:
|
|
return
|
|
try:
|
|
res = self.collection.get(limit=1, include=["embeddings"])
|
|
except Exception as e:
|
|
sys.stderr.write(f"RAG: Failed to read collection for dim check: {e}\n")
|
|
sys.stderr.flush()
|
|
return
|
|
if not res:
|
|
return
|
|
embeddings = res.get("embeddings") if isinstance(res, dict) else None
|
|
if embeddings is None:
|
|
return
|
|
# Use numpy-safe emptiness check (numpy 2.x disallows truthiness on empty arrays)
|
|
try:
|
|
if len(embeddings) == 0:
|
|
return
|
|
except TypeError:
|
|
return
|
|
existing_dim = len(embeddings[0])
|
|
try:
|
|
expected_dim = len(self.embedding_provider.embed(["__rag_dim_check__"])[0])
|
|
except Exception as e:
|
|
sys.stderr.write(f"RAG: Failed to compute expected dim: {e}\n")
|
|
sys.stderr.flush()
|
|
return
|
|
if existing_dim == expected_dim:
|
|
return
|
|
sys.stderr.write(
|
|
f"RAG: Collection '{self.collection.name}' dim mismatch "
|
|
f"(existing={existing_dim}, expected={expected_dim}). "
|
|
f"Wiping chroma dir to prevent silent corruption.\n"
|
|
)
|
|
sys.stderr.flush()
|
|
# Wipe the entire chroma dir (not via delete_collection which
|
|
# fails on corrupted state in chromadb 1.5.x with
|
|
# "RustBindingsAPI object has no attribute bindings"). Rmtree is
|
|
# reliable and re-creates a fresh empty collection.
|
|
# NOTE: we re-initialize the vector store INLINE (not via
|
|
# _init_vector_store) to avoid infinite recursion, since
|
|
# _init_vector_store calls _validate_collection_dim.
|
|
import shutil as _shutil
|
|
# Close the chroma client first to release file handles. Without
|
|
# this, rmtree fails with WinError 32 on Windows.
|
|
try:
|
|
if hasattr(self, 'client') and self.client and self.client != "mock":
|
|
self.client.close()
|
|
except Exception:
|
|
pass
|
|
self.client = None
|
|
self.collection = None
|
|
if hasattr(self, 'base_dir') and self.base_dir:
|
|
db_path = os.path.abspath(os.path.join(self.base_dir, ".slop_cache", f"chroma_{self.config.vector_store.collection_name}"))
|
|
if os.path.isdir(db_path):
|
|
try:
|
|
_shutil.rmtree(db_path)
|
|
except Exception as e:
|
|
sys.stderr.write(f"RAG: Failed to wipe chroma dir: {e}\n")
|
|
sys.stderr.flush()
|
|
# Re-initialize the vector store inline (no recursion).
|
|
vs_config = self.config.vector_store
|
|
if vs_config.provider == 'chroma':
|
|
from src import rag_engine as _re_self
|
|
os.makedirs(db_path, exist_ok=True)
|
|
chroma_module = _get_chromadb()
|
|
if chroma_module is None:
|
|
raise ImportError("chromadb is not installed")
|
|
chromadb, _Settings = chroma_module
|
|
self.client = chromadb.PersistentClient(path=db_path)
|
|
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
|
|
elif vs_config.provider == 'mock':
|
|
self.client = "mock"
|
|
self.collection = "mock"
|
|
|
|
def is_empty(self) -> bool:
|
|
if not self.config.enabled:
|
|
return True
|
|
if self.config.vector_store.provider == 'mock' or self.collection == "mock":
|
|
return True
|
|
if self.collection is None:
|
|
return True
|
|
return self.collection.count() == 0
|
|
|
|
def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None):
|
|
"""
|
|
[C: tests/test_rag_engine.py:test_rag_engine_chroma]
|
|
"""
|
|
if not self.config.enabled or self.collection == "mock":
|
|
return
|
|
embeddings = self.embedding_provider.embed(texts)
|
|
self.collection.upsert(
|
|
ids = ids,
|
|
embeddings = embeddings,
|
|
documents = texts,
|
|
metadatas = metadatas
|
|
)
|
|
|
|
def _chunk_text(self, content: str) -> List[str]:
|
|
"""Character-based chunking with overlap."""
|
|
chunks = []
|
|
if not content:
|
|
return chunks
|
|
chunk_size = self.config.chunk_size
|
|
overlap = self.config.chunk_overlap
|
|
start = 0
|
|
while start < len(content):
|
|
end = start + chunk_size
|
|
chunks.append(content[start:end])
|
|
if end >= len(content):
|
|
break
|
|
start += (chunk_size - overlap)
|
|
return chunks
|
|
|
|
def _chunk_code(self, content: str, file_path: str) -> List[str]:
|
|
"""AST-aware chunking for Python code."""
|
|
try:
|
|
parser = ASTParser("python")
|
|
tree = parser.parse(content)
|
|
chunks = []
|
|
|
|
for node in tree.root_node.children:
|
|
if node.type in ("function_definition", "class_definition"):
|
|
chunks.append(content[node.start_byte:node.end_byte])
|
|
|
|
if not chunks or len(content) < self.config.chunk_size:
|
|
return self._chunk_text(content)
|
|
return chunks
|
|
except Exception:
|
|
return self._chunk_text(content)
|
|
|
|
def index_file(self, file_path: str):
|
|
"""Reads, chunks, and indexes a file into the vector store."""
|
|
if not self.config.enabled or self.collection == "mock":
|
|
return
|
|
|
|
full_path = os.path.join(self.base_dir, file_path)
|
|
if not os.path.exists(full_path):
|
|
# CWD fallback: handle the case where base_dir was resolved to a
|
|
# parent directory (e.g. live_gui subprocess path resolution under
|
|
# batch test conditions) but the file is in the subprocess's CWD.
|
|
# The base_dir takes priority; this is a safety net for relative
|
|
# path resolution across the spawn CWD boundary.
|
|
cwd_path = os.path.join(os.getcwd(), file_path)
|
|
if os.path.exists(cwd_path):
|
|
full_path = cwd_path
|
|
else:
|
|
return
|
|
|
|
try:
|
|
mtime = os.path.getmtime(full_path)
|
|
except Exception:
|
|
return
|
|
|
|
try:
|
|
res = self.collection.get(where={"path": file_path}, limit=1, include=["metadatas"])
|
|
if res and res["metadatas"] and res["metadatas"][0]:
|
|
if res["metadatas"][0].get("mtime") == mtime:
|
|
return
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
with open(full_path, "r", encoding="utf-8", errors="ignore") as f:
|
|
content = f.read()
|
|
except Exception:
|
|
return
|
|
|
|
self.collection.delete(where={"path": file_path})
|
|
|
|
if file_path.lower().endswith(".py"):
|
|
chunks = self._chunk_code(content, file_path)
|
|
else:
|
|
chunks = self._chunk_text(content)
|
|
|
|
if not chunks:
|
|
return
|
|
|
|
ids = [f"{file_path}_{i}" for i in range(len(chunks))]
|
|
metadatas = [{"path": file_path, "chunk": i, "mtime": mtime} for i in range(len(chunks))]
|
|
self.add_documents(ids, chunks, metadatas)
|
|
|
|
def _search_mcp(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
async def _async_search_mcp():
|
|
tool_name = self.config.vector_store.mcp_tool or "rag_search"
|
|
args = {"query": query, "top_k": top_k}
|
|
res_str = await mcp_client.async_dispatch(tool_name, args)
|
|
try:
|
|
data = json.loads(res_str)
|
|
if isinstance(data, list):
|
|
return data
|
|
elif isinstance(data, dict) and "results" in data:
|
|
return data["results"]
|
|
return []
|
|
except:
|
|
return []
|
|
|
|
return asyncio.run(_async_search_mcp())
|
|
|
|
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
"""
|
|
[C: tests/mock_concurrent_mma.py:main, tests/test_rag_engine.py:test_rag_engine_chroma]
|
|
"""
|
|
if not self.config.enabled: return []
|
|
if self.config.vector_store.provider == 'mcp': return self._search_mcp(query, top_k)
|
|
if self.collection == "mock": return []
|
|
|
|
query_embedding = self.embedding_provider.embed([query])[0]
|
|
results = self.collection.query(
|
|
query_embeddings = [query_embedding],
|
|
n_results = top_k
|
|
)
|
|
|
|
ret = []
|
|
if results and results["ids"] and results["ids"][0]:
|
|
for i in range(len(results["ids"][0])):
|
|
ret.append({
|
|
"id": results["ids"][0][i],
|
|
"document": results["documents"][0][i],
|
|
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
|
|
"distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0
|
|
})
|
|
return ret
|
|
|
|
def delete_documents(self, ids: List[str]):
|
|
"""
|
|
[C: tests/test_rag_engine.py:test_rag_engine_chroma]
|
|
"""
|
|
if not self.config.enabled or self.collection == "mock":
|
|
return
|
|
self.collection.delete(ids=ids)
|
|
|
|
def get_all_indexed_paths(self) -> List[str]:
|
|
if not self.config.enabled or self.collection == "mock":
|
|
return []
|
|
res = self.collection.get(include=["metadatas"])
|
|
if not res or not res["metadatas"]:
|
|
return []
|
|
return list(set(m.get("path") for m in res["metadatas"] if m.get("path")))
|
|
|
|
def delete_documents_by_path(self, file_paths: List[str]):
|
|
if not self.config.enabled or self.collection == "mock":
|
|
return
|
|
for path in file_paths:
|
|
self.collection.delete(where={"path": path})
|