feat(ai_client): gate mutating MCP tools through pre_tool_callback in all 4 providers

This commit is contained in:
2026-03-02 16:50:47 -05:00
parent 8e6462d10b
commit e5e35f78dd
2 changed files with 81 additions and 4 deletions

View File

@@ -805,6 +805,11 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx}) events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
if name in mcp_client.TOOL_NAMES: if name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": name, "args": args}) _append_comms("OUT", "tool_call", {"name": name, "args": args})
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
_res = pre_tool_callback(desc, base_dir, qa_callback)
out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args)
else:
out = mcp_client.dispatch(name, args) out = mcp_client.dispatch(name, args)
elif name == TOOL_NAME: elif name == TOOL_NAME:
scr = args.get("script", "") scr = args.get("script", "")
@@ -927,6 +932,11 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx}) events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
if name in mcp_client.TOOL_NAMES: if name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args}) _append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
_res = pre_tool_callback(desc, base_dir, qa_callback)
out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args)
else:
out = mcp_client.dispatch(name, args) out = mcp_client.dispatch(name, args)
elif name == TOOL_NAME: elif name == TOOL_NAME:
scr = args.get("script", "") scr = args.get("script", "")
@@ -1343,6 +1353,11 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx}) events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx})
if b_name in mcp_client.TOOL_NAMES: if b_name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input}) _append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input})
if b_name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
desc = f"# MCP MUTATING TOOL: {b_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in b_input.items())
_res = pre_tool_callback(desc, base_dir, qa_callback)
output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(b_name, b_input)
else:
output = mcp_client.dispatch(b_name, b_input) output = mcp_client.dispatch(b_name, b_input)
_append_comms("IN", "tool_result", {"name": b_name, "id": b_id, "output": output}) _append_comms("IN", "tool_result", {"name": b_name, "id": b_id, "output": output})
elif b_name == TOOL_NAME: elif b_name == TOOL_NAME:
@@ -1596,6 +1611,11 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx}) events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx})
if tool_name in mcp_client.TOOL_NAMES: if tool_name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": tool_name, "id": tool_id, "args": tool_args}) _append_comms("OUT", "tool_call", {"name": tool_name, "id": tool_id, "args": tool_args})
if tool_name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
desc = f"# MCP MUTATING TOOL: {tool_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in tool_args.items())
_res = pre_tool_callback(desc, base_dir, qa_callback)
tool_output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(tool_name, tool_args)
else:
tool_output = mcp_client.dispatch(tool_name, tool_args) tool_output = mcp_client.dispatch(tool_name, tool_args)
elif tool_name == TOOL_NAME: elif tool_name == TOOL_NAME:
script = tool_args.get("script", "") script = tool_args.get("script", "")

View File

@@ -93,3 +93,60 @@ def test_mutating_tools_excludes_read_tools():
read_only = {"read_file", "get_file_slice", "py_get_definition", "py_get_skeleton"} read_only = {"read_file", "get_file_slice", "py_get_definition", "py_get_skeleton"}
for tool in read_only: for tool in read_only:
assert tool not in mcp_client.MUTATING_TOOLS, f"Read-only tool '{tool}' must not be in MUTATING_TOOLS" assert tool not in mcp_client.MUTATING_TOOLS, f"Read-only tool '{tool}' must not be in MUTATING_TOOLS"
# ---------------------------------------------------------------------------
# Task 2.4: HITL enforcement in ai_client — mutating tools route through pre_tool_callback
# ---------------------------------------------------------------------------
def test_mutating_tool_triggers_pre_tool_callback(monkeypatch):
"""When a mutating tool is called and pre_tool_callback is set, it must be invoked."""
import ai_client
import mcp_client
from unittest.mock import MagicMock, patch
callback_called = []
def fake_callback(desc, base_dir, qa_cb):
callback_called.append(desc)
return "approved"
with patch.object(mcp_client, "dispatch", return_value="dispatch_result") as mock_dispatch:
with patch.object(mcp_client, "TOOL_NAMES", {"set_file_slice"}):
tool_name = "set_file_slice"
args = {"path": "foo.py", "start_line": 1, "end_line": 2, "new_content": "x"}
# Simulate the logic from all 4 provider dispatch blocks
out = ""
_res = fake_callback(f"# MCP MUTATING TOOL: {tool_name}", ".", None)
if _res is None:
out = "USER REJECTED: tool execution cancelled"
else:
out = mcp_client.dispatch(tool_name, args)
assert len(callback_called) == 1, "pre_tool_callback must be called for mutating tools"
assert mock_dispatch.called
def test_mutating_tool_rejected_skips_dispatch(monkeypatch):
"""When pre_tool_callback returns None (rejected), dispatch must NOT be called."""
import mcp_client
from unittest.mock import patch
def rejecting_callback(desc, base_dir, qa_cb):
return None
with patch.object(mcp_client, "dispatch", return_value="should_not_call") as mock_dispatch:
tool_name = "set_file_slice"
args = {"path": "foo.py", "start_line": 1, "end_line": 2, "new_content": "x"}
_res = rejecting_callback(f"# MCP MUTATING TOOL: {tool_name}", ".", None)
out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(tool_name, args)
assert out == "USER REJECTED: tool execution cancelled"
assert not mock_dispatch.called
def test_non_mutating_tool_skips_callback():
"""Read-only tools must NOT trigger pre_tool_callback."""
import mcp_client
callback_called = []
def fake_callback(desc, base_dir, qa_cb):
callback_called.append(desc)
return "approved"
tool_name = "get_file_slice"
# Simulate the guard: only call callback if tool in MUTATING_TOOLS
if tool_name in mcp_client.MUTATING_TOOLS and fake_callback:
fake_callback(tool_name, ".", None)
assert len(callback_called) == 0, "pre_tool_callback must NOT be called for read-only tools"