diff --git a/src/ai_client.py b/src/ai_client.py index 858caf1d..8699c6de 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -3260,9 +3260,8 @@ def send( if chunks: context_block = "## Retrieved Context\n\n" for i, chunk in enumerate(chunks): - chunk_meta = chunk["metadata"] if "metadata" in chunk else {} - path = chunk_meta["path"] if "path" in chunk_meta else "unknown" - doc = chunk["document"] if "document" in chunk else "" + path = chunk.path if chunk.path else "unknown" + doc = chunk.document context_block += f"### Chunk {i+1} (Source: {path})\n{doc}\n\n" user_message = context_block + user_message diff --git a/src/rag_engine.py b/src/rag_engine.py index 12be8046..a9880edd 100644 --- a/src/rag_engine.py +++ b/src/rag_engine.py @@ -18,6 +18,7 @@ from src.file_cache import ASTParser @dataclass(frozen=True) class RAGChunk: + id: str = "" document: str = "" path: str = "" score: float = 0.0 @@ -364,7 +365,7 @@ class RAGEngine: return asyncio.run(_async_search_mcp()) - def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]: + def search(self, query: str, top_k: int = 5) -> List["RAGChunk"]: """ [C: tests/mock_concurrent_mma.py:main, tests/test_rag_engine.py:test_rag_engine_chroma] """ @@ -381,12 +382,16 @@ class RAGEngine: ret = [] if results and results["ids"] and results["ids"][0]: for i in range(len(results["ids"][0])): - ret.append({ - "id": results["ids"][0][i], - "document": results["documents"][0][i], - "metadata": results["metadatas"][0][i] if results["metadatas"] else {}, - "distance": results["distances"][0][i] if "distances" in results and results["distances"] else 0.0 - }) + raw_meta = results["metadatas"][0][i] if results["metadatas"] else {} + distance = results["distances"][0][i] if "distances" in results and results["distances"] else 0.0 + raw_path = raw_meta.get("path", "") if isinstance(raw_meta, dict) else "" + ret.append(RAGChunk( + id=results["ids"][0][i], + document=results["documents"][0][i], + path=raw_path, + score=1.0 - float(distance), + metadata=Metadata.from_dict(raw_meta) if isinstance(raw_meta, dict) else Metadata(), + )) return ret def delete_documents(self, ids: List[str]): diff --git a/tests/test_rag_engine.py b/tests/test_rag_engine.py index eaa6293e..d3c883c1 100644 --- a/tests/test_rag_engine.py +++ b/tests/test_rag_engine.py @@ -58,7 +58,7 @@ def test_rag_engine_chroma(mock_get_chroma, mock_embed): results = engine.search("hello", top_k=1) assert len(results) == 1 - assert results[0]["id"] == "doc1" + assert results[0].id == "doc1" engine.delete_documents(["doc1"]) mock_collection.delete.assert_called_once_with(ids=["doc1"])