From fe0069c046c642db1796e51ea9490012dc741bad Mon Sep 17 00:00:00 2001 From: Ed_ Date: Mon, 4 May 2026 06:53:32 -0400 Subject: [PATCH] feat(rag): Implement indexing and retrieval logic with AppController integration --- src/ai_client.py | 11 ++++ src/app_controller.py | 52 +++++++++++++++- src/rag_engine.py | 66 ++++++++++++++++++++ tests/test_rag_integration.py | 111 ++++++++++++++++++++++++++++++++++ 4 files changed, 237 insertions(+), 3 deletions(-) create mode 100644 tests/test_rag_integration.py diff --git a/src/ai_client.py b/src/ai_client.py index ea89c33..96577f7 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -2245,9 +2245,20 @@ def send( enable_tools: bool = True, stream_callback: Optional[Callable[[str], None]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None, + rag_engine: Optional[Any] = None, ) -> str: monitor = performance_monitor.get_monitor() if monitor.enabled: monitor.start_component("ai_client.send") + + if rag_engine and getattr(rag_engine.config, "enabled", False): + chunks = rag_engine.search(user_message) + if chunks: + context_block = "## Retrieved Context\n\n" + for i, chunk in enumerate(chunks): + path = chunk.get("metadata", {}).get("path", "unknown") + context_block += f"### Chunk {i+1} (Source: {path})\n{chunk.get('document', '')}\n\n" + user_message = context_block + user_message + _append_comms("OUT", "request", {"message": user_message, "system": _get_combined_system_prompt(_active_tool_preset, _active_bias_profile)}) with _send_lock: if _provider == "gemini": diff --git a/src/app_controller.py b/src/app_controller.py index c144737..9d7b48d 100644 --- a/src/app_controller.py +++ b/src/app_controller.py @@ -35,6 +35,7 @@ from src import orchestrator_pm from src import conductor_tech_lead from src import multi_agent_conductor from src import tool_presets +from src import rag_engine from src import theme_2 as theme def hide_tk_root() -> Tk: @@ -202,6 +203,8 @@ class AppController: self._pending_ask_dialog: bool = False self.mcp_config: models.MCPConfiguration = models.MCPConfiguration() self.rag_config: Optional[models.RAGConfig] = None + self.rag_engine: Optional[rag_engine.RAGEngine] = None + self.rag_status: str = 'idle' # AI settings state self._current_provider: str = "gemini" self._current_model: str = "gemini-2.5-flash-lite" @@ -353,6 +356,7 @@ class AppController: 'show_confirm_modal': 'show_confirm_modal', 'mma_epic_input': 'ui_epic_input', 'mma_status': 'mma_status', + 'rag_status': 'rag_status', 'mma_active_tier': 'active_tier', 'ui_new_track_name': 'ui_new_track_name', 'ui_new_track_desc': 'ui_new_track_desc', @@ -560,6 +564,32 @@ class AppController: "payload": status }) + def _set_rag_status(self, status: str) -> None: + """Thread-safe update of rag_status via the GUI task queue.""" + with self._pending_gui_tasks_lock: + self._pending_gui_tasks.append({ + "action": "set_value", + "item": "rag_status", + "value": status + }) + + def _rebuild_rag_index(self) -> None: + """Background thread that re-indexes all files in the current project.""" + if not self.rag_config or not self.rag_config.enabled or not self.rag_engine: + return + + def _run(): + try: + self._set_rag_status("indexing...") + for f in self.files: + path = f.path if hasattr(f, "path") else str(f) + self.rag_engine.index_file(path) + self._set_rag_status("ready") + except Exception as e: + self._set_rag_status(f"error: {e}") + + threading.Thread(target=_run, daemon=True).start() + def _trigger_gui_refresh(self): with self._pending_gui_tasks_lock: self._pending_gui_tasks.append({'action': 'set_comms_dirty'}) @@ -955,6 +985,8 @@ class AppController: else: self.rag_config = models.RAGConfig() + self.rag_engine = rag_engine.RAGEngine(self.rag_config, self.active_project_root) + from src.personas import PersonaManager self.persona_manager = PersonaManager(Path(self.active_project_path).parent if self.active_project_path else None) self.personas = self.persona_manager.load_all() @@ -1448,7 +1480,8 @@ class AppController: stream_callback=lambda text: self._on_ai_stream(text), pre_tool_callback=self._confirm_and_run, qa_callback=ai_client.run_tier4_analysis, - patch_callback=ai_client.run_tier4_patch_callback + patch_callback=ai_client.run_tier4_patch_callback, + rag_engine=self.rag_engine ) self.event_queue.put("response", {"text": resp, "status": "done", "role": "AI"}) except ai_client.ProviderError as e: @@ -1867,7 +1900,7 @@ class AppController: "ts": project_manager.now_ts() }) try: - resp = ai_client.send(stable_md, user_msg, base_dir, self.last_file_items, disc_text) + resp = ai_client.send(stable_md, user_msg, base_dir, self.last_file_items, disc_text, rag_engine=self.rag_engine) if req.auto_add_history: with self._pending_history_adds_lock: self._pending_history_adds.append({ @@ -2024,7 +2057,17 @@ class AppController: self._set_status(f"switched to: {Path(path).stem}") def _refresh_from_project(self) -> None: - self.files = list(self.project.get("files", {}).get("paths", [])) + # Deserialize FileItems in files.paths + raw_paths = self.project.get("files", {}).get("paths", []) + self.files = [] + for p in raw_paths: + if isinstance(p, models.FileItem): + self.files.append(p) + elif isinstance(p, dict): + self.files.append(models.FileItem.from_dict(p)) + else: + self.files.append(models.FileItem(path=str(p))) + self.screenshots = list(self.project.get("screenshots", {}).get("paths", [])) disc_sec = self.project.get("discussion", {}) self.disc_roles = list(disc_sec.get("roles", ["User", "AI", "Vendor API", "System"])) @@ -2090,6 +2133,9 @@ class AppController: self.tool_presets = self.tool_preset_manager.load_all_presets() self.bias_profiles = self.tool_preset_manager.load_all_bias_profiles() + if self.rag_config and self.rag_config.enabled: + self._rebuild_rag_index() + def _apply_preset(self, name: str, scope: str) -> None: print(f"[DEBUG] _apply_preset: name={name}, scope={scope}") if name == "None": diff --git a/src/rag_engine.py b/src/rag_engine.py index c486b27..ab09a4b 100644 --- a/src/rag_engine.py +++ b/src/rag_engine.py @@ -91,6 +91,72 @@ class RAGEngine: metadatas=metadatas ) + def _chunk_text(self, content: str) -> List[str]: + """Character-based chunking with overlap.""" + chunks = [] + if not content: + return chunks + chunk_size = self.config.chunk_size + overlap = self.config.chunk_overlap + start = 0 + while start < len(content): + end = start + chunk_size + chunks.append(content[start:end]) + if end >= len(content): + break + start += (chunk_size - overlap) + return chunks + + def _chunk_code(self, content: str, file_path: str) -> List[str]: + """AST-aware chunking for Python code.""" + try: + from src.file_cache import ASTParser + parser = ASTParser("python") + tree = parser.parse(content) + chunks = [] + + # Capture classes and top-level functions + for node in tree.root_node.children: + if node.type in ("function_definition", "class_definition"): + chunks.append(content[node.start_byte:node.end_byte]) + + # Fallback if no structural chunks found or if file is small + if not chunks or len(content) < self.config.chunk_size: + return self._chunk_text(content) + return chunks + except Exception: + return self._chunk_text(content) + + def index_file(self, file_path: str): + """Reads, chunks, and indexes a file into the vector store.""" + if not self.config.enabled or self.collection == "mock": + return + + full_path = os.path.join(self.base_dir, file_path) + if not os.path.exists(full_path): + return + + try: + with open(full_path, "r", encoding="utf-8", errors="ignore") as f: + content = f.read() + except Exception: + return + + # Remove old entries for this file + self.collection.delete(where={"path": file_path}) + + if file_path.lower().endswith(".py"): + chunks = self._chunk_code(content, file_path) + else: + chunks = self._chunk_text(content) + + if not chunks: + return + + ids = [f"{file_path}_{i}" for i in range(len(chunks))] + metadatas = [{"path": file_path, "chunk": i} for i in range(len(chunks))] + self.add_documents(ids, chunks, metadatas) + def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: if not self.config.enabled or self.collection == "mock": return [] diff --git a/tests/test_rag_integration.py b/tests/test_rag_integration.py new file mode 100644 index 0000000..cd91618 --- /dev/null +++ b/tests/test_rag_integration.py @@ -0,0 +1,111 @@ +import os +import shutil +import tempfile +from unittest.mock import MagicMock, patch +import pytest + +from src.app_controller import AppController +from src import ai_client +from src import events +from src import models + +@pytest.fixture +def mock_project(): + # Use a temporary directory for the mock project + temp_dir = tempfile.mkdtemp() + try: + # Create a minimal manual_slop.toml + with open(os.path.join(temp_dir, "manual_slop.toml"), "w") as f: + f.write('discussion_history = []\n') + yield temp_dir + finally: + # Clean up after test + shutil.rmtree(temp_dir) + +def test_rag_integration(mock_project): + """ + Integration test verifying the flow from AppController through RAGEngine to ai_client. + """ + # 1. Initializes a mock project and AppController. + # We patch several components to avoid side effects during initialization. + with patch('src.app_controller.AppController._fetch_models'), \ + patch('src.models.load_config', return_value={}), \ + patch('src.paths.get_full_path_info', return_value={'logs_dir': {'path': mock_project}, 'scripts_dir': {'path': mock_project}}), \ + patch('src.theme.load_from_config'): + app = AppController() + # Minimal state setup for _handle_request_event + app.ui_global_system_prompt = "" + app.ui_project_system_prompt = "" + app.ui_base_system_prompt = "" + app.ui_use_default_base_prompt = True + app.ui_project_context_marker = "" + app.temperature = 0.0 + app.max_tokens = 100 + app.history_trunc_limit = 1000 + app.top_p = 1.0 + app.ui_agent_tools = {} + app.ui_gemini_cli_path = "gemini" + app.current_model = "gemini-1.5-flash" + app.active_project_path = os.path.join(mock_project, "manual_slop.toml") + + # Ensure the provider is set to 'gemini' for our test + ai_client.set_provider("gemini", "gemini-1.5-flash") + + # 2. Configures a mock RAG setup (enabled=True, provider='mock'). + rag_config = models.RAGConfig( + enabled=True, + vector_store=models.VectorStoreConfig(provider='mock') + ) + app.rag_config = rag_config + + # 3. Mocks rag_engine.search to return a known chunk. + mock_rag_engine = MagicMock() + mock_rag_engine.config = rag_config + mock_rag_engine.search.return_value = [ + {"document": "This is a retrieved chunk from RAG.", "metadata": {"path": "test_file.py"}} + ] + app.rag_engine = mock_rag_engine + + # 4. Mocks ai_client.send to verify that the retrieved chunk appears in the + # message sent to the provider. We use 'wraps' to let the real logic run + # while still having a mock we can inspect. We also mock the internal + # _send_gemini which is what actually "sends to the provider". + with patch('src.ai_client.send', wraps=ai_client.send) as mock_send: + with patch('src.ai_client._send_gemini') as mock_provider: + mock_provider.return_value = "Mock AI Response" + + # Create a UserRequestEvent as if the user clicked "Gen + Send" + event = events.UserRequestEvent( + prompt="Tell me about the code.", + stable_md="Context", + file_items=[], + disc_text="History", + base_dir=mock_project + ) + + # Trigger the request event processing logic in AppController + app._handle_request_event(event) + + # 5. This verifies the wiring from AppController through RAGEngine to ai_client. + + # Verify that ai_client.send was called by AppController + assert mock_send.called + _, kwargs = mock_send.call_args + assert kwargs['rag_engine'] == mock_rag_engine + + # Verify that the internal provider call was made + assert mock_provider.called + + # Extract the user_message passed to the provider call + args, _ = mock_provider.call_args + # _send_gemini(md_content, user_message, ...) -> user_message is index 1 + sent_user_message = args[1] + + # Verify that the RAG chunk was prepended to the original prompt + assert "This is a retrieved chunk from RAG." in sent_user_message + assert "Tell me about the code." in sent_user_message + assert "## Retrieved Context" in sent_user_message + assert "Source: test_file.py" in sent_user_message + + # Verify that rag_engine.search was called with the original prompt + mock_rag_engine.search.assert_called_once_with("Tell me about the code.")