feat(rag): Implement indexing and retrieval logic with AppController integration
This commit is contained in:
@@ -2245,9 +2245,20 @@ def send(
|
||||
enable_tools: bool = True,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None,
|
||||
rag_engine: Optional[Any] = None,
|
||||
) -> str:
|
||||
monitor = performance_monitor.get_monitor()
|
||||
if monitor.enabled: monitor.start_component("ai_client.send")
|
||||
|
||||
if rag_engine and getattr(rag_engine.config, "enabled", False):
|
||||
chunks = rag_engine.search(user_message)
|
||||
if chunks:
|
||||
context_block = "## Retrieved Context\n\n"
|
||||
for i, chunk in enumerate(chunks):
|
||||
path = chunk.get("metadata", {}).get("path", "unknown")
|
||||
context_block += f"### Chunk {i+1} (Source: {path})\n{chunk.get('document', '')}\n\n"
|
||||
user_message = context_block + user_message
|
||||
|
||||
_append_comms("OUT", "request", {"message": user_message, "system": _get_combined_system_prompt(_active_tool_preset, _active_bias_profile)})
|
||||
with _send_lock:
|
||||
if _provider == "gemini":
|
||||
|
||||
+49
-3
@@ -35,6 +35,7 @@ from src import orchestrator_pm
|
||||
from src import conductor_tech_lead
|
||||
from src import multi_agent_conductor
|
||||
from src import tool_presets
|
||||
from src import rag_engine
|
||||
from src import theme_2 as theme
|
||||
|
||||
def hide_tk_root() -> Tk:
|
||||
@@ -202,6 +203,8 @@ class AppController:
|
||||
self._pending_ask_dialog: bool = False
|
||||
self.mcp_config: models.MCPConfiguration = models.MCPConfiguration()
|
||||
self.rag_config: Optional[models.RAGConfig] = None
|
||||
self.rag_engine: Optional[rag_engine.RAGEngine] = None
|
||||
self.rag_status: str = 'idle'
|
||||
# AI settings state
|
||||
self._current_provider: str = "gemini"
|
||||
self._current_model: str = "gemini-2.5-flash-lite"
|
||||
@@ -353,6 +356,7 @@ class AppController:
|
||||
'show_confirm_modal': 'show_confirm_modal',
|
||||
'mma_epic_input': 'ui_epic_input',
|
||||
'mma_status': 'mma_status',
|
||||
'rag_status': 'rag_status',
|
||||
'mma_active_tier': 'active_tier',
|
||||
'ui_new_track_name': 'ui_new_track_name',
|
||||
'ui_new_track_desc': 'ui_new_track_desc',
|
||||
@@ -560,6 +564,32 @@ class AppController:
|
||||
"payload": status
|
||||
})
|
||||
|
||||
def _set_rag_status(self, status: str) -> None:
|
||||
"""Thread-safe update of rag_status via the GUI task queue."""
|
||||
with self._pending_gui_tasks_lock:
|
||||
self._pending_gui_tasks.append({
|
||||
"action": "set_value",
|
||||
"item": "rag_status",
|
||||
"value": status
|
||||
})
|
||||
|
||||
def _rebuild_rag_index(self) -> None:
|
||||
"""Background thread that re-indexes all files in the current project."""
|
||||
if not self.rag_config or not self.rag_config.enabled or not self.rag_engine:
|
||||
return
|
||||
|
||||
def _run():
|
||||
try:
|
||||
self._set_rag_status("indexing...")
|
||||
for f in self.files:
|
||||
path = f.path if hasattr(f, "path") else str(f)
|
||||
self.rag_engine.index_file(path)
|
||||
self._set_rag_status("ready")
|
||||
except Exception as e:
|
||||
self._set_rag_status(f"error: {e}")
|
||||
|
||||
threading.Thread(target=_run, daemon=True).start()
|
||||
|
||||
def _trigger_gui_refresh(self):
|
||||
with self._pending_gui_tasks_lock:
|
||||
self._pending_gui_tasks.append({'action': 'set_comms_dirty'})
|
||||
@@ -955,6 +985,8 @@ class AppController:
|
||||
else:
|
||||
self.rag_config = models.RAGConfig()
|
||||
|
||||
self.rag_engine = rag_engine.RAGEngine(self.rag_config, self.active_project_root)
|
||||
|
||||
from src.personas import PersonaManager
|
||||
self.persona_manager = PersonaManager(Path(self.active_project_path).parent if self.active_project_path else None)
|
||||
self.personas = self.persona_manager.load_all()
|
||||
@@ -1448,7 +1480,8 @@ class AppController:
|
||||
stream_callback=lambda text: self._on_ai_stream(text),
|
||||
pre_tool_callback=self._confirm_and_run,
|
||||
qa_callback=ai_client.run_tier4_analysis,
|
||||
patch_callback=ai_client.run_tier4_patch_callback
|
||||
patch_callback=ai_client.run_tier4_patch_callback,
|
||||
rag_engine=self.rag_engine
|
||||
)
|
||||
self.event_queue.put("response", {"text": resp, "status": "done", "role": "AI"})
|
||||
except ai_client.ProviderError as e:
|
||||
@@ -1867,7 +1900,7 @@ class AppController:
|
||||
"ts": project_manager.now_ts()
|
||||
})
|
||||
try:
|
||||
resp = ai_client.send(stable_md, user_msg, base_dir, self.last_file_items, disc_text)
|
||||
resp = ai_client.send(stable_md, user_msg, base_dir, self.last_file_items, disc_text, rag_engine=self.rag_engine)
|
||||
if req.auto_add_history:
|
||||
with self._pending_history_adds_lock:
|
||||
self._pending_history_adds.append({
|
||||
@@ -2024,7 +2057,17 @@ class AppController:
|
||||
self._set_status(f"switched to: {Path(path).stem}")
|
||||
|
||||
def _refresh_from_project(self) -> None:
|
||||
self.files = list(self.project.get("files", {}).get("paths", []))
|
||||
# Deserialize FileItems in files.paths
|
||||
raw_paths = self.project.get("files", {}).get("paths", [])
|
||||
self.files = []
|
||||
for p in raw_paths:
|
||||
if isinstance(p, models.FileItem):
|
||||
self.files.append(p)
|
||||
elif isinstance(p, dict):
|
||||
self.files.append(models.FileItem.from_dict(p))
|
||||
else:
|
||||
self.files.append(models.FileItem(path=str(p)))
|
||||
|
||||
self.screenshots = list(self.project.get("screenshots", {}).get("paths", []))
|
||||
disc_sec = self.project.get("discussion", {})
|
||||
self.disc_roles = list(disc_sec.get("roles", ["User", "AI", "Vendor API", "System"]))
|
||||
@@ -2090,6 +2133,9 @@ class AppController:
|
||||
self.tool_presets = self.tool_preset_manager.load_all_presets()
|
||||
self.bias_profiles = self.tool_preset_manager.load_all_bias_profiles()
|
||||
|
||||
if self.rag_config and self.rag_config.enabled:
|
||||
self._rebuild_rag_index()
|
||||
|
||||
def _apply_preset(self, name: str, scope: str) -> None:
|
||||
print(f"[DEBUG] _apply_preset: name={name}, scope={scope}")
|
||||
if name == "None":
|
||||
|
||||
@@ -91,6 +91,72 @@ class RAGEngine:
|
||||
metadatas=metadatas
|
||||
)
|
||||
|
||||
def _chunk_text(self, content: str) -> List[str]:
|
||||
"""Character-based chunking with overlap."""
|
||||
chunks = []
|
||||
if not content:
|
||||
return chunks
|
||||
chunk_size = self.config.chunk_size
|
||||
overlap = self.config.chunk_overlap
|
||||
start = 0
|
||||
while start < len(content):
|
||||
end = start + chunk_size
|
||||
chunks.append(content[start:end])
|
||||
if end >= len(content):
|
||||
break
|
||||
start += (chunk_size - overlap)
|
||||
return chunks
|
||||
|
||||
def _chunk_code(self, content: str, file_path: str) -> List[str]:
|
||||
"""AST-aware chunking for Python code."""
|
||||
try:
|
||||
from src.file_cache import ASTParser
|
||||
parser = ASTParser("python")
|
||||
tree = parser.parse(content)
|
||||
chunks = []
|
||||
|
||||
# Capture classes and top-level functions
|
||||
for node in tree.root_node.children:
|
||||
if node.type in ("function_definition", "class_definition"):
|
||||
chunks.append(content[node.start_byte:node.end_byte])
|
||||
|
||||
# Fallback if no structural chunks found or if file is small
|
||||
if not chunks or len(content) < self.config.chunk_size:
|
||||
return self._chunk_text(content)
|
||||
return chunks
|
||||
except Exception:
|
||||
return self._chunk_text(content)
|
||||
|
||||
def index_file(self, file_path: str):
|
||||
"""Reads, chunks, and indexes a file into the vector store."""
|
||||
if not self.config.enabled or self.collection == "mock":
|
||||
return
|
||||
|
||||
full_path = os.path.join(self.base_dir, file_path)
|
||||
if not os.path.exists(full_path):
|
||||
return
|
||||
|
||||
try:
|
||||
with open(full_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||
content = f.read()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
# Remove old entries for this file
|
||||
self.collection.delete(where={"path": file_path})
|
||||
|
||||
if file_path.lower().endswith(".py"):
|
||||
chunks = self._chunk_code(content, file_path)
|
||||
else:
|
||||
chunks = self._chunk_text(content)
|
||||
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
ids = [f"{file_path}_{i}" for i in range(len(chunks))]
|
||||
metadatas = [{"path": file_path, "chunk": i} for i in range(len(chunks))]
|
||||
self.add_documents(ids, chunks, metadatas)
|
||||
|
||||
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
if not self.config.enabled or self.collection == "mock":
|
||||
return []
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from src.app_controller import AppController
|
||||
from src import ai_client
|
||||
from src import events
|
||||
from src import models
|
||||
|
||||
@pytest.fixture
|
||||
def mock_project():
|
||||
# Use a temporary directory for the mock project
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
# Create a minimal manual_slop.toml
|
||||
with open(os.path.join(temp_dir, "manual_slop.toml"), "w") as f:
|
||||
f.write('discussion_history = []\n')
|
||||
yield temp_dir
|
||||
finally:
|
||||
# Clean up after test
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def test_rag_integration(mock_project):
|
||||
"""
|
||||
Integration test verifying the flow from AppController through RAGEngine to ai_client.
|
||||
"""
|
||||
# 1. Initializes a mock project and AppController.
|
||||
# We patch several components to avoid side effects during initialization.
|
||||
with patch('src.app_controller.AppController._fetch_models'), \
|
||||
patch('src.models.load_config', return_value={}), \
|
||||
patch('src.paths.get_full_path_info', return_value={'logs_dir': {'path': mock_project}, 'scripts_dir': {'path': mock_project}}), \
|
||||
patch('src.theme.load_from_config'):
|
||||
app = AppController()
|
||||
# Minimal state setup for _handle_request_event
|
||||
app.ui_global_system_prompt = ""
|
||||
app.ui_project_system_prompt = ""
|
||||
app.ui_base_system_prompt = ""
|
||||
app.ui_use_default_base_prompt = True
|
||||
app.ui_project_context_marker = ""
|
||||
app.temperature = 0.0
|
||||
app.max_tokens = 100
|
||||
app.history_trunc_limit = 1000
|
||||
app.top_p = 1.0
|
||||
app.ui_agent_tools = {}
|
||||
app.ui_gemini_cli_path = "gemini"
|
||||
app.current_model = "gemini-1.5-flash"
|
||||
app.active_project_path = os.path.join(mock_project, "manual_slop.toml")
|
||||
|
||||
# Ensure the provider is set to 'gemini' for our test
|
||||
ai_client.set_provider("gemini", "gemini-1.5-flash")
|
||||
|
||||
# 2. Configures a mock RAG setup (enabled=True, provider='mock').
|
||||
rag_config = models.RAGConfig(
|
||||
enabled=True,
|
||||
vector_store=models.VectorStoreConfig(provider='mock')
|
||||
)
|
||||
app.rag_config = rag_config
|
||||
|
||||
# 3. Mocks rag_engine.search to return a known chunk.
|
||||
mock_rag_engine = MagicMock()
|
||||
mock_rag_engine.config = rag_config
|
||||
mock_rag_engine.search.return_value = [
|
||||
{"document": "This is a retrieved chunk from RAG.", "metadata": {"path": "test_file.py"}}
|
||||
]
|
||||
app.rag_engine = mock_rag_engine
|
||||
|
||||
# 4. Mocks ai_client.send to verify that the retrieved chunk appears in the
|
||||
# message sent to the provider. We use 'wraps' to let the real logic run
|
||||
# while still having a mock we can inspect. We also mock the internal
|
||||
# _send_gemini which is what actually "sends to the provider".
|
||||
with patch('src.ai_client.send', wraps=ai_client.send) as mock_send:
|
||||
with patch('src.ai_client._send_gemini') as mock_provider:
|
||||
mock_provider.return_value = "Mock AI Response"
|
||||
|
||||
# Create a UserRequestEvent as if the user clicked "Gen + Send"
|
||||
event = events.UserRequestEvent(
|
||||
prompt="Tell me about the code.",
|
||||
stable_md="Context",
|
||||
file_items=[],
|
||||
disc_text="History",
|
||||
base_dir=mock_project
|
||||
)
|
||||
|
||||
# Trigger the request event processing logic in AppController
|
||||
app._handle_request_event(event)
|
||||
|
||||
# 5. This verifies the wiring from AppController through RAGEngine to ai_client.
|
||||
|
||||
# Verify that ai_client.send was called by AppController
|
||||
assert mock_send.called
|
||||
_, kwargs = mock_send.call_args
|
||||
assert kwargs['rag_engine'] == mock_rag_engine
|
||||
|
||||
# Verify that the internal provider call was made
|
||||
assert mock_provider.called
|
||||
|
||||
# Extract the user_message passed to the provider call
|
||||
args, _ = mock_provider.call_args
|
||||
# _send_gemini(md_content, user_message, ...) -> user_message is index 1
|
||||
sent_user_message = args[1]
|
||||
|
||||
# Verify that the RAG chunk was prepended to the original prompt
|
||||
assert "This is a retrieved chunk from RAG." in sent_user_message
|
||||
assert "Tell me about the code." in sent_user_message
|
||||
assert "## Retrieved Context" in sent_user_message
|
||||
assert "Source: test_file.py" in sent_user_message
|
||||
|
||||
# Verify that rag_engine.search was called with the original prompt
|
||||
mock_rag_engine.search.assert_called_once_with("Tell me about the code.")
|
||||
Reference in New Issue
Block a user