Private
Public Access
0
0
Files
manual_slop/src/rag_engine.py
T
2026-05-13 06:32:26 -04:00

264 lines
9.4 KiB
Python

import os
import sys
import asyncio
import json
from typing import List, Dict, Any, Optional
import chromadb
from chromadb.config import Settings
from src import models
from src import mcp_client
_SENTENCE_TRANSFORMERS = None
_GOOGLE_GENAI = None
def _get_sentence_transformers():
global _SENTENCE_TRANSFORMERS
if _SENTENCE_TRANSFORMERS is None:
from sentence_transformers import SentenceTransformer
_SENTENCE_TRANSFORMERS = SentenceTransformer
return _SENTENCE_TRANSFORMERS
def _get_google_genai():
global _GOOGLE_GENAI
if _GOOGLE_GENAI is None:
from google import genai
from google.genai import types
_GOOGLE_GENAI = (genai, types)
return _GOOGLE_GENAI
class BaseEmbeddingProvider:
def embed(self, texts: List[str]) -> List[List[float]]:
raise NotImplementedError()
class LocalEmbeddingProvider(BaseEmbeddingProvider):
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
ST = _get_sentence_transformers()
if ST is None:
raise ImportError("sentence-transformers is not installed")
self.model = ST(model_name)
def embed(self, texts: List[str]) -> List[List[float]]:
embeddings = self.model.encode(texts)
return embeddings.tolist()
class GeminiEmbeddingProvider(BaseEmbeddingProvider):
def __init__(self, model_name: str = 'text-embedding-004'):
self.model_name = model_name
def embed(self, texts: List[str]) -> List[List[float]]:
google_module = _get_google_genai()
if google_module is None:
raise ImportError("google-genai is not installed")
genai_pkg, types = google_module
from src import ai_client
ai_client._ensure_gemini_client()
client = ai_client._gemini_client
if not client:
raise ValueError("Gemini client not initialized")
res = client.models.embed_content(
model=self.model_name,
contents=texts,
config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT")
)
return [e.values for e in res.embeddings]
class RAGEngine:
def __init__(self, config: models.RAGConfig, base_dir: str = "."):
self.config = config
self.base_dir = base_dir
self.client = None
self.collection = None
self.embedding_provider = None
if not self.config.enabled:
return
self._init_embedding_provider()
self._init_vector_store()
def _init_embedding_provider(self):
if self.config.embedding_provider == 'gemini':
self.embedding_provider = GeminiEmbeddingProvider()
elif self.config.embedding_provider == 'local':
self.embedding_provider = LocalEmbeddingProvider()
else:
raise ValueError(f"Unknown embedding provider: {self.config.embedding_provider}")
def _init_vector_store(self):
vs_config = self.config.vector_store
if vs_config.provider == 'chroma':
db_path = os.path.join(self.base_dir, ".slop_cache", "chroma_db")
os.makedirs(db_path, exist_ok=True)
self.client = chromadb.PersistentClient(path=db_path)
self.collection = self.client.get_or_create_collection(name=vs_config.collection_name)
elif vs_config.provider == 'mock':
self.client = "mock"
self.collection = "mock"
else:
raise ValueError(f"Unknown vector store provider: {vs_config.provider}")
def is_empty(self) -> bool:
if not self.config.enabled:
return True
if self.config.vector_store.provider == 'mock' or self.collection == "mock":
return True
if self.collection is None:
return True
return self.collection.count() == 0
def add_documents(self, ids: List[str], texts: List[str], metadatas: Optional[List[Dict[str, Any]]] = None):
"""
[C: tests/test_rag_engine.py:test_rag_engine_chroma]
"""
if not self.config.enabled or self.collection == "mock":
return
embeddings = self.embedding_provider.embed(texts)
self.collection.upsert(
ids=ids,
embeddings=embeddings,
documents=texts,
metadatas=metadatas
)
def _chunk_text(self, content: str) -> List[str]:
"""Character-based chunking with overlap."""
chunks = []
if not content:
return chunks
chunk_size = self.config.chunk_size
overlap = self.config.chunk_overlap
start = 0
while start < len(content):
end = start + chunk_size
chunks.append(content[start:end])
if end >= len(content):
break
start += (chunk_size - overlap)
return chunks
def _chunk_code(self, content: str, file_path: str) -> List[str]:
"""AST-aware chunking for Python code."""
try:
from src.file_cache import ASTParser
parser = ASTParser("python")
tree = parser.parse(content)
chunks = []
for node in tree.root_node.children:
if node.type in ("function_definition", "class_definition"):
chunks.append(content[node.start_byte:node.end_byte])
if not chunks or len(content) < self.config.chunk_size:
return self._chunk_text(content)
return chunks
except Exception:
return self._chunk_text(content)
def index_file(self, file_path: str):
"""Reads, chunks, and indexes a file into the vector store."""
if not self.config.enabled or self.collection == "mock":
return
full_path = os.path.join(self.base_dir, file_path)
if not os.path.exists(full_path):
return
try:
mtime = os.path.getmtime(full_path)
except Exception:
return
try:
res = self.collection.get(where={"path": file_path}, limit=1, include=["metadatas"])
if res and res["metadatas"] and res["metadatas"][0]:
if res["metadatas"][0].get("mtime") == mtime:
return
except Exception:
pass
try:
with open(full_path, "r", encoding="utf-8", errors="ignore") as f:
content = f.read()
except Exception:
return
self.collection.delete(where={"path": file_path})
if file_path.lower().endswith(".py"):
chunks = self._chunk_code(content, file_path)
else:
chunks = self._chunk_text(content)
if not chunks:
return
ids = [f"{file_path}_{i}" for i in range(len(chunks))]
metadatas = [{"path": file_path, "chunk": i, "mtime": mtime} for i in range(len(chunks))]
self.add_documents(ids, chunks, metadatas)
def _search_mcp(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
async def _async_search_mcp():
tool_name = self.config.vector_store.mcp_tool or "rag_search"
args = {"query": query, "top_k": top_k}
res_str = await mcp_client.async_dispatch(tool_name, args)
try:
data = json.loads(res_str)
if isinstance(data, list):
return data
elif isinstance(data, dict) and "results" in data:
return data["results"]
return []
except:
return []
return asyncio.run(_async_search_mcp())
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
[C: tests/mock_concurrent_mma.py:main, tests/test_rag_engine.py:test_rag_engine_chroma]
"""
if not self.config.enabled:
return []
if self.config.vector_store.provider == 'mcp':
return self._search_mcp(query, top_k)
if self.collection == "mock":
return []
query_embedding = self.embedding_provider.embed([query])[0]
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k
)
ret = []
if results and results["ids"] and results["ids"][0]:
for i in range(len(results["ids"][0])):
ret.append({
"id": results["ids"][0][i],
"document": results["documents"][0][i],
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
"distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0
})
return ret
def delete_documents(self, ids: List[str]):
"""
[C: tests/test_rag_engine.py:test_rag_engine_chroma]
"""
if not self.config.enabled or self.collection == "mock":
return
self.collection.delete(ids=ids)
def get_all_indexed_paths(self) -> List[str]:
if not self.config.enabled or self.collection == "mock":
return []
res = self.collection.get(include=["metadatas"])
if not res or not res["metadatas"]:
return []
return list(set(m.get("path") for m in res["metadatas"] if m.get("path")))
def delete_documents_by_path(self, file_paths: List[str]):
if not self.config.enabled or self.collection == "mock":
return
for path in file_paths:
self.collection.delete(where={"path": path})