diff --git a/tests/test_async_tools.py b/tests/test_async_tools.py new file mode 100644 index 0000000..a7b3bb8 --- /dev/null +++ b/tests/test_async_tools.py @@ -0,0 +1,100 @@ +import asyncio +import time +import pytest +from unittest.mock import patch +from src import ai_client +from src import mcp_client + +@pytest.mark.asyncio +async def test_execute_tool_calls_concurrently_timing(): + """ + Verifies that _execute_tool_calls_concurrently runs tools in parallel. + Total time should be approx 0.5s for 3 tools each taking 0.5s. + """ + # 1. Setup mock tool calls (Gemini style) + class MockGeminiCall: + def __init__(self, name, args): + self.name = name + self.args = args + + calls = [ + MockGeminiCall("read_file", {"path": "file1.txt"}), + MockGeminiCall("read_file", {"path": "file2.txt"}), + MockGeminiCall("read_file", {"path": "file3.txt"}), + ] + + # 2. Mock async_dispatch to sleep for 0.5s + async def mocked_async_dispatch(name, args): + await asyncio.sleep(0.5) + return f"Content of {args.get('path')}" + + # 3. Mock components in ai_client/mcp_client + with ( + patch("src.mcp_client.async_dispatch", side_effect=mocked_async_dispatch), + patch("src.mcp_client.TOOL_NAMES", ["read_file"]), + patch("src.ai_client._append_comms"), + patch("src.ai_client.events.emit") + ): + start_time = time.time() + + results = await ai_client._execute_tool_calls_concurrently( + calls=calls, + base_dir=".", + pre_tool_callback=None, + qa_callback=None, + r_idx=0, + provider="gemini" + ) + + end_time = time.time() + duration = end_time - start_time + + # 4. Assertions + assert len(results) == 3 + # Parallel execution: duration should be < 1.0s (0.5s + overhead) + # Serial execution: duration would be >= 1.5s + print(f"Concurrent tool execution took {duration:.4f} seconds") + assert duration < 1.0, f"Tools executed serially? Took {duration:.4f}s" + assert duration >= 0.5, "Tools didn't even sleep?" + + # Verify results content + for i, (name, call_id, out, orig_name) in enumerate(results): + assert name == "read_file" + assert out == f"Content of file{i+1}.txt" + +@pytest.mark.asyncio +async def test_execute_tool_calls_concurrently_exception_handling(): + """ + Verifies that if one tool call fails, it doesn't crash the whole group if caught, + but currently gather is used WITHOUT return_exceptions=True, so it should re-raise. + """ + class MockGeminiCall: + def __init__(self, name, args): + self.name = name + self.args = args + + calls = [ + MockGeminiCall("read_file", {"path": "success.txt"}), + MockGeminiCall("read_file", {"path": "fail.txt"}), + ] + + async def mocked_async_dispatch(name, args): + if args.get("path") == "fail.txt": + raise ValueError("Simulated tool failure") + return "Success" + + with ( + patch("src.mcp_client.async_dispatch", side_effect=mocked_async_dispatch), + patch("src.mcp_client.TOOL_NAMES", ["read_file"]), + patch("src.ai_client._append_comms"), + patch("src.ai_client.events.emit") + ): + with pytest.raises(ValueError, match="Simulated tool failure"): + await ai_client._execute_tool_calls_concurrently( + calls=calls, + base_dir=".", + pre_tool_callback=None, + qa_callback=None, + r_idx=0, + provider="gemini" + )