63 lines
2.3 KiB
Python
63 lines
2.3 KiB
Python
import asyncio
|
|
import json
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch, AsyncMock
|
|
from src import ai_client
|
|
from src import mcp_client
|
|
from src import models
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_external_mcp_hitl_approval():
|
|
# 1. Setup mock manager and server
|
|
mock_manager = mcp_client.ExternalMCPManager()
|
|
mock_server = AsyncMock()
|
|
mock_server.name = "test-server"
|
|
mock_server.tools = {"ext_tool": {"name": "ext_tool", "description": "desc"}}
|
|
mock_server.call_tool.return_value = "Success"
|
|
mock_manager.servers["test-server"] = mock_server
|
|
|
|
with patch("src.mcp_client.get_external_mcp_manager", return_value=mock_manager):
|
|
# 2. Setup ai_client callbacks
|
|
mock_pre_tool = MagicMock(return_value="Approved")
|
|
ai_client.confirm_and_run_callback = mock_pre_tool
|
|
|
|
# 3. Call _execute_single_tool_call_async
|
|
name = "ext_tool"
|
|
args = {"arg1": "val1"}
|
|
call_id = "call_123"
|
|
base_dir = "."
|
|
|
|
# We need to pass the callback to the function
|
|
name, cid, out, orig_name = await ai_client._execute_single_tool_call_async(
|
|
name, args, call_id, base_dir, mock_pre_tool, None, 0
|
|
)
|
|
|
|
# 4. Assertions
|
|
assert out == "Success"
|
|
mock_pre_tool.assert_called_once()
|
|
# Check description contains EXTERNAL MCP
|
|
call_args = mock_pre_tool.call_args[0]
|
|
assert "EXTERNAL MCP TOOL: ext_tool" in call_args[0]
|
|
assert "arg1: 'val1'" in call_args[0]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_external_mcp_hitl_rejection():
|
|
mock_manager = mcp_client.ExternalMCPManager()
|
|
mock_server = AsyncMock()
|
|
mock_server.name = "test-server"
|
|
mock_server.tools = {"ext_tool": {"name": "ext_tool"}}
|
|
mock_manager.servers["test-server"] = mock_server
|
|
|
|
with patch("src.mcp_client.get_external_mcp_manager", return_value=mock_manager):
|
|
mock_pre_tool = MagicMock(return_value=None) # Rejection
|
|
|
|
name = "ext_tool"
|
|
args = {"arg1": "val1"}
|
|
|
|
name, cid, out, orig_name = await ai_client._execute_single_tool_call_async(
|
|
name, args, "id", ".", mock_pre_tool, None, 0
|
|
)
|
|
|
|
assert out == "USER REJECTED: tool execution cancelled"
|
|
mock_server.call_tool.assert_not_called()
|