adjustments to rag engine

This commit is contained in:
2026-05-13 06:32:26 -04:00
parent 1a529ed750
commit 8e9725792f
+226 -217
View File
@@ -8,248 +8,257 @@ from chromadb.config import Settings
from src import models from src import models
from src import mcp_client from src import mcp_client
try: _SENTENCE_TRANSFORMERS = None
from sentence_transformers import SentenceTransformer _GOOGLE_GENAI = None
except ImportError:
SentenceTransformer = None
from google import genai def _get_sentence_transformers():
from google.genai import types global _SENTENCE_TRANSFORMERS
from src import ai_client if _SENTENCE_TRANSFORMERS is None:
from sentence_transformers import SentenceTransformer
_SENTENCE_TRANSFORMERS = SentenceTransformer
return _SENTENCE_TRANSFORMERS
def _get_google_genai():
global _GOOGLE_GENAI
if _GOOGLE_GENAI is None:
from google import genai
from google.genai import types
_GOOGLE_GENAI = (genai, types)
return _GOOGLE_GENAI
class BaseEmbeddingProvider: class BaseEmbeddingProvider:
def embed(self, texts: List[str]) -> List[List[float]]: def embed(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError() raise NotImplementedError()
class LocalEmbeddingProvider(BaseEmbeddingProvider): class LocalEmbeddingProvider(BaseEmbeddingProvider):
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
if SentenceTransformer is None: ST = _get_sentence_transformers()
raise ImportError("sentence-transformers is not installed") if ST is None:
self.model = SentenceTransformer(model_name) raise ImportError("sentence-transformers is not installed")
self.model = ST(model_name)
def embed(self, texts: List[str]) -> List[List[float]]: def embed(self, texts: List[str]) -> List[List[float]]:
embeddings = self.model.encode(texts) embeddings = self.model.encode(texts)
return embeddings.tolist() return embeddings.tolist()
class GeminiEmbeddingProvider(BaseEmbeddingProvider): class GeminiEmbeddingProvider(BaseEmbeddingProvider):
def __init__(self, model_name: str = 'text-embedding-004'): def __init__(self, model_name: str = 'text-embedding-004'):
self.model_name = model_name self.model_name = model_name
def embed(self, texts: List[str]) -> List[List[float]]: def embed(self, texts: List[str]) -> List[List[float]]:
ai_client._ensure_gemini_client() google_module = _get_google_genai()
client = ai_client._gemini_client if google_module is None:
if not client: raise ImportError("google-genai is not installed")
raise ValueError("Gemini client not initialized") genai_pkg, types = google_module
from src import ai_client
# For text-embedding-004, we can embed a batch ai_client._ensure_gemini_client()
res = client.models.embed_content( client = ai_client._gemini_client
model=self.model_name, if not client:
contents=texts, raise ValueError("Gemini client not initialized")
config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT") res = client.models.embed_content(
) model=self.model_name,
return [e.values for e in res.embeddings] contents=texts,
config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
)
return [e.values for e in res.embeddings]
class RAGEngine: class RAGEngine:
def __init__(self, config: models.RAGConfig, base_dir: str = "."): def __init__(self, config: models.RAGConfig, base_dir: str = "."):
self.config = config self.config = config
self.base_dir = base_dir self.base_dir = base_dir
self.client = None self.client = None
self.collection = None self.collection = None
self.embedding_provider = None self.embedding_provider = None
if not self.config.enabled:
return
self._init_embedding_provider()
self._init_vector_store()
def _init_embedding_provider(self): if not self.config.enabled:
if self.config.embedding_provider == 'gemini': return
self.embedding_provider = GeminiEmbeddingProvider()
elif self.config.embedding_provider == 'local':
self.embedding_provider = LocalEmbeddingProvider()
else:
raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}")
def _init_vector_store(self): self._init_embedding_provider()
vs_config = self.config.vector_store self._init_vector_store()
if vs_config.provider == 'chroma':
db_path = os.path.join(self.base_dir, ".slop_cache", "chroma_db")
os.makedirs(db_path, exist_ok=True)
self.client = chromadb.PersistentClient(path=db_path)
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
elif vs_config.provider == 'mock':
self.client = "mock"
self.collection = "mock"
else:
raise ValueError(f"Unknown vector store provider: {vs_config.provider}")
def is_empty(self) -> bool: def _init_embedding_provider(self):
if not self.config.enabled: if self.config.embedding_provider == 'gemini':
return True self.embedding_provider = GeminiEmbeddingProvider()
if self.config.vector_store.provider == 'mock' or self.collection == "mock": elif self.config.embedding_provider == 'local':
return True self.embedding_provider = LocalEmbeddingProvider()
if self.collection is None: else:
return True raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}")
return self.collection.count() == 0
def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None): def _init_vector_store(self):
""" vs_config = self.config.vector_store
[C: tests/test_rag_engine.py:test_rag_engine_chroma] if vs_config.provider == 'chroma':
""" db_path = os.path.join(self.base_dir, ".slop_cache", "chroma_db")
if not self.config.enabled or self.collection == "mock": os.makedirs(db_path, exist_ok=True)
return self.client = chromadb.PersistentClient(path=db_path)
embeddings = self.embedding_provider.embed(texts) self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
self.collection.upsert( elif vs_config.provider == 'mock':
ids=ids, self.client = "mock"
embeddings=embeddings, self.collection = "mock"
documents=texts, else:
metadatas=metadatas raise ValueError(f"Unknown vector store provider: {vs_config.provider}")
)
def _chunk_text(self, content: str) -> List[str]: def is_empty(self) -> bool:
"""Character-based chunking with overlap.""" if not self.config.enabled:
chunks = [] return True
if not content: if self.config.vector_store.provider == 'mock' or self.collection == "mock":
return chunks return True
chunk_size = self.config.chunk_size if self.collection is None:
overlap = self.config.chunk_overlap return True
start = 0 return self.collection.count() == 0
while start < len(content):
end = start + chunk_size
chunks.append(content[start:end])
if end >= len(content):
break
start += (chunk_size - overlap)
return chunks
def _chunk_code(self, content: str, file_path: str) -> List[str]: def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None):
"""AST-aware chunking for Python code.""" """
try: [C: tests/test_rag_engine.py:test_rag_engine_chroma]
from src.file_cache import ASTParser """
parser = ASTParser("python") if not self.config.enabled or self.collection == "mock":
tree = parser.parse(content) return
chunks = [] embeddings = self.embedding_provider.embed(texts)
self.collection.upsert(
# Capture classes and top-level functions ids=ids,
for node in tree.root_node.children: embeddings=embeddings,
if node.type in ("function_definition", "class_definition"): documents=texts,
chunks.append(content[node.start_byte:node.end_byte]) metadatas=metadatas
)
# Fallback if no structural chunks found or if file is small
if not chunks or len(content) < self.config.chunk_size:
return self._chunk_text(content)
return chunks
except Exception:
return self._chunk_text(content)
def index_file(self, file_path: str): def _chunk_text(self, content: str) -> List[str]:
"""Reads, chunks, and indexes a file into the vector store.""" """Character-based chunking with overlap."""
if not self.config.enabled or self.collection == "mock": chunks = []
return if not content:
return chunks
full_path = os.path.join(self.base_dir, file_path) chunk_size = self.config.chunk_size
if not os.path.exists(full_path): overlap = self.config.chunk_overlap
return start = 0
while start < len(content):
try: end = start + chunk_size
mtime = os.path.getmtime(full_path) chunks.append(content[start:end])
except Exception: if end >= len(content):
return break
start += (chunk_size - overlap)
return chunks
# Incremental check: see if we already have this file with the same mtime def _chunk_code(self, content: str, file_path: str) -> List[str]:
try: """AST-aware chunking for Python code."""
res = self.collection.get(where={"path": file_path}, limit=1, include=["metadatas"]) try:
if res and res["metadatas"] and res["metadatas"][0]: from src.file_cache import ASTParser
if res["metadatas"][0].get("mtime") == mtime: parser = ASTParser("python")
return tree = parser.parse(content)
except Exception: chunks = []
pass
try: for node in tree.root_node.children:
with open(full_path, "r", encoding="utf-8", errors="ignore") as f: if node.type in ("function_definition", "class_definition"):
content = f.read() chunks.append(content[node.start_byte:node.end_byte])
except Exception:
return
# Remove old entries for this file if not chunks or len(content) < self.config.chunk_size:
self.collection.delete(where={"path": file_path}) return self._chunk_text(content)
return chunks
if file_path.lower().endswith(".py"): except Exception:
chunks = self._chunk_code(content, file_path) return self._chunk_text(content)
else:
chunks = self._chunk_text(content)
if not chunks:
return
ids = [f"{file_path}_{i}" for i in range(len(chunks))] def index_file(self, file_path: str):
metadatas = [{"path": file_path, "chunk": i, "mtime": mtime} for i in range(len(chunks))] """Reads, chunks, and indexes a file into the vector store."""
self.add_documents(ids, chunks, metadatas) if not self.config.enabled or self.collection == "mock":
return
def _search_mcp(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: full_path = os.path.join(self.base_dir, file_path)
async def _async_search_mcp(): if not os.path.exists(full_path):
tool_name = self.config.vector_store.mcp_tool or "rag_search" return
args = {"query": query, "top_k": top_k}
res_str = await mcp_client.async_dispatch(tool_name, args)
try:
data = json.loads(res_str)
if isinstance(data, list):
return data
elif isinstance(data, dict) and "results" in data:
return data["results"]
return []
except:
return []
return asyncio.run(_async_search_mcp()) try:
mtime = os.path.getmtime(full_path)
except Exception:
return
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: try:
""" res = self.collection.get(where={"path": file_path}, limit=1, include=["metadatas"])
[C: tests/mock_concurrent_mma.py:main, tests/test_rag_engine.py:test_rag_engine_chroma] if res and res["metadatas"] and res["metadatas"][0]:
""" if res["metadatas"][0].get("mtime") == mtime:
if not self.config.enabled: return
return [] except Exception:
if self.config.vector_store.provider == 'mcp': pass
return self._search_mcp(query, top_k)
if self.collection == "mock":
return []
query_embedding = self.embedding_provider.embed([query])[0]
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k
)
ret = []
if results and results["ids"] and results["ids"][0]:
for i in range(len(results["ids"][0])):
ret.append({
"id": results["ids"][0][i],
"document": results["documents"][0][i],
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
"distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0
})
return ret
def delete_documents(self, ids: List[str]): try:
""" with open(full_path, "r", encoding="utf-8", errors="ignore") as f:
[C: tests/test_rag_engine.py:test_rag_engine_chroma] content = f.read()
""" except Exception:
if not self.config.enabled or self.collection == "mock": return
return
self.collection.delete(ids=ids)
def get_all_indexed_paths(self) -> List[str]: self.collection.delete(where={"path": file_path})
if not self.config.enabled or self.collection == "mock":
return []
res = self.collection.get(include=["metadatas"])
if not res or not res["metadatas"]:
return []
return list(set(m.get("path") for m in res["metadatas"] if m.get("path")))
def delete_documents_by_path(self, file_paths: List[str]): if file_path.lower().endswith(".py"):
if not self.config.enabled or self.collection == "mock": chunks = self._chunk_code(content, file_path)
return else:
for path in file_paths: chunks = self._chunk_text(content)
self.collection.delete(where={"path": path})
if not chunks:
return
ids = [f"{file_path}_{i}" for i in range(len(chunks))]
metadatas = [{"path": file_path, "chunk": i, "mtime": mtime} for i in range(len(chunks))]
self.add_documents(ids, chunks, metadatas)
def _search_mcp(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
async def _async_search_mcp():
tool_name = self.config.vector_store.mcp_tool or "rag_search"
args = {"query": query, "top_k": top_k}
res_str = await mcp_client.async_dispatch(tool_name, args)
try:
data = json.loads(res_str)
if isinstance(data, list):
return data
elif isinstance(data, dict) and "results" in data:
return data["results"]
return []
except:
return []
return asyncio.run(_async_search_mcp())
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
[C: tests/mock_concurrent_mma.py:main, tests/test_rag_engine.py:test_rag_engine_chroma]
"""
if not self.config.enabled:
return []
if self.config.vector_store.provider == 'mcp':
return self._search_mcp(query, top_k)
if self.collection == "mock":
return []
query_embedding = self.embedding_provider.embed([query])[0]
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k
)
ret = []
if results and results["ids"] and results["ids"][0]:
for i in range(len(results["ids"][0])):
ret.append({
"id": results["ids"][0][i],
"document": results["documents"][0][i],
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
"distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0
})
return ret
def delete_documents(self, ids: List[str]):
"""
[C: tests/test_rag_engine.py:test_rag_engine_chroma]
"""
if not self.config.enabled or self.collection == "mock":
return
self.collection.delete(ids=ids)
def get_all_indexed_paths(self) -> List[str]:
if not self.config.enabled or self.collection == "mock":
return []
res = self.collection.get(include=["metadatas"])
if not res or not res["metadatas"]:
return []
return list(set(m.get("path") for m in res["metadatas"] if m.get("path")))
def delete_documents_by_path(self, file_paths: List[str]):
if not self.config.enabled or self.collection == "mock":
return
for path in file_paths:
self.collection.delete(where={"path": path})