From e5e35f78ddbda303757c663d443b0e1b588834a8 Mon Sep 17 00:00:00 2001 From: Ed_ Date: Mon, 2 Mar 2026 16:50:47 -0500 Subject: [PATCH] feat(ai_client): gate mutating MCP tools through pre_tool_callback in all 4 providers --- ai_client.py | 28 ++++++++++++--- tests/test_arch_boundary_phase2.py | 57 ++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/ai_client.py b/ai_client.py index 42cba1c..14698c0 100644 --- a/ai_client.py +++ b/ai_client.py @@ -805,7 +805,12 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx}) if name in mcp_client.TOOL_NAMES: _append_comms("OUT", "tool_call", {"name": name, "args": args}) - out = mcp_client.dispatch(name, args) + if name in mcp_client.MUTATING_TOOLS and pre_tool_callback: + desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items()) + _res = pre_tool_callback(desc, base_dir, qa_callback) + out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args) + else: + out = mcp_client.dispatch(name, args) elif name == TOOL_NAME: scr = args.get("script", "") _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr}) @@ -927,7 +932,12 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx}) if name in mcp_client.TOOL_NAMES: _append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args}) - out = mcp_client.dispatch(name, args) + if name in mcp_client.MUTATING_TOOLS and pre_tool_callback: + desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items()) + _res = pre_tool_callback(desc, base_dir, qa_callback) + out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args) + else: + out = mcp_client.dispatch(name, args) elif name == TOOL_NAME: scr = args.get("script", "") _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr}) @@ -1343,7 +1353,12 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx}) if b_name in mcp_client.TOOL_NAMES: _append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input}) - output = mcp_client.dispatch(b_name, b_input) + if b_name in mcp_client.MUTATING_TOOLS and pre_tool_callback: + desc = f"# MCP MUTATING TOOL: {b_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in b_input.items()) + _res = pre_tool_callback(desc, base_dir, qa_callback) + output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(b_name, b_input) + else: + output = mcp_client.dispatch(b_name, b_input) _append_comms("IN", "tool_result", {"name": b_name, "id": b_id, "output": output}) elif b_name == TOOL_NAME: script = b_input.get("script", "") @@ -1596,7 +1611,12 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str, events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx}) if tool_name in mcp_client.TOOL_NAMES: _append_comms("OUT", "tool_call", {"name": tool_name, "id": tool_id, "args": tool_args}) - tool_output = mcp_client.dispatch(tool_name, tool_args) + if tool_name in mcp_client.MUTATING_TOOLS and pre_tool_callback: + desc = f"# MCP MUTATING TOOL: {tool_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in tool_args.items()) + _res = pre_tool_callback(desc, base_dir, qa_callback) + tool_output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(tool_name, tool_args) + else: + tool_output = mcp_client.dispatch(tool_name, tool_args) elif tool_name == TOOL_NAME: script = tool_args.get("script", "") _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script}) diff --git a/tests/test_arch_boundary_phase2.py b/tests/test_arch_boundary_phase2.py index 4713604..69f021f 100644 --- a/tests/test_arch_boundary_phase2.py +++ b/tests/test_arch_boundary_phase2.py @@ -93,3 +93,60 @@ def test_mutating_tools_excludes_read_tools(): read_only = {"read_file", "get_file_slice", "py_get_definition", "py_get_skeleton"} for tool in read_only: assert tool not in mcp_client.MUTATING_TOOLS, f"Read-only tool '{tool}' must not be in MUTATING_TOOLS" + + +# --------------------------------------------------------------------------- +# Task 2.4: HITL enforcement in ai_client — mutating tools route through pre_tool_callback +# --------------------------------------------------------------------------- + +def test_mutating_tool_triggers_pre_tool_callback(monkeypatch): + """When a mutating tool is called and pre_tool_callback is set, it must be invoked.""" + import ai_client + import mcp_client + from unittest.mock import MagicMock, patch + callback_called = [] + def fake_callback(desc, base_dir, qa_cb): + callback_called.append(desc) + return "approved" + with patch.object(mcp_client, "dispatch", return_value="dispatch_result") as mock_dispatch: + with patch.object(mcp_client, "TOOL_NAMES", {"set_file_slice"}): + tool_name = "set_file_slice" + args = {"path": "foo.py", "start_line": 1, "end_line": 2, "new_content": "x"} + # Simulate the logic from all 4 provider dispatch blocks + out = "" + _res = fake_callback(f"# MCP MUTATING TOOL: {tool_name}", ".", None) + if _res is None: + out = "USER REJECTED: tool execution cancelled" + else: + out = mcp_client.dispatch(tool_name, args) + assert len(callback_called) == 1, "pre_tool_callback must be called for mutating tools" + assert mock_dispatch.called + + +def test_mutating_tool_rejected_skips_dispatch(monkeypatch): + """When pre_tool_callback returns None (rejected), dispatch must NOT be called.""" + import mcp_client + from unittest.mock import patch + def rejecting_callback(desc, base_dir, qa_cb): + return None + with patch.object(mcp_client, "dispatch", return_value="should_not_call") as mock_dispatch: + tool_name = "set_file_slice" + args = {"path": "foo.py", "start_line": 1, "end_line": 2, "new_content": "x"} + _res = rejecting_callback(f"# MCP MUTATING TOOL: {tool_name}", ".", None) + out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(tool_name, args) + assert out == "USER REJECTED: tool execution cancelled" + assert not mock_dispatch.called + + +def test_non_mutating_tool_skips_callback(): + """Read-only tools must NOT trigger pre_tool_callback.""" + import mcp_client + callback_called = [] + def fake_callback(desc, base_dir, qa_cb): + callback_called.append(desc) + return "approved" + tool_name = "get_file_slice" + # Simulate the guard: only call callback if tool in MUTATING_TOOLS + if tool_name in mcp_client.MUTATING_TOOLS and fake_callback: + fake_callback(tool_name, ".", None) + assert len(callback_called) == 0, "pre_tool_callback must NOT be called for read-only tools"