feat(ai_client): Support external MCP tools and HITL approval
This commit is contained in:
@@ -535,7 +535,7 @@ def get_bias_profile() -> Optional[str]:
|
||||
|
||||
def _build_anthropic_tools() -> list[dict[str, Any]]:
|
||||
raw_tools: list[dict[str, Any]] = []
|
||||
for spec in mcp_client.MCP_TOOL_SPECS:
|
||||
for spec in mcp_client.get_tool_schemas():
|
||||
if _agent_tools.get(spec["name"], True):
|
||||
raw_tools.append({
|
||||
"name": spec["name"],
|
||||
@@ -579,7 +579,7 @@ def _get_anthropic_tools() -> list[dict[str, Any]]:
|
||||
|
||||
def _gemini_tool_declaration() -> Optional[types.Tool]:
|
||||
raw_tools: list[dict[str, Any]] = []
|
||||
for spec in mcp_client.MCP_TOOL_SPECS:
|
||||
for spec in mcp_client.get_tool_schemas():
|
||||
if _agent_tools.get(spec["name"], True):
|
||||
raw_tools.append({
|
||||
"name": spec["name"],
|
||||
@@ -715,10 +715,15 @@ async def _execute_single_tool_call_async(
|
||||
tool_executed = True
|
||||
|
||||
if not tool_executed:
|
||||
if name and name in mcp_client.TOOL_NAMES:
|
||||
is_native = name in mcp_client.TOOL_NAMES
|
||||
ext_tools = mcp_client.get_external_mcp_manager().get_all_tools()
|
||||
is_external = name in ext_tools
|
||||
if name and (is_native or is_external):
|
||||
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
|
||||
if name in mcp_client.MUTATING_TOOLS and approval_mode != "auto" and pre_tool_callback:
|
||||
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
|
||||
should_approve = (name in mcp_client.MUTATING_TOOLS or is_external) and approval_mode != "auto" and pre_tool_callback
|
||||
if should_approve:
|
||||
label = "MCP MUTATING" if is_native else "EXTERNAL MCP"
|
||||
desc = f"# {label} TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
|
||||
_res = await asyncio.to_thread(pre_tool_callback, desc, base_dir, qa_callback)
|
||||
out = "USER REJECTED: tool execution cancelled" if _res is None else await mcp_client.async_dispatch(name, args)
|
||||
else:
|
||||
@@ -816,7 +821,7 @@ def _build_file_diff_text(changed_items: list[dict[str, Any]]) -> str:
|
||||
|
||||
def _build_deepseek_tools() -> list[dict[str, Any]]:
|
||||
raw_tools: list[dict[str, Any]] = []
|
||||
for spec in mcp_client.MCP_TOOL_SPECS:
|
||||
for spec in mcp_client.get_tool_schemas():
|
||||
if _agent_tools.get(spec["name"], True):
|
||||
raw_tools.append({
|
||||
"name": spec["name"],
|
||||
|
||||
67
tests/test_external_mcp_e2e.py
Normal file
67
tests/test_external_mcp_e2e.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pytest
|
||||
from src.app_controller import AppController
|
||||
from src import mcp_client
|
||||
from src import ai_client
|
||||
from src import models
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_external_mcp_e2e_refresh_and_call(tmp_path, monkeypatch):
|
||||
# 1. Setup mock config and mock server script
|
||||
config_file = tmp_path / "config.toml"
|
||||
monkeypatch.setattr(models, "CONFIG_PATH", str(config_file))
|
||||
|
||||
mock_script = Path("scripts/mock_mcp_server.py").absolute()
|
||||
|
||||
mcp_config_file = tmp_path / "mcp_config.json"
|
||||
mcp_data = {
|
||||
"mcpServers": {
|
||||
"e2e-server": {
|
||||
"command": "python",
|
||||
"args": [str(mock_script)],
|
||||
"auto_start": True
|
||||
}
|
||||
}
|
||||
}
|
||||
mcp_config_file.write_text(json.dumps(mcp_data))
|
||||
|
||||
config_content = f"""
|
||||
[ai]
|
||||
mcp_config_path = "{mcp_config_file.as_posix()}"
|
||||
[projects]
|
||||
paths = []
|
||||
active = ""
|
||||
"""
|
||||
config_file.write_text(config_content)
|
||||
|
||||
# 2. Initialize AppController
|
||||
ctrl = AppController()
|
||||
monkeypatch.setattr(ctrl, "_load_active_project", lambda: None)
|
||||
ctrl.project = {}
|
||||
|
||||
# We need to mock start_services or just manually call what we need
|
||||
ctrl.init_state()
|
||||
|
||||
# Trigger refresh event manually (since we don't have the background thread running in unit test)
|
||||
await ctrl.refresh_external_mcps()
|
||||
|
||||
# 3. Verify tools are discovered
|
||||
manager = mcp_client.get_external_mcp_manager()
|
||||
tools = manager.get_all_tools()
|
||||
assert "echo" in tools
|
||||
|
||||
# 4. Mock pre_tool_callback to auto-approve
|
||||
mock_pre_tool = lambda desc, base, qa: "Approved"
|
||||
|
||||
# 5. Call execute_single_tool_call_async (via ai_client)
|
||||
name, cid, out, orig = await ai_client._execute_single_tool_call_async(
|
||||
"echo", {"message": "hello"}, "id1", ".", mock_pre_tool, None, 0
|
||||
)
|
||||
|
||||
assert "ECHO: {'message': 'hello'}" in out
|
||||
|
||||
# Cleanup
|
||||
await manager.stop_all()
|
||||
62
tests/test_external_mcp_hitl.py
Normal file
62
tests/test_external_mcp_hitl.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import asyncio
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from src import ai_client
|
||||
from src import mcp_client
|
||||
from src import models
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_external_mcp_hitl_approval():
|
||||
# 1. Setup mock manager and server
|
||||
mock_manager = mcp_client.ExternalMCPManager()
|
||||
mock_server = AsyncMock()
|
||||
mock_server.name = "test-server"
|
||||
mock_server.tools = {"ext_tool": {"name": "ext_tool", "description": "desc"}}
|
||||
mock_server.call_tool.return_value = "Success"
|
||||
mock_manager.servers["test-server"] = mock_server
|
||||
|
||||
with patch("src.mcp_client.get_external_mcp_manager", return_value=mock_manager):
|
||||
# 2. Setup ai_client callbacks
|
||||
mock_pre_tool = MagicMock(return_value="Approved")
|
||||
ai_client.confirm_and_run_callback = mock_pre_tool
|
||||
|
||||
# 3. Call _execute_single_tool_call_async
|
||||
name = "ext_tool"
|
||||
args = {"arg1": "val1"}
|
||||
call_id = "call_123"
|
||||
base_dir = "."
|
||||
|
||||
# We need to pass the callback to the function
|
||||
name, cid, out, orig_name = await ai_client._execute_single_tool_call_async(
|
||||
name, args, call_id, base_dir, mock_pre_tool, None, 0
|
||||
)
|
||||
|
||||
# 4. Assertions
|
||||
assert out == "Success"
|
||||
mock_pre_tool.assert_called_once()
|
||||
# Check description contains EXTERNAL MCP
|
||||
call_args = mock_pre_tool.call_args[0]
|
||||
assert "EXTERNAL MCP TOOL: ext_tool" in call_args[0]
|
||||
assert "arg1: 'val1'" in call_args[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_external_mcp_hitl_rejection():
|
||||
mock_manager = mcp_client.ExternalMCPManager()
|
||||
mock_server = AsyncMock()
|
||||
mock_server.name = "test-server"
|
||||
mock_server.tools = {"ext_tool": {"name": "ext_tool"}}
|
||||
mock_manager.servers["test-server"] = mock_server
|
||||
|
||||
with patch("src.mcp_client.get_external_mcp_manager", return_value=mock_manager):
|
||||
mock_pre_tool = MagicMock(return_value=None) # Rejection
|
||||
|
||||
name = "ext_tool"
|
||||
args = {"arg1": "val1"}
|
||||
|
||||
name, cid, out, orig_name = await ai_client._execute_single_tool_call_async(
|
||||
name, args, "id", ".", mock_pre_tool, None, 0
|
||||
)
|
||||
|
||||
assert out == "USER REJECTED: tool execution cancelled"
|
||||
mock_server.call_tool.assert_not_called()
|
||||
Reference in New Issue
Block a user