adjustments to rag engine
This commit is contained in:
+24
-15
@@ -8,14 +8,23 @@ from chromadb.config import Settings
|
|||||||
from src import models
|
from src import models
|
||||||
from src import mcp_client
|
from src import mcp_client
|
||||||
|
|
||||||
try:
|
_SENTENCE_TRANSFORMERS = None
|
||||||
from sentence_transformers import SentenceTransformer
|
_GOOGLE_GENAI = None
|
||||||
except ImportError:
|
|
||||||
SentenceTransformer = None
|
|
||||||
|
|
||||||
from google import genai
|
def _get_sentence_transformers():
|
||||||
from google.genai import types
|
global _SENTENCE_TRANSFORMERS
|
||||||
from src import ai_client
|
if _SENTENCE_TRANSFORMERS is None:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
_SENTENCE_TRANSFORMERS = SentenceTransformer
|
||||||
|
return _SENTENCE_TRANSFORMERS
|
||||||
|
|
||||||
|
def _get_google_genai():
|
||||||
|
global _GOOGLE_GENAI
|
||||||
|
if _GOOGLE_GENAI is None:
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types
|
||||||
|
_GOOGLE_GENAI = (genai, types)
|
||||||
|
return _GOOGLE_GENAI
|
||||||
|
|
||||||
class BaseEmbeddingProvider:
|
class BaseEmbeddingProvider:
|
||||||
def embed(self, texts: List[str]) -> List[List[float]]:
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
@@ -23,9 +32,10 @@ class BaseEmbeddingProvider:
|
|||||||
|
|
||||||
class LocalEmbeddingProvider(BaseEmbeddingProvider):
|
class LocalEmbeddingProvider(BaseEmbeddingProvider):
|
||||||
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
|
||||||
if SentenceTransformer is None:
|
ST = _get_sentence_transformers()
|
||||||
|
if ST is None:
|
||||||
raise ImportError("sentence-transformers is not installed")
|
raise ImportError("sentence-transformers is not installed")
|
||||||
self.model = SentenceTransformer(model_name)
|
self.model = ST(model_name)
|
||||||
|
|
||||||
def embed(self, texts: List[str]) -> List[List[float]]:
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
embeddings = self.model.encode(texts)
|
embeddings = self.model.encode(texts)
|
||||||
@@ -36,12 +46,15 @@ class GeminiEmbeddingProvider(BaseEmbeddingProvider):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def embed(self, texts: List[str]) -> List[List[float]]:
|
def embed(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
google_module = _get_google_genai()
|
||||||
|
if google_module is None:
|
||||||
|
raise ImportError("google-genai is not installed")
|
||||||
|
genai_pkg, types = google_module
|
||||||
|
from src import ai_client
|
||||||
ai_client._ensure_gemini_client()
|
ai_client._ensure_gemini_client()
|
||||||
client = ai_client._gemini_client
|
client = ai_client._gemini_client
|
||||||
if not client:
|
if not client:
|
||||||
raise ValueError("Gemini client not initialized")
|
raise ValueError("Gemini client not initialized")
|
||||||
|
|
||||||
# For text-embedding-004, we can embed a batch
|
|
||||||
res = client.models.embed_content(
|
res = client.models.embed_content(
|
||||||
model=self.model_name,
|
model=self.model_name,
|
||||||
contents=texts,
|
contents=texts,
|
||||||
@@ -131,12 +144,10 @@ class RAGEngine:
|
|||||||
tree = parser.parse(content)
|
tree = parser.parse(content)
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
||||||
# Capture classes and top-level functions
|
|
||||||
for node in tree.root_node.children:
|
for node in tree.root_node.children:
|
||||||
if node.type in ("function_definition", "class_definition"):
|
if node.type in ("function_definition", "class_definition"):
|
||||||
chunks.append(content[node.start_byte:node.end_byte])
|
chunks.append(content[node.start_byte:node.end_byte])
|
||||||
|
|
||||||
# Fallback if no structural chunks found or if file is small
|
|
||||||
if not chunks or len(content) < self.config.chunk_size:
|
if not chunks or len(content) < self.config.chunk_size:
|
||||||
return self._chunk_text(content)
|
return self._chunk_text(content)
|
||||||
return chunks
|
return chunks
|
||||||
@@ -157,7 +168,6 @@ class RAGEngine:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Incremental check: see if we already have this file with the same mtime
|
|
||||||
try:
|
try:
|
||||||
res = self.collection.get(where={"path": file_path}, limit=1, include=["metadatas"])
|
res = self.collection.get(where={"path": file_path}, limit=1, include=["metadatas"])
|
||||||
if res and res["metadatas"] and res["metadatas"][0]:
|
if res and res["metadatas"] and res["metadatas"][0]:
|
||||||
@@ -172,7 +182,6 @@ class RAGEngine:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Remove old entries for this file
|
|
||||||
self.collection.delete(where={"path": file_path})
|
self.collection.delete(where={"path": file_path})
|
||||||
|
|
||||||
if file_path.lower().endswith(".py"):
|
if file_path.lower().endswith(".py"):
|
||||||
|
|||||||
Reference in New Issue
Block a user