feat(rag): Implement indexing and retrieval logic with AppController integration
This commit is contained in:
@@ -2245,9 +2245,20 @@ def send(
|
|||||||
enable_tools: bool = True,
|
enable_tools: bool = True,
|
||||||
stream_callback: Optional[Callable[[str], None]] = None,
|
stream_callback: Optional[Callable[[str], None]] = None,
|
||||||
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None,
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None,
|
||||||
|
rag_engine: Optional[Any] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
monitor = performance_monitor.get_monitor()
|
monitor = performance_monitor.get_monitor()
|
||||||
if monitor.enabled: monitor.start_component("ai_client.send")
|
if monitor.enabled: monitor.start_component("ai_client.send")
|
||||||
|
|
||||||
|
if rag_engine and getattr(rag_engine.config, "enabled", False):
|
||||||
|
chunks = rag_engine.search(user_message)
|
||||||
|
if chunks:
|
||||||
|
context_block = "## Retrieved Context\n\n"
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
path = chunk.get("metadata", {}).get("path", "unknown")
|
||||||
|
context_block += f"### Chunk {i+1} (Source: {path})\n{chunk.get('document', '')}\n\n"
|
||||||
|
user_message = context_block + user_message
|
||||||
|
|
||||||
_append_comms("OUT", "request", {"message": user_message, "system": _get_combined_system_prompt(_active_tool_preset, _active_bias_profile)})
|
_append_comms("OUT", "request", {"message": user_message, "system": _get_combined_system_prompt(_active_tool_preset, _active_bias_profile)})
|
||||||
with _send_lock:
|
with _send_lock:
|
||||||
if _provider == "gemini":
|
if _provider == "gemini":
|
||||||
|
|||||||
+49
-3
@@ -35,6 +35,7 @@ from src import orchestrator_pm
|
|||||||
from src import conductor_tech_lead
|
from src import conductor_tech_lead
|
||||||
from src import multi_agent_conductor
|
from src import multi_agent_conductor
|
||||||
from src import tool_presets
|
from src import tool_presets
|
||||||
|
from src import rag_engine
|
||||||
from src import theme_2 as theme
|
from src import theme_2 as theme
|
||||||
|
|
||||||
def hide_tk_root() -> Tk:
|
def hide_tk_root() -> Tk:
|
||||||
@@ -202,6 +203,8 @@ class AppController:
|
|||||||
self._pending_ask_dialog: bool = False
|
self._pending_ask_dialog: bool = False
|
||||||
self.mcp_config: models.MCPConfiguration = models.MCPConfiguration()
|
self.mcp_config: models.MCPConfiguration = models.MCPConfiguration()
|
||||||
self.rag_config: Optional[models.RAGConfig] = None
|
self.rag_config: Optional[models.RAGConfig] = None
|
||||||
|
self.rag_engine: Optional[rag_engine.RAGEngine] = None
|
||||||
|
self.rag_status: str = 'idle'
|
||||||
# AI settings state
|
# AI settings state
|
||||||
self._current_provider: str = "gemini"
|
self._current_provider: str = "gemini"
|
||||||
self._current_model: str = "gemini-2.5-flash-lite"
|
self._current_model: str = "gemini-2.5-flash-lite"
|
||||||
@@ -353,6 +356,7 @@ class AppController:
|
|||||||
'show_confirm_modal': 'show_confirm_modal',
|
'show_confirm_modal': 'show_confirm_modal',
|
||||||
'mma_epic_input': 'ui_epic_input',
|
'mma_epic_input': 'ui_epic_input',
|
||||||
'mma_status': 'mma_status',
|
'mma_status': 'mma_status',
|
||||||
|
'rag_status': 'rag_status',
|
||||||
'mma_active_tier': 'active_tier',
|
'mma_active_tier': 'active_tier',
|
||||||
'ui_new_track_name': 'ui_new_track_name',
|
'ui_new_track_name': 'ui_new_track_name',
|
||||||
'ui_new_track_desc': 'ui_new_track_desc',
|
'ui_new_track_desc': 'ui_new_track_desc',
|
||||||
@@ -560,6 +564,32 @@ class AppController:
|
|||||||
"payload": status
|
"payload": status
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def _set_rag_status(self, status: str) -> None:
|
||||||
|
"""Thread-safe update of rag_status via the GUI task queue."""
|
||||||
|
with self._pending_gui_tasks_lock:
|
||||||
|
self._pending_gui_tasks.append({
|
||||||
|
"action": "set_value",
|
||||||
|
"item": "rag_status",
|
||||||
|
"value": status
|
||||||
|
})
|
||||||
|
|
||||||
|
def _rebuild_rag_index(self) -> None:
|
||||||
|
"""Background thread that re-indexes all files in the current project."""
|
||||||
|
if not self.rag_config or not self.rag_config.enabled or not self.rag_engine:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _run():
|
||||||
|
try:
|
||||||
|
self._set_rag_status("indexing...")
|
||||||
|
for f in self.files:
|
||||||
|
path = f.path if hasattr(f, "path") else str(f)
|
||||||
|
self.rag_engine.index_file(path)
|
||||||
|
self._set_rag_status("ready")
|
||||||
|
except Exception as e:
|
||||||
|
self._set_rag_status(f"error: {e}")
|
||||||
|
|
||||||
|
threading.Thread(target=_run, daemon=True).start()
|
||||||
|
|
||||||
def _trigger_gui_refresh(self):
|
def _trigger_gui_refresh(self):
|
||||||
with self._pending_gui_tasks_lock:
|
with self._pending_gui_tasks_lock:
|
||||||
self._pending_gui_tasks.append({'action': 'set_comms_dirty'})
|
self._pending_gui_tasks.append({'action': 'set_comms_dirty'})
|
||||||
@@ -955,6 +985,8 @@ class AppController:
|
|||||||
else:
|
else:
|
||||||
self.rag_config = models.RAGConfig()
|
self.rag_config = models.RAGConfig()
|
||||||
|
|
||||||
|
self.rag_engine = rag_engine.RAGEngine(self.rag_config, self.active_project_root)
|
||||||
|
|
||||||
from src.personas import PersonaManager
|
from src.personas import PersonaManager
|
||||||
self.persona_manager = PersonaManager(Path(self.active_project_path).parent if self.active_project_path else None)
|
self.persona_manager = PersonaManager(Path(self.active_project_path).parent if self.active_project_path else None)
|
||||||
self.personas = self.persona_manager.load_all()
|
self.personas = self.persona_manager.load_all()
|
||||||
@@ -1448,7 +1480,8 @@ class AppController:
|
|||||||
stream_callback=lambda text: self._on_ai_stream(text),
|
stream_callback=lambda text: self._on_ai_stream(text),
|
||||||
pre_tool_callback=self._confirm_and_run,
|
pre_tool_callback=self._confirm_and_run,
|
||||||
qa_callback=ai_client.run_tier4_analysis,
|
qa_callback=ai_client.run_tier4_analysis,
|
||||||
patch_callback=ai_client.run_tier4_patch_callback
|
patch_callback=ai_client.run_tier4_patch_callback,
|
||||||
|
rag_engine=self.rag_engine
|
||||||
)
|
)
|
||||||
self.event_queue.put("response", {"text": resp, "status": "done", "role": "AI"})
|
self.event_queue.put("response", {"text": resp, "status": "done", "role": "AI"})
|
||||||
except ai_client.ProviderError as e:
|
except ai_client.ProviderError as e:
|
||||||
@@ -1867,7 +1900,7 @@ class AppController:
|
|||||||
"ts": project_manager.now_ts()
|
"ts": project_manager.now_ts()
|
||||||
})
|
})
|
||||||
try:
|
try:
|
||||||
resp = ai_client.send(stable_md, user_msg, base_dir, self.last_file_items, disc_text)
|
resp = ai_client.send(stable_md, user_msg, base_dir, self.last_file_items, disc_text, rag_engine=self.rag_engine)
|
||||||
if req.auto_add_history:
|
if req.auto_add_history:
|
||||||
with self._pending_history_adds_lock:
|
with self._pending_history_adds_lock:
|
||||||
self._pending_history_adds.append({
|
self._pending_history_adds.append({
|
||||||
@@ -2024,7 +2057,17 @@ class AppController:
|
|||||||
self._set_status(f"switched to: {Path(path).stem}")
|
self._set_status(f"switched to: {Path(path).stem}")
|
||||||
|
|
||||||
def _refresh_from_project(self) -> None:
|
def _refresh_from_project(self) -> None:
|
||||||
self.files = list(self.project.get("files", {}).get("paths", []))
|
# Deserialize FileItems in files.paths
|
||||||
|
raw_paths = self.project.get("files", {}).get("paths", [])
|
||||||
|
self.files = []
|
||||||
|
for p in raw_paths:
|
||||||
|
if isinstance(p, models.FileItem):
|
||||||
|
self.files.append(p)
|
||||||
|
elif isinstance(p, dict):
|
||||||
|
self.files.append(models.FileItem.from_dict(p))
|
||||||
|
else:
|
||||||
|
self.files.append(models.FileItem(path=str(p)))
|
||||||
|
|
||||||
self.screenshots = list(self.project.get("screenshots", {}).get("paths", []))
|
self.screenshots = list(self.project.get("screenshots", {}).get("paths", []))
|
||||||
disc_sec = self.project.get("discussion", {})
|
disc_sec = self.project.get("discussion", {})
|
||||||
self.disc_roles = list(disc_sec.get("roles", ["User", "AI", "Vendor API", "System"]))
|
self.disc_roles = list(disc_sec.get("roles", ["User", "AI", "Vendor API", "System"]))
|
||||||
@@ -2090,6 +2133,9 @@ class AppController:
|
|||||||
self.tool_presets = self.tool_preset_manager.load_all_presets()
|
self.tool_presets = self.tool_preset_manager.load_all_presets()
|
||||||
self.bias_profiles = self.tool_preset_manager.load_all_bias_profiles()
|
self.bias_profiles = self.tool_preset_manager.load_all_bias_profiles()
|
||||||
|
|
||||||
|
if self.rag_config and self.rag_config.enabled:
|
||||||
|
self._rebuild_rag_index()
|
||||||
|
|
||||||
def _apply_preset(self, name: str, scope: str) -> None:
|
def _apply_preset(self, name: str, scope: str) -> None:
|
||||||
print(f"[DEBUG] _apply_preset: name={name}, scope={scope}")
|
print(f"[DEBUG] _apply_preset: name={name}, scope={scope}")
|
||||||
if name == "None":
|
if name == "None":
|
||||||
|
|||||||
@@ -91,6 +91,72 @@ class RAGEngine:
|
|||||||
metadatas=metadatas
|
metadatas=metadatas
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _chunk_text(self, content: str) -> List[str]:
|
||||||
|
"""Character-based chunking with overlap."""
|
||||||
|
chunks = []
|
||||||
|
if not content:
|
||||||
|
return chunks
|
||||||
|
chunk_size = self.config.chunk_size
|
||||||
|
overlap = self.config.chunk_overlap
|
||||||
|
start = 0
|
||||||
|
while start < len(content):
|
||||||
|
end = start + chunk_size
|
||||||
|
chunks.append(content[start:end])
|
||||||
|
if end >= len(content):
|
||||||
|
break
|
||||||
|
start += (chunk_size - overlap)
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _chunk_code(self, content: str, file_path: str) -> List[str]:
|
||||||
|
"""AST-aware chunking for Python code."""
|
||||||
|
try:
|
||||||
|
from src.file_cache import ASTParser
|
||||||
|
parser = ASTParser("python")
|
||||||
|
tree = parser.parse(content)
|
||||||
|
chunks = []
|
||||||
|
|
||||||
|
# Capture classes and top-level functions
|
||||||
|
for node in tree.root_node.children:
|
||||||
|
if node.type in ("function_definition", "class_definition"):
|
||||||
|
chunks.append(content[node.start_byte:node.end_byte])
|
||||||
|
|
||||||
|
# Fallback if no structural chunks found or if file is small
|
||||||
|
if not chunks or len(content) < self.config.chunk_size:
|
||||||
|
return self._chunk_text(content)
|
||||||
|
return chunks
|
||||||
|
except Exception:
|
||||||
|
return self._chunk_text(content)
|
||||||
|
|
||||||
|
def index_file(self, file_path: str):
|
||||||
|
"""Reads, chunks, and indexes a file into the vector store."""
|
||||||
|
if not self.config.enabled or self.collection == "mock":
|
||||||
|
return
|
||||||
|
|
||||||
|
full_path = os.path.join(self.base_dir, file_path)
|
||||||
|
if not os.path.exists(full_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(full_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||||
|
content = f.read()
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove old entries for this file
|
||||||
|
self.collection.delete(where={"path": file_path})
|
||||||
|
|
||||||
|
if file_path.lower().endswith(".py"):
|
||||||
|
chunks = self._chunk_code(content, file_path)
|
||||||
|
else:
|
||||||
|
chunks = self._chunk_text(content)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
return
|
||||||
|
|
||||||
|
ids = [f"{file_path}_{i}" for i in range(len(chunks))]
|
||||||
|
metadatas = [{"path": file_path, "chunk": i} for i in range(len(chunks))]
|
||||||
|
self.add_documents(ids, chunks, metadatas)
|
||||||
|
|
||||||
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
||||||
if not self.config.enabled or self.collection == "mock":
|
if not self.config.enabled or self.collection == "mock":
|
||||||
return []
|
return []
|
||||||
|
|||||||
@@ -0,0 +1,111 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.app_controller import AppController
|
||||||
|
from src import ai_client
|
||||||
|
from src import events
|
||||||
|
from src import models
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_project():
|
||||||
|
# Use a temporary directory for the mock project
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
try:
|
||||||
|
# Create a minimal manual_slop.toml
|
||||||
|
with open(os.path.join(temp_dir, "manual_slop.toml"), "w") as f:
|
||||||
|
f.write('discussion_history = []\n')
|
||||||
|
yield temp_dir
|
||||||
|
finally:
|
||||||
|
# Clean up after test
|
||||||
|
shutil.rmtree(temp_dir)
|
||||||
|
|
||||||
|
def test_rag_integration(mock_project):
|
||||||
|
"""
|
||||||
|
Integration test verifying the flow from AppController through RAGEngine to ai_client.
|
||||||
|
"""
|
||||||
|
# 1. Initializes a mock project and AppController.
|
||||||
|
# We patch several components to avoid side effects during initialization.
|
||||||
|
with patch('src.app_controller.AppController._fetch_models'), \
|
||||||
|
patch('src.models.load_config', return_value={}), \
|
||||||
|
patch('src.paths.get_full_path_info', return_value={'logs_dir': {'path': mock_project}, 'scripts_dir': {'path': mock_project}}), \
|
||||||
|
patch('src.theme.load_from_config'):
|
||||||
|
app = AppController()
|
||||||
|
# Minimal state setup for _handle_request_event
|
||||||
|
app.ui_global_system_prompt = ""
|
||||||
|
app.ui_project_system_prompt = ""
|
||||||
|
app.ui_base_system_prompt = ""
|
||||||
|
app.ui_use_default_base_prompt = True
|
||||||
|
app.ui_project_context_marker = ""
|
||||||
|
app.temperature = 0.0
|
||||||
|
app.max_tokens = 100
|
||||||
|
app.history_trunc_limit = 1000
|
||||||
|
app.top_p = 1.0
|
||||||
|
app.ui_agent_tools = {}
|
||||||
|
app.ui_gemini_cli_path = "gemini"
|
||||||
|
app.current_model = "gemini-1.5-flash"
|
||||||
|
app.active_project_path = os.path.join(mock_project, "manual_slop.toml")
|
||||||
|
|
||||||
|
# Ensure the provider is set to 'gemini' for our test
|
||||||
|
ai_client.set_provider("gemini", "gemini-1.5-flash")
|
||||||
|
|
||||||
|
# 2. Configures a mock RAG setup (enabled=True, provider='mock').
|
||||||
|
rag_config = models.RAGConfig(
|
||||||
|
enabled=True,
|
||||||
|
vector_store=models.VectorStoreConfig(provider='mock')
|
||||||
|
)
|
||||||
|
app.rag_config = rag_config
|
||||||
|
|
||||||
|
# 3. Mocks rag_engine.search to return a known chunk.
|
||||||
|
mock_rag_engine = MagicMock()
|
||||||
|
mock_rag_engine.config = rag_config
|
||||||
|
mock_rag_engine.search.return_value = [
|
||||||
|
{"document": "This is a retrieved chunk from RAG.", "metadata": {"path": "test_file.py"}}
|
||||||
|
]
|
||||||
|
app.rag_engine = mock_rag_engine
|
||||||
|
|
||||||
|
# 4. Mocks ai_client.send to verify that the retrieved chunk appears in the
|
||||||
|
# message sent to the provider. We use 'wraps' to let the real logic run
|
||||||
|
# while still having a mock we can inspect. We also mock the internal
|
||||||
|
# _send_gemini which is what actually "sends to the provider".
|
||||||
|
with patch('src.ai_client.send', wraps=ai_client.send) as mock_send:
|
||||||
|
with patch('src.ai_client._send_gemini') as mock_provider:
|
||||||
|
mock_provider.return_value = "Mock AI Response"
|
||||||
|
|
||||||
|
# Create a UserRequestEvent as if the user clicked "Gen + Send"
|
||||||
|
event = events.UserRequestEvent(
|
||||||
|
prompt="Tell me about the code.",
|
||||||
|
stable_md="Context",
|
||||||
|
file_items=[],
|
||||||
|
disc_text="History",
|
||||||
|
base_dir=mock_project
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger the request event processing logic in AppController
|
||||||
|
app._handle_request_event(event)
|
||||||
|
|
||||||
|
# 5. This verifies the wiring from AppController through RAGEngine to ai_client.
|
||||||
|
|
||||||
|
# Verify that ai_client.send was called by AppController
|
||||||
|
assert mock_send.called
|
||||||
|
_, kwargs = mock_send.call_args
|
||||||
|
assert kwargs['rag_engine'] == mock_rag_engine
|
||||||
|
|
||||||
|
# Verify that the internal provider call was made
|
||||||
|
assert mock_provider.called
|
||||||
|
|
||||||
|
# Extract the user_message passed to the provider call
|
||||||
|
args, _ = mock_provider.call_args
|
||||||
|
# _send_gemini(md_content, user_message, ...) -> user_message is index 1
|
||||||
|
sent_user_message = args[1]
|
||||||
|
|
||||||
|
# Verify that the RAG chunk was prepended to the original prompt
|
||||||
|
assert "This is a retrieved chunk from RAG." in sent_user_message
|
||||||
|
assert "Tell me about the code." in sent_user_message
|
||||||
|
assert "## Retrieved Context" in sent_user_message
|
||||||
|
assert "Source: test_file.py" in sent_user_message
|
||||||
|
|
||||||
|
# Verify that rag_engine.search was called with the original prompt
|
||||||
|
mock_rag_engine.search.assert_called_once_with("Tell me about the code.")
|
||||||
Reference in New Issue
Block a user