feat(mma): Implement HITL execution clutch and step-mode
This commit is contained in:
77
ai_client.py
77
ai_client.py
@@ -668,7 +668,8 @@ def _get_gemini_history_list(chat):
|
|||||||
|
|
||||||
def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
||||||
file_items: list[dict] | None = None,
|
file_items: list[dict] | None = None,
|
||||||
discussion_history: str = "") -> str:
|
discussion_history: str = "",
|
||||||
|
pre_tool_callback = None) -> str:
|
||||||
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at
|
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -747,7 +748,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
_gemini_cache = None
|
_gemini_cache = None
|
||||||
_gemini_cache_created_at = None
|
_gemini_cache_created_at = None
|
||||||
_append_comms("OUT", "request", {"message": f"[CACHE FAILED] {type(e).__name__}: {e} — falling back to inline system_instruction"})
|
_append_comms("OUT", "request", {"message": f"[CACHE FAILED] {type(e).__name__}: {e} \u2014 falling back to inline system_instruction"})
|
||||||
|
|
||||||
kwargs = {"model": _model, "config": chat_config}
|
kwargs = {"model": _model, "config": chat_config}
|
||||||
if old_history:
|
if old_history:
|
||||||
@@ -767,7 +768,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
|||||||
_cumulative_tool_bytes = 0
|
_cumulative_tool_bytes = 0
|
||||||
|
|
||||||
# Strip stale file refreshes and truncate old tool outputs ONCE before
|
# Strip stale file refreshes and truncate old tool outputs ONCE before
|
||||||
# entering the tool loop (not per-round — history entries don't change).
|
# entering the tool loop (not per-round \u2014 history entries don't change).
|
||||||
if _gemini_chat and _get_gemini_history_list(_gemini_chat):
|
if _gemini_chat and _get_gemini_history_list(_gemini_chat):
|
||||||
for msg in _get_gemini_history_list(_gemini_chat):
|
for msg in _get_gemini_history_list(_gemini_chat):
|
||||||
if msg.role == "user" and hasattr(msg, "parts"):
|
if msg.role == "user" and hasattr(msg, "parts"):
|
||||||
@@ -830,6 +831,16 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
|||||||
f_resps, log = [], []
|
f_resps, log = [], []
|
||||||
for i, fc in enumerate(calls):
|
for i, fc in enumerate(calls):
|
||||||
name, args = fc.name, dict(fc.args)
|
name, args = fc.name, dict(fc.args)
|
||||||
|
|
||||||
|
# Check for tool confirmation if callback is provided
|
||||||
|
if pre_tool_callback:
|
||||||
|
payload_str = json.dumps({"tool": name, "args": args})
|
||||||
|
if not pre_tool_callback(payload_str):
|
||||||
|
out = "USER REJECTED: tool execution cancelled"
|
||||||
|
f_resps.append(types.Part.from_function_response(name=name, response={"output": out}))
|
||||||
|
log.append({"tool_use_id": name, "content": out})
|
||||||
|
continue
|
||||||
|
|
||||||
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
||||||
if name in mcp_client.TOOL_NAMES:
|
if name in mcp_client.TOOL_NAMES:
|
||||||
_append_comms("OUT", "tool_call", {"name": name, "args": args})
|
_append_comms("OUT", "tool_call", {"name": name, "args": args})
|
||||||
@@ -868,7 +879,8 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
|||||||
|
|
||||||
def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||||
file_items: list[dict] | None = None,
|
file_items: list[dict] | None = None,
|
||||||
discussion_history: str = "") -> str:
|
discussion_history: str = "",
|
||||||
|
pre_tool_callback = None) -> str:
|
||||||
global _gemini_cli_adapter
|
global _gemini_cli_adapter
|
||||||
try:
|
try:
|
||||||
if _gemini_cli_adapter is None:
|
if _gemini_cli_adapter is None:
|
||||||
@@ -951,6 +963,20 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
|||||||
args = fc.get("args", {})
|
args = fc.get("args", {})
|
||||||
call_id = fc.get("id")
|
call_id = fc.get("id")
|
||||||
|
|
||||||
|
# Check for tool confirmation if callback is provided
|
||||||
|
if pre_tool_callback:
|
||||||
|
payload_str = json.dumps({"tool": name, "args": args})
|
||||||
|
if not pre_tool_callback(payload_str):
|
||||||
|
out = "USER REJECTED: tool execution cancelled"
|
||||||
|
tool_results_for_cli.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"name": name,
|
||||||
|
"content": out
|
||||||
|
})
|
||||||
|
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
||||||
|
continue
|
||||||
|
|
||||||
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
||||||
if name in mcp_client.TOOL_NAMES:
|
if name in mcp_client.TOOL_NAMES:
|
||||||
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
|
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
|
||||||
@@ -1251,13 +1277,13 @@ def _repair_anthropic_history(history: list[dict]):
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_items: list[dict] | None = None, discussion_history: str = "") -> str:
|
def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_items: list[dict] | None = None, discussion_history: str = "", pre_tool_callback = None) -> str:
|
||||||
try:
|
try:
|
||||||
_ensure_anthropic_client()
|
_ensure_anthropic_client()
|
||||||
mcp_client.configure(file_items or [], [base_dir])
|
mcp_client.configure(file_items or [], [base_dir])
|
||||||
|
|
||||||
# Split system into two cache breakpoints:
|
# Split system into two cache breakpoints:
|
||||||
# 1. Stable system prompt (never changes — always a cache hit)
|
# 1. Stable system prompt (never changes \u2014 always a cache hit)
|
||||||
# 2. Dynamic file context (invalidated only when files change)
|
# 2. Dynamic file context (invalidated only when files change)
|
||||||
stable_prompt = _get_combined_system_prompt()
|
stable_prompt = _get_combined_system_prompt()
|
||||||
stable_blocks = [{"type": "text", "text": stable_prompt, "cache_control": {"type": "ephemeral"}}]
|
stable_blocks = [{"type": "text", "text": stable_prompt, "cache_control": {"type": "ephemeral"}}]
|
||||||
@@ -1382,6 +1408,19 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
|
|||||||
b_name = getattr(block, "name", None)
|
b_name = getattr(block, "name", None)
|
||||||
b_id = getattr(block, "id", "")
|
b_id = getattr(block, "id", "")
|
||||||
b_input = getattr(block, "input", {})
|
b_input = getattr(block, "input", {})
|
||||||
|
|
||||||
|
# Check for tool confirmation if callback is provided
|
||||||
|
if pre_tool_callback:
|
||||||
|
payload_str = json.dumps({"tool": b_name, "args": b_input})
|
||||||
|
if not pre_tool_callback(payload_str):
|
||||||
|
output = "USER REJECTED: tool execution cancelled"
|
||||||
|
tool_results.append({
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": b_id,
|
||||||
|
"content": output,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx})
|
events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx})
|
||||||
if b_name in mcp_client.TOOL_NAMES:
|
if b_name in mcp_client.TOOL_NAMES:
|
||||||
_append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input})
|
_append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input})
|
||||||
@@ -1480,7 +1519,8 @@ def _ensure_deepseek_client():
|
|||||||
def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
||||||
file_items: list[dict] | None = None,
|
file_items: list[dict] | None = None,
|
||||||
discussion_history: str = "",
|
discussion_history: str = "",
|
||||||
stream: bool = False) -> str:
|
stream: bool = False,
|
||||||
|
pre_tool_callback = None) -> str:
|
||||||
"""
|
"""
|
||||||
Sends a message to the DeepSeek API, handling tool calls and history.
|
Sends a message to the DeepSeek API, handling tool calls and history.
|
||||||
Supports streaming responses.
|
Supports streaming responses.
|
||||||
@@ -1652,6 +1692,19 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
|||||||
except:
|
except:
|
||||||
tool_args = {}
|
tool_args = {}
|
||||||
|
|
||||||
|
# Check for tool confirmation if callback is provided
|
||||||
|
if pre_tool_callback:
|
||||||
|
payload_str = json.dumps({"tool": tool_name, "args": tool_args})
|
||||||
|
if not pre_tool_callback(payload_str):
|
||||||
|
tool_output = "USER REJECTED: tool execution cancelled"
|
||||||
|
tool_results_for_history.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_id,
|
||||||
|
"content": tool_output,
|
||||||
|
})
|
||||||
|
_append_comms("IN", "tool_result", {"name": tool_name, "id": tool_id, "output": tool_output})
|
||||||
|
continue
|
||||||
|
|
||||||
events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx})
|
events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx})
|
||||||
|
|
||||||
if tool_name in mcp_client.TOOL_NAMES:
|
if tool_name in mcp_client.TOOL_NAMES:
|
||||||
@@ -1720,6 +1773,7 @@ def send(
|
|||||||
file_items: list[dict] | None = None,
|
file_items: list[dict] | None = None,
|
||||||
discussion_history: str = "",
|
discussion_history: str = "",
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
pre_tool_callback = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Send a message to the active provider.
|
Send a message to the active provider.
|
||||||
@@ -1733,16 +1787,17 @@ def send(
|
|||||||
discussion_history : discussion history text (used by Gemini to inject as
|
discussion_history : discussion history text (used by Gemini to inject as
|
||||||
conversation message instead of caching it)
|
conversation message instead of caching it)
|
||||||
stream : Whether to use streaming (supported by DeepSeek)
|
stream : Whether to use streaming (supported by DeepSeek)
|
||||||
|
pre_tool_callback : Optional callback (payload: str) -> bool called before tool execution
|
||||||
"""
|
"""
|
||||||
with _send_lock:
|
with _send_lock:
|
||||||
if _provider == "gemini":
|
if _provider == "gemini":
|
||||||
return _send_gemini(md_content, user_message, base_dir, file_items, discussion_history)
|
return _send_gemini(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback)
|
||||||
elif _provider == "gemini_cli":
|
elif _provider == "gemini_cli":
|
||||||
return _send_gemini_cli(md_content, user_message, base_dir, file_items, discussion_history)
|
return _send_gemini_cli(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback)
|
||||||
elif _provider == "anthropic":
|
elif _provider == "anthropic":
|
||||||
return _send_anthropic(md_content, user_message, base_dir, file_items, discussion_history)
|
return _send_anthropic(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback)
|
||||||
elif _provider == "deepseek":
|
elif _provider == "deepseek":
|
||||||
return _send_deepseek(md_content, user_message, base_dir, file_items, discussion_history, stream=stream)
|
return _send_deepseek(md_content, user_message, base_dir, file_items, discussion_history, stream=stream, pre_tool_callback=pre_tool_callback)
|
||||||
raise ValueError(f"unknown provider: {_provider}")
|
raise ValueError(f"unknown provider: {_provider}")
|
||||||
|
|
||||||
def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
def get_history_bleed_stats(md_content: str | None = None) -> dict:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class Ticket:
|
|||||||
assigned_to: str
|
assigned_to: str
|
||||||
depends_on: List[str] = field(default_factory=list)
|
depends_on: List[str] = field(default_factory=list)
|
||||||
blocked_reason: Optional[str] = None
|
blocked_reason: Optional[str] = None
|
||||||
|
step_mode: bool = False
|
||||||
|
|
||||||
def mark_blocked(self, reason: str):
|
def mark_blocked(self, reason: str):
|
||||||
"""Sets the ticket status to 'blocked' and records the reason."""
|
"""Sets the ticket status to 'blocked' and records the reason."""
|
||||||
|
|||||||
@@ -36,6 +36,13 @@ class ConductorEngine:
|
|||||||
)
|
)
|
||||||
run_worker_lifecycle(ticket, context)
|
run_worker_lifecycle(ticket, context)
|
||||||
|
|
||||||
|
def confirm_execution(payload: str) -> bool:
|
||||||
|
"""
|
||||||
|
Placeholder for external confirmation function.
|
||||||
|
In a real scenario, this might trigger a UI prompt.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] = None):
|
def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: List[str] = None):
|
||||||
"""
|
"""
|
||||||
Simulates the lifecycle of a single agent working on a ticket.
|
Simulates the lifecycle of a single agent working on a ticket.
|
||||||
@@ -74,7 +81,8 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
|
|||||||
response = ai_client.send(
|
response = ai_client.send(
|
||||||
md_content="",
|
md_content="",
|
||||||
user_message=user_message,
|
user_message=user_message,
|
||||||
base_dir="."
|
base_dir=".",
|
||||||
|
pre_tool_callback=confirm_execution if ticket.step_mode else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if "BLOCKED" in response.upper():
|
if "BLOCKED" in response.upper():
|
||||||
|
|||||||
@@ -138,3 +138,59 @@ def test_run_worker_lifecycle_handles_blocked_response():
|
|||||||
|
|
||||||
assert ticket.status == "blocked"
|
assert ticket.status == "blocked"
|
||||||
assert "BLOCKED" in ticket.blocked_reason
|
assert "BLOCKED" in ticket.blocked_reason
|
||||||
|
|
||||||
|
def test_run_worker_lifecycle_step_mode_confirmation():
|
||||||
|
"""
|
||||||
|
Test that run_worker_lifecycle passes confirm_execution to ai_client.send when step_mode is True.
|
||||||
|
Verify that if confirm_execution is called (simulated by mocking ai_client.send to call its callback),
|
||||||
|
the flow works as expected.
|
||||||
|
"""
|
||||||
|
ticket = Ticket(id="T1", description="Task 1", status="todo", assigned_to="worker1", step_mode=True)
|
||||||
|
context = WorkerContext(ticket_id="T1", model_name="test-model", messages=[])
|
||||||
|
|
||||||
|
from multi_agent_conductor import run_worker_lifecycle, confirm_execution
|
||||||
|
|
||||||
|
with patch("ai_client.send") as mock_send, \
|
||||||
|
patch("multi_agent_conductor.confirm_execution") as mock_confirm:
|
||||||
|
|
||||||
|
# We simulate ai_client.send by making it call the pre_tool_callback it received
|
||||||
|
def mock_send_side_effect(*args, **kwargs):
|
||||||
|
callback = kwargs.get("pre_tool_callback")
|
||||||
|
if callback:
|
||||||
|
# Simulate calling it with some payload
|
||||||
|
callback('{"tool": "read_file", "args": {"path": "test.txt"}}')
|
||||||
|
return "Success"
|
||||||
|
|
||||||
|
mock_send.side_effect = mock_send_side_effect
|
||||||
|
mock_confirm.return_value = True
|
||||||
|
|
||||||
|
run_worker_lifecycle(ticket, context)
|
||||||
|
|
||||||
|
# Verify confirm_execution was called
|
||||||
|
mock_confirm.assert_called_once()
|
||||||
|
assert ticket.status == "completed"
|
||||||
|
|
||||||
|
def test_run_worker_lifecycle_step_mode_rejection():
|
||||||
|
"""
|
||||||
|
Verify that if confirm_execution returns False, the logic (in ai_client, which we simulate here)
|
||||||
|
would prevent execution. In run_worker_lifecycle, we just check if it's passed.
|
||||||
|
"""
|
||||||
|
ticket = Ticket(id="T1", description="Task 1", status="todo", assigned_to="worker1", step_mode=True)
|
||||||
|
context = WorkerContext(ticket_id="T1", model_name="test-model", messages=[])
|
||||||
|
|
||||||
|
from multi_agent_conductor import run_worker_lifecycle
|
||||||
|
|
||||||
|
with patch("ai_client.send") as mock_send, \
|
||||||
|
patch("multi_agent_conductor.confirm_execution") as mock_confirm:
|
||||||
|
|
||||||
|
mock_confirm.return_value = False
|
||||||
|
mock_send.return_value = "Task failed because tool execution was rejected."
|
||||||
|
|
||||||
|
run_worker_lifecycle(ticket, context)
|
||||||
|
|
||||||
|
# Verify it was passed to send
|
||||||
|
args, kwargs = mock_send.call_args
|
||||||
|
assert kwargs["pre_tool_callback"] is not None
|
||||||
|
|
||||||
|
# Since we've already tested ai_client's implementation of pre_tool_callback (mentally or via other tests),
|
||||||
|
# here we just verify the wiring.
|
||||||
|
|||||||
Reference in New Issue
Block a user