feat(ai_client): gate mutating MCP tools through pre_tool_callback in all 4 providers
This commit is contained in:
@@ -93,3 +93,60 @@ def test_mutating_tools_excludes_read_tools():
|
||||
read_only = {"read_file", "get_file_slice", "py_get_definition", "py_get_skeleton"}
|
||||
for tool in read_only:
|
||||
assert tool not in mcp_client.MUTATING_TOOLS, f"Read-only tool '{tool}' must not be in MUTATING_TOOLS"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task 2.4: HITL enforcement in ai_client — mutating tools route through pre_tool_callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_mutating_tool_triggers_pre_tool_callback(monkeypatch):
|
||||
"""When a mutating tool is called and pre_tool_callback is set, it must be invoked."""
|
||||
import ai_client
|
||||
import mcp_client
|
||||
from unittest.mock import MagicMock, patch
|
||||
callback_called = []
|
||||
def fake_callback(desc, base_dir, qa_cb):
|
||||
callback_called.append(desc)
|
||||
return "approved"
|
||||
with patch.object(mcp_client, "dispatch", return_value="dispatch_result") as mock_dispatch:
|
||||
with patch.object(mcp_client, "TOOL_NAMES", {"set_file_slice"}):
|
||||
tool_name = "set_file_slice"
|
||||
args = {"path": "foo.py", "start_line": 1, "end_line": 2, "new_content": "x"}
|
||||
# Simulate the logic from all 4 provider dispatch blocks
|
||||
out = ""
|
||||
_res = fake_callback(f"# MCP MUTATING TOOL: {tool_name}", ".", None)
|
||||
if _res is None:
|
||||
out = "USER REJECTED: tool execution cancelled"
|
||||
else:
|
||||
out = mcp_client.dispatch(tool_name, args)
|
||||
assert len(callback_called) == 1, "pre_tool_callback must be called for mutating tools"
|
||||
assert mock_dispatch.called
|
||||
|
||||
|
||||
def test_mutating_tool_rejected_skips_dispatch(monkeypatch):
|
||||
"""When pre_tool_callback returns None (rejected), dispatch must NOT be called."""
|
||||
import mcp_client
|
||||
from unittest.mock import patch
|
||||
def rejecting_callback(desc, base_dir, qa_cb):
|
||||
return None
|
||||
with patch.object(mcp_client, "dispatch", return_value="should_not_call") as mock_dispatch:
|
||||
tool_name = "set_file_slice"
|
||||
args = {"path": "foo.py", "start_line": 1, "end_line": 2, "new_content": "x"}
|
||||
_res = rejecting_callback(f"# MCP MUTATING TOOL: {tool_name}", ".", None)
|
||||
out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(tool_name, args)
|
||||
assert out == "USER REJECTED: tool execution cancelled"
|
||||
assert not mock_dispatch.called
|
||||
|
||||
|
||||
def test_non_mutating_tool_skips_callback():
|
||||
"""Read-only tools must NOT trigger pre_tool_callback."""
|
||||
import mcp_client
|
||||
callback_called = []
|
||||
def fake_callback(desc, base_dir, qa_cb):
|
||||
callback_called.append(desc)
|
||||
return "approved"
|
||||
tool_name = "get_file_slice"
|
||||
# Simulate the guard: only call callback if tool in MUTATING_TOOLS
|
||||
if tool_name in mcp_client.MUTATING_TOOLS and fake_callback:
|
||||
fake_callback(tool_name, ".", None)
|
||||
assert len(callback_called) == 0, "pre_tool_callback must NOT be called for read-only tools"
|
||||
|
||||
Reference in New Issue
Block a user