test(ai_client): Add tests for concurrent tool execution.
This commit is contained in:
100
tests/test_async_tools.py
Normal file
100
tests/test_async_tools.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from src import ai_client
|
||||||
|
from src import mcp_client
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_tool_calls_concurrently_timing():
|
||||||
|
"""
|
||||||
|
Verifies that _execute_tool_calls_concurrently runs tools in parallel.
|
||||||
|
Total time should be approx 0.5s for 3 tools each taking 0.5s.
|
||||||
|
"""
|
||||||
|
# 1. Setup mock tool calls (Gemini style)
|
||||||
|
class MockGeminiCall:
|
||||||
|
def __init__(self, name, args):
|
||||||
|
self.name = name
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
calls = [
|
||||||
|
MockGeminiCall("read_file", {"path": "file1.txt"}),
|
||||||
|
MockGeminiCall("read_file", {"path": "file2.txt"}),
|
||||||
|
MockGeminiCall("read_file", {"path": "file3.txt"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 2. Mock async_dispatch to sleep for 0.5s
|
||||||
|
async def mocked_async_dispatch(name, args):
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
return f"Content of {args.get('path')}"
|
||||||
|
|
||||||
|
# 3. Mock components in ai_client/mcp_client
|
||||||
|
with (
|
||||||
|
patch("src.mcp_client.async_dispatch", side_effect=mocked_async_dispatch),
|
||||||
|
patch("src.mcp_client.TOOL_NAMES", ["read_file"]),
|
||||||
|
patch("src.ai_client._append_comms"),
|
||||||
|
patch("src.ai_client.events.emit")
|
||||||
|
):
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
results = await ai_client._execute_tool_calls_concurrently(
|
||||||
|
calls=calls,
|
||||||
|
base_dir=".",
|
||||||
|
pre_tool_callback=None,
|
||||||
|
qa_callback=None,
|
||||||
|
r_idx=0,
|
||||||
|
provider="gemini"
|
||||||
|
)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
duration = end_time - start_time
|
||||||
|
|
||||||
|
# 4. Assertions
|
||||||
|
assert len(results) == 3
|
||||||
|
# Parallel execution: duration should be < 1.0s (0.5s + overhead)
|
||||||
|
# Serial execution: duration would be >= 1.5s
|
||||||
|
print(f"Concurrent tool execution took {duration:.4f} seconds")
|
||||||
|
assert duration < 1.0, f"Tools executed serially? Took {duration:.4f}s"
|
||||||
|
assert duration >= 0.5, "Tools didn't even sleep?"
|
||||||
|
|
||||||
|
# Verify results content
|
||||||
|
for i, (name, call_id, out, orig_name) in enumerate(results):
|
||||||
|
assert name == "read_file"
|
||||||
|
assert out == f"Content of file{i+1}.txt"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_tool_calls_concurrently_exception_handling():
|
||||||
|
"""
|
||||||
|
Verifies that if one tool call fails, it doesn't crash the whole group if caught,
|
||||||
|
but currently gather is used WITHOUT return_exceptions=True, so it should re-raise.
|
||||||
|
"""
|
||||||
|
class MockGeminiCall:
|
||||||
|
def __init__(self, name, args):
|
||||||
|
self.name = name
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
calls = [
|
||||||
|
MockGeminiCall("read_file", {"path": "success.txt"}),
|
||||||
|
MockGeminiCall("read_file", {"path": "fail.txt"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def mocked_async_dispatch(name, args):
|
||||||
|
if args.get("path") == "fail.txt":
|
||||||
|
raise ValueError("Simulated tool failure")
|
||||||
|
return "Success"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("src.mcp_client.async_dispatch", side_effect=mocked_async_dispatch),
|
||||||
|
patch("src.mcp_client.TOOL_NAMES", ["read_file"]),
|
||||||
|
patch("src.ai_client._append_comms"),
|
||||||
|
patch("src.ai_client.events.emit")
|
||||||
|
):
|
||||||
|
with pytest.raises(ValueError, match="Simulated tool failure"):
|
||||||
|
await ai_client._execute_tool_calls_concurrently(
|
||||||
|
calls=calls,
|
||||||
|
base_dir=".",
|
||||||
|
pre_tool_callback=None,
|
||||||
|
qa_callback=None,
|
||||||
|
r_idx=0,
|
||||||
|
provider="gemini"
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user