Compare commits

...

5 Commits

5 changed files with 228 additions and 220 deletions

View File

@@ -41,6 +41,7 @@
- **ApiHookClient:** A dedicated IPC client for automated GUI interaction and state inspection. - **ApiHookClient:** A dedicated IPC client for automated GUI interaction and state inspection.
- **mma-exec / mma.ps1:** Python-based execution engine and PowerShell wrapper for managing the 4-Tier MMA hierarchy and automated documentation mapping. - **mma-exec / mma.ps1:** Python-based execution engine and PowerShell wrapper for managing the 4-Tier MMA hierarchy and automated documentation mapping.
- **dag_engine.py:** A native Python utility implementing `TrackDAG` and `ExecutionEngine` for dependency resolution, cycle detection, transitive blocking propagation, and programmable task execution loops. - **dag_engine.py:** A native Python utility implementing `TrackDAG` and `ExecutionEngine` for dependency resolution, cycle detection, transitive blocking propagation, and programmable task execution loops.
- **Thread-Local Context Isolation:** Utilizes `threading.local()` for managing per-thread AI client context (e.g., source tier tagging), ensuring thread safety during concurrent multi-agent execution.
## Architectural Patterns ## Architectural Patterns

View File

@@ -20,18 +20,18 @@ This file tracks all major tracks for the project. Each track has its own detail
4. [x] **Track: Robust JSON Parsing for Tech Lead** 4. [x] **Track: Robust JSON Parsing for Tech Lead**
*Link: [./tracks/robust_json_parsing_tech_lead_20260302/](./tracks/robust_json_parsing_tech_lead_20260302/)* *Link: [./tracks/robust_json_parsing_tech_lead_20260302/](./tracks/robust_json_parsing_tech_lead_20260302/)*
5. [~] **Track: Concurrent Tier Source Isolation** 5. [x] **Track: Concurrent Tier Source Isolation**
*Link: [./tracks/concurrent_tier_source_tier_20260302/](./tracks/concurrent_tier_source_tier_20260302/)* *Link: [./tracks/concurrent_tier_source_tier_20260302/](./tracks/concurrent_tier_source_tier_20260302/)*
6. [ ] **Track: Manual UX Validation & Polish** 6. [ ] **Track: Asynchronous Tool Execution Engine**
*Link: [./tracks/manual_ux_validation_20260302/](./tracks/manual_ux_validation_20260302/)*
7. [ ] **Track: Asynchronous Tool Execution Engine**
*Link: [./tracks/async_tool_execution_20260303/](./tracks/async_tool_execution_20260303/)* *Link: [./tracks/async_tool_execution_20260303/](./tracks/async_tool_execution_20260303/)*
8. [ ] **Track: Simulation Fidelity Enhancement** 7. [ ] **Track: Simulation Fidelity Enhancement**
*Link: [./tracks/simulation_fidelity_enhancement_20260305/](./tracks/simulation_fidelity_enhancement_20260305/)* *Link: [./tracks/simulation_fidelity_enhancement_20260305/](./tracks/simulation_fidelity_enhancement_20260305/)*
8. [ ] **Track: Manual UX Validation & Polish**
*Link: [./tracks/manual_ux_validation_20260302/](./tracks/manual_ux_validation_20260302/)*
--- ---
## Completed / Archived ## Completed / Archived

View File

@@ -4,30 +4,30 @@
## Phase 1: Thread-Local Context Refactoring ## Phase 1: Thread-Local Context Refactoring
- [x] Task: Initialize MMA Environment `activate_skill mma-orchestrator` - [x] Task: Initialize MMA Environment `activate_skill mma-orchestrator`
- [~] Task: Refactor `ai_client` to `threading.local()` - [x] Task: Refactor `ai_client` to `threading.local()` (684a6d1)
- [ ] WHERE: `ai_client.py` - [x] WHERE: `ai_client.py`
- [ ] WHAT: Replace `current_tier = None` with `_local_context = threading.local()`. Implement safe getters/setters for the tier. - [x] WHAT: Replace `current_tier = None` with `_local_context = threading.local()`. Implement safe getters/setters for the tier.
- [ ] HOW: Use standard `threading.local` attributes. - [x] HOW: Use standard `threading.local` attributes.
- [ ] SAFETY: Provide defaults (e.g., `getattr(_local_context, 'tier', None)`) so uninitialized threads don't crash. - [x] SAFETY: Provide defaults (e.g., `getattr(_local_context, 'tier', None)`) so uninitialized threads don't crash.
- [ ] Task: Update Lifecycle Callers - [x] Task: Update Lifecycle Callers (684a6d1)
- [ ] WHERE: `multi_agent_conductor.py`, `conductor_tech_lead.py` - [x] WHERE: `multi_agent_conductor.py`, `conductor_tech_lead.py`
- [ ] WHAT: Update how they set the current tier around `send()` calls. - [x] WHAT: Update how they set the current tier around `send()` calls.
- [ ] HOW: Use the new setter/getter functions from `ai_client`. - [x] HOW: Use the new setter/getter functions from `ai_client`.
- [ ] SAFETY: Ensure `finally` blocks clean up the thread-local state. - [x] SAFETY: Ensure `finally` blocks clean up the thread-local state.
- [ ] Task: Conductor - User Manual Verification 'Phase 1: Refactoring' (Protocol in workflow.md) - [x] Task: Conductor - User Manual Verification 'Phase 1: Refactoring' (Protocol in workflow.md)
## Phase 2: Testing Concurrency ## Phase 2: Testing Concurrency
- [ ] Task: Write Concurrent Execution Test - [x] Task: Write Concurrent Execution Test (684a6d1)
- [ ] WHERE: `tests/test_ai_client_concurrency.py` (New) - [x] WHERE: `tests/test_ai_client_concurrency.py` (New)
- [ ] WHAT: Spawn two threads. Thread A sets Tier 3 and calls a mock `send`. Thread B sets Tier 4 and calls mock `send`. - [x] WHAT: Spawn two threads. Thread A sets Tier 3 and calls a mock `send`. Thread B sets Tier 4 and calls mock `send`.
- [ ] HOW: Assert that the resulting `comms_log` correctly maps the entries to Tier 3 and Tier 4 respectively without race condition overwrites. - [x] HOW: Assert that the resulting `comms_log` correctly maps the entries to Tier 3 and Tier 4 respectively without race condition overwrites.
- [ ] SAFETY: Use `threading.Barrier` to force race conditions in the test to ensure the isolation holds. - [x] SAFETY: Use `threading.Barrier` to force race conditions in the test to ensure the isolation holds.
- [ ] Task: Conductor - User Manual Verification 'Phase 2: Testing Concurrency' (Protocol in workflow.md) - [x] Task: Conductor - User Manual Verification 'Phase 2: Testing Concurrency' (Protocol in workflow.md)
## Phase 3: Final Validation ## Phase 3: Final Validation
- [ ] Task: Full Suite Validation & Warning Cleanup - [x] Task: Full Suite Validation & Warning Cleanup (684a6d1)
- [ ] WHERE: Project root - [x] WHERE: Project root
- [ ] WHAT: `uv run pytest` - [x] WHAT: `uv run pytest`
- [ ] HOW: Ensure 100% pass rate. - [x] HOW: Ensure 100% pass rate.
- [ ] SAFETY: None. - [x] SAFETY: None.
- [ ] Task: Conductor - User Manual Verification 'Phase 3: Final Validation' (Protocol in workflow.md) - [x] Task: Conductor - User Manual Verification 'Phase 3: Final Validation' (Protocol in workflow.md)

View File

@@ -13,6 +13,7 @@ during chat creation to avoid massive history bloat.
""" """
# ai_client.py # ai_client.py
import tomllib import tomllib
import asyncio
import json import json
import sys import sys
import time import time
@@ -473,6 +474,84 @@ def _gemini_tool_declaration() -> Optional[types.Tool]:
)) ))
return types.Tool(function_declarations=declarations) if declarations else None return types.Tool(function_declarations=declarations) if declarations else None
async def _execute_tool_calls_concurrently(
calls: list[Any],
base_dir: str,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
qa_callback: Optional[Callable[[str], str]],
r_idx: int,
provider: str
) -> list[tuple[str, str, str, str]]: # tool_name, call_id, output, original_name
"""
Executes multiple tool calls concurrently using asyncio.gather.
Returns a list of (tool_name, call_id, output, original_name).
"""
tasks = []
for fc in calls:
if provider == "gemini":
name, args, call_id = fc.name, dict(fc.args), fc.name # Gemini 1.0.0 doesn't have call IDs in types.Part
elif provider == "gemini_cli":
name, args, call_id = cast(str, fc.get("name")), cast(dict[str, Any], fc.get("args", {})), cast(str, fc.get("id"))
elif provider == "anthropic":
name, args, call_id = cast(str, getattr(fc, "name")), cast(dict[str, Any], getattr(fc, "input")), cast(str, getattr(fc, "id"))
elif provider == "deepseek":
tool_info = fc.get("function", {})
name = cast(str, tool_info.get("name"))
tool_args_str = cast(str, tool_info.get("arguments", "{}"))
call_id = cast(str, fc.get("id"))
try: args = json.loads(tool_args_str)
except: args = {}
else:
continue
tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx))
results = await asyncio.gather(*tasks)
return results
async def _execute_single_tool_call_async(
name: str,
args: dict[str, Any],
call_id: str,
base_dir: str,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
qa_callback: Optional[Callable[[str], str]],
r_idx: int
) -> tuple[str, str, str, str]:
out = ""
tool_executed = False
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
# Check for run_powershell
if name == TOOL_NAME and pre_tool_callback:
scr = cast(str, args.get("script", ""))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
# pre_tool_callback is synchronous and might block for HITL
res = await asyncio.to_thread(pre_tool_callback, scr, base_dir, qa_callback)
if res is None:
out = "USER REJECTED: tool execution cancelled"
else:
out = res
tool_executed = True
if not tool_executed:
if name and name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
_res = await asyncio.to_thread(pre_tool_callback, desc, base_dir, qa_callback)
out = "USER REJECTED: tool execution cancelled" if _res is None else await mcp_client.async_dispatch(name, args)
else:
out = await mcp_client.async_dispatch(name, args)
elif name == TOOL_NAME:
scr = cast(str, args.get("script", ""))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback)
else:
out = f"ERROR: unknown tool '{name}'"
return (name, call_id, out, name)
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str: def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str:
if confirm_and_run_callback is None: if confirm_and_run_callback is None:
return "ERROR: no confirmation handler registered" return "ERROR: no confirmation handler registered"
@@ -762,48 +841,33 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
if not calls or r_idx > MAX_TOOL_ROUNDS: break if not calls or r_idx > MAX_TOOL_ROUNDS: break
f_resps: list[types.Part] = [] f_resps: list[types.Part] = []
log: list[dict[str, Any]] = [] log: list[dict[str, Any]] = []
for i, fc in enumerate(calls):
name, args = fc.name, dict(fc.args) # Execute tools concurrently
out = "" try:
tool_executed = False loop = asyncio.get_running_loop()
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx}) results = asyncio.run_coroutine_threadsafe(
if name == TOOL_NAME and pre_tool_callback: _execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"),
scr = cast(str, args.get("script", "")) loop
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr}) ).result()
res = pre_tool_callback(scr, base_dir, qa_callback) except RuntimeError:
if res is None: results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"))
out = "USER REJECTED: tool execution cancelled"
else: for i, (name, call_id, out, _) in enumerate(results):
out = res # Check if this is the last tool to trigger file refresh
tool_executed = True if i == len(results) - 1:
if not tool_executed:
if name and name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": name, "args": args})
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
_res = pre_tool_callback(desc, base_dir, qa_callback)
out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args)
else:
out = mcp_client.dispatch(cast(str, name), args)
elif name == TOOL_NAME:
scr = cast(str, args.get("script", ""))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr})
out = _run_script(scr, base_dir, qa_callback)
else: out = f"ERROR: unknown tool '{name}'"
if i == len(calls) - 1:
if file_items: if file_items:
file_items, changed = _reread_file_items(file_items) file_items, changed = _reread_file_items(file_items)
ctx = _build_file_diff_text(changed) ctx = _build_file_diff_text(changed)
if ctx: if ctx:
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}" out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]" if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
out = _truncate_tool_output(out) out = _truncate_tool_output(out)
_cumulative_tool_bytes += len(out) _cumulative_tool_bytes += len(out)
f_resps.append(types.Part(function_response=types.FunctionResponse(name=cast(str, name), response={"output": out}))) f_resps.append(types.Part(function_response=types.FunctionResponse(name=cast(str, name), response={"output": out})))
log.append({"tool_use_id": name, "content": out}) log.append({"tool_use_id": name, "content": out})
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx}) events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
f_resps.append(types.Part(text= f_resps.append(types.Part(text=
f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now." f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
@@ -877,42 +941,21 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
}) })
if not calls or r_idx > MAX_TOOL_ROUNDS: if not calls or r_idx > MAX_TOOL_ROUNDS:
break break
tool_results_for_cli: list[dict[str, Any]] = []
for i, fc in enumerate(calls): # Execute tools concurrently
name = cast(str, fc.get("name")) try:
args = cast(dict[str, Any], fc.get("args", {})) loop = asyncio.get_running_loop()
call_id = cast(str, fc.get("id")) results = asyncio.run_coroutine_threadsafe(
out = "" _execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"),
tool_executed = False loop
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx}) ).result()
if name == TOOL_NAME and pre_tool_callback: except RuntimeError:
scr = cast(str, args.get("script", "")) results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
res = pre_tool_callback(scr, base_dir, qa_callback)
if res is None:
out = "USER REJECTED: tool execution cancelled"
else:
out = res
tool_executed = True
if not tool_executed:
if name and name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args}) tool_results_for_cli: list[dict[str, Any]] = []
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback: for i, (name, call_id, out, _) in enumerate(results):
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items()) # Check if this is the last tool to trigger file refresh
_res = pre_tool_callback(desc, base_dir, qa_callback) if i == len(results) - 1:
out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args)
else:
out = mcp_client.dispatch(name, args)
elif name == TOOL_NAME:
scr = cast(str, args.get("script", ""))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
out = _run_script(scr, base_dir, qa_callback)
else:
out = f"ERROR: unknown tool '{name}'"
if i == len(calls) - 1:
if file_items: if file_items:
file_items, changed = _reread_file_items(file_items) file_items, changed = _reread_file_items(file_items)
ctx = _build_file_diff_text(changed) ctx = _build_file_diff_text(changed)
@@ -920,6 +963,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}" out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
if r_idx == MAX_TOOL_ROUNDS: if r_idx == MAX_TOOL_ROUNDS:
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]" out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
out = _truncate_tool_output(out) out = _truncate_tool_output(out)
_cumulative_tool_bytes += len(out) _cumulative_tool_bytes += len(out)
tool_results_for_cli.append({ tool_results_for_cli.append({
@@ -930,6 +974,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
}) })
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out}) _append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx}) events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
payload = tool_results_for_cli payload = tool_results_for_cli
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"}) _append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
@@ -1217,61 +1262,29 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
break break
if round_idx > MAX_TOOL_ROUNDS: if round_idx > MAX_TOOL_ROUNDS:
break break
tool_results: list[dict[str, Any]] = []
for block in response.content: # Execute tools concurrently
if getattr(block, "type", None) != "tool_use": try:
continue loop = asyncio.get_running_loop()
b_name = cast(str, getattr(block, "name")) results = asyncio.run_coroutine_threadsafe(
b_id = cast(str, getattr(block, "id")) _execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic"),
b_input = cast(dict[str, Any], getattr(block, "input")) loop
output = "" ).result()
tool_executed = False except RuntimeError:
events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx}) results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic"))
if b_name == TOOL_NAME and pre_tool_callback:
script = cast(str, b_input.get("script", ""))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": b_id, "script": script})
res = pre_tool_callback(script, base_dir, qa_callback)
if res is None:
output = "USER REJECTED: tool execution cancelled"
else:
output = res
tool_executed = True
if not tool_executed:
if b_name and b_name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input}) tool_results: list[dict[str, Any]] = []
if b_name in mcp_client.MUTATING_TOOLS and pre_tool_callback: for i, (name, call_id, out, _) in enumerate(results):
desc = f"# MCP MUTATING TOOL: {b_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in b_input.items()) truncated = _truncate_tool_output(out)
_res = pre_tool_callback(desc, base_dir, qa_callback)
output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(b_name, b_input)
else:
output = mcp_client.dispatch(b_name, b_input)
_append_comms("IN", "tool_result", {"name": b_name, "id": b_id, "output": output})
elif b_name == TOOL_NAME:
script = cast(str, b_input.get("script", ""))
_append_comms("OUT", "tool_call", {
"name": TOOL_NAME,
"id": b_id,
"script": script,
})
output = _run_script(script, base_dir, qa_callback)
_append_comms("IN", "tool_result", {
"name": TOOL_NAME,
"id": b_id,
"output": output,
})
else:
output = f"ERROR: unknown tool '{b_name}'"
truncated = _truncate_tool_output(output)
_cumulative_tool_bytes += len(truncated) _cumulative_tool_bytes += len(truncated)
tool_results.append({ tool_results.append({
"type": "tool_result", "type": "tool_result",
"tool_use_id": b_id, "tool_use_id": call_id,
"content": truncated, "content": truncated,
}) })
events.emit("tool_execution", payload={"status": "completed", "tool": b_name, "result": output, "round": round_idx}) _append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx})
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
tool_results.append({ tool_results.append({
"type": "text", "type": "text",
@@ -1448,62 +1461,39 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
break break
if round_idx > MAX_TOOL_ROUNDS: if round_idx > MAX_TOOL_ROUNDS:
break break
# Execute tools concurrently
try:
loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"),
loop
).result()
except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"))
tool_results_for_history: list[dict[str, Any]] = [] tool_results_for_history: list[dict[str, Any]] = []
for i, tc_raw in enumerate(tool_calls_raw): for i, (name, call_id, out, _) in enumerate(results):
tool_info = tc_raw.get("function", {}) if i == len(results) - 1:
tool_name = cast(str, tool_info.get("name"))
tool_args_str = cast(str, tool_info.get("arguments", "{}"))
tool_id = cast(str, tc_raw.get("id"))
try:
tool_args = json.loads(tool_args_str)
except:
tool_args = {}
tool_output = ""
tool_executed = False
events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx})
if tool_name == TOOL_NAME and pre_tool_callback:
script = cast(str, tool_args.get("script", ""))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script})
res = pre_tool_callback(script, base_dir, qa_callback)
if res is None:
tool_output = "USER REJECTED: tool execution cancelled"
else:
tool_output = res
tool_executed = True
if not tool_executed:
if tool_name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": tool_name, "id": tool_id, "args": tool_args})
if tool_name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
desc = f"# MCP MUTATING TOOL: {tool_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in tool_args.items())
_res = pre_tool_callback(desc, base_dir, qa_callback)
tool_output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(tool_name, tool_args)
else:
tool_output = mcp_client.dispatch(tool_name, tool_args)
elif tool_name == TOOL_NAME:
script = cast(str, tool_args.get("script", ""))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script})
tool_output = _run_script(script, base_dir, qa_callback)
else:
tool_output = f"ERROR: unknown tool '{tool_name}'"
if i == len(tool_calls_raw) - 1:
if file_items: if file_items:
file_items, changed = _reread_file_items(file_items) file_items, changed = _reread_file_items(file_items)
ctx = _build_file_diff_text(changed) ctx = _build_file_diff_text(changed)
if ctx: if ctx:
tool_output += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}" out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
if round_idx == MAX_TOOL_ROUNDS: if round_idx == MAX_TOOL_ROUNDS:
tool_output += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]" out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
tool_output = _truncate_tool_output(tool_output)
_cumulative_tool_bytes += len(tool_output) truncated = _truncate_tool_output(out)
_cumulative_tool_bytes += len(truncated)
tool_results_for_history.append({ tool_results_for_history.append({
"role": "tool", "role": "tool",
"tool_call_id": tool_id, "tool_call_id": call_id,
"content": tool_output, "name": name,
"content": truncated,
}) })
_append_comms("IN", "tool_result", {"name": tool_name, "id": tool_id, "output": tool_output}) _append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
events.emit("tool_execution", payload={"status": "completed", "tool": tool_name, "result": tool_output, "round": round_idx}) events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx})
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
tool_results_for_history.append({ tool_results_for_history.append({
"role": "user", "role": "user",
@@ -1524,6 +1514,7 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
except Exception as e: except Exception as e:
raise _classify_deepseek_error(e) from e raise _classify_deepseek_error(e) from e
def run_tier4_analysis(stderr: str) -> str: def run_tier4_analysis(stderr: str) -> str:
if not stderr or not stderr.strip(): if not stderr or not stderr.strip():
return "" return ""

View File

@@ -30,6 +30,7 @@ so the AI doesn't wander outside the project workspace.
# #
from __future__ import annotations from __future__ import annotations
import asyncio
from pathlib import Path from pathlib import Path
from typing import Optional, Callable, Any, cast from typing import Optional, Callable, Any, cast
import os import os
@@ -858,75 +859,90 @@ def get_ui_performance() -> str:
TOOL_NAMES: set[str] = {"read_file", "list_directory", "search_files", "get_file_summary", "py_get_skeleton", "py_get_code_outline", "py_get_definition", "get_git_diff", "web_search", "fetch_url", "get_ui_performance", "get_file_slice", "set_file_slice", "edit_file", "py_update_definition", "py_get_signature", "py_set_signature", "py_get_class_summary", "py_get_var_declaration", "py_set_var_declaration", "py_find_usages", "py_get_imports", "py_check_syntax", "py_get_hierarchy", "py_get_docstring", "get_tree"} TOOL_NAMES: set[str] = {"read_file", "list_directory", "search_files", "get_file_summary", "py_get_skeleton", "py_get_code_outline", "py_get_definition", "get_git_diff", "web_search", "fetch_url", "get_ui_performance", "get_file_slice", "set_file_slice", "edit_file", "py_update_definition", "py_get_signature", "py_set_signature", "py_get_class_summary", "py_get_var_declaration", "py_set_var_declaration", "py_find_usages", "py_get_imports", "py_check_syntax", "py_get_hierarchy", "py_get_docstring", "get_tree"}
def dispatch(tool_name: str, tool_input: dict[str, Any]) -> str: async def async_dispatch(tool_name: str, tool_input: dict[str, Any]) -> str:
"""
Dispatch an MCP tool call by name asynchronously. Returns the result as a string.
""" """
Dispatch an MCP tool call by name. Returns the result as a string.
"""
# Handle aliases # Handle aliases
path = str(tool_input.get("path", tool_input.get("file_path", tool_input.get("dir_path", "")))) path = str(tool_input.get("path", tool_input.get("file_path", tool_input.get("dir_path", ""))))
if tool_name == "read_file": if tool_name == "read_file":
return read_file(path) return await asyncio.to_thread(read_file, path)
if tool_name == "list_directory": if tool_name == "list_directory":
return list_directory(path) return await asyncio.to_thread(list_directory, path)
if tool_name == "search_files": if tool_name == "search_files":
return search_files(path, str(tool_input.get("pattern", "*"))) return await asyncio.to_thread(search_files, path, str(tool_input.get("pattern", "*")))
if tool_name == "get_file_summary": if tool_name == "get_file_summary":
return get_file_summary(path) return await asyncio.to_thread(get_file_summary, path)
if tool_name == "py_get_skeleton": if tool_name == "py_get_skeleton":
return py_get_skeleton(path) return await asyncio.to_thread(py_get_skeleton, path)
if tool_name == "py_get_code_outline": if tool_name == "py_get_code_outline":
return py_get_code_outline(path) return await asyncio.to_thread(py_get_code_outline, path)
if tool_name == "py_get_definition": if tool_name == "py_get_definition":
return py_get_definition(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_get_definition, path, str(tool_input.get("name", "")))
if tool_name == "py_update_definition": if tool_name == "py_update_definition":
return py_update_definition(path, str(tool_input.get("name", "")), str(tool_input.get("new_content", ""))) return await asyncio.to_thread(py_update_definition, path, str(tool_input.get("name", "")), str(tool_input.get("new_content", "")))
if tool_name == "py_get_signature": if tool_name == "py_get_signature":
return py_get_signature(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_get_signature, path, str(tool_input.get("name", "")))
if tool_name == "py_set_signature": if tool_name == "py_set_signature":
return py_set_signature(path, str(tool_input.get("name", "")), str(tool_input.get("new_signature", ""))) return await asyncio.to_thread(py_set_signature, path, str(tool_input.get("name", "")), str(tool_input.get("new_signature", "")))
if tool_name == "py_get_class_summary": if tool_name == "py_get_class_summary":
return py_get_class_summary(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_get_class_summary, path, str(tool_input.get("name", "")))
if tool_name == "py_get_var_declaration": if tool_name == "py_get_var_declaration":
return py_get_var_declaration(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_get_var_declaration, path, str(tool_input.get("name", "")))
if tool_name == "py_set_var_declaration": if tool_name == "py_set_var_declaration":
return py_set_var_declaration(path, str(tool_input.get("name", "")), str(tool_input.get("new_declaration", ""))) return await asyncio.to_thread(py_set_var_declaration, path, str(tool_input.get("name", "")), str(tool_input.get("new_declaration", "")))
if tool_name == "get_file_slice": if tool_name == "get_file_slice":
return get_file_slice(path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1))) return await asyncio.to_thread(get_file_slice, path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1)))
if tool_name == "set_file_slice": if tool_name == "set_file_slice":
return set_file_slice(path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1)), str(tool_input.get("new_content", ""))) return await asyncio.to_thread(set_file_slice, path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1)), str(tool_input.get("new_content", "")))
if tool_name == "get_git_diff": if tool_name == "get_git_diff":
return get_git_diff( return await asyncio.to_thread(get_git_diff,
path, path,
str(tool_input.get("base_rev", "HEAD")), str(tool_input.get("base_rev", "HEAD")),
str(tool_input.get("head_rev", "")) str(tool_input.get("head_rev", ""))
) )
if tool_name == "edit_file": if tool_name == "edit_file":
return edit_file( return await asyncio.to_thread(edit_file,
path, path,
str(tool_input.get("old_string", "")), str(tool_input.get("old_string", "")),
str(tool_input.get("new_string", "")), str(tool_input.get("new_string", "")),
bool(tool_input.get("replace_all", False)) bool(tool_input.get("replace_all", False))
) )
if tool_name == "web_search": if tool_name == "web_search":
return web_search(str(tool_input.get("query", ""))) return await asyncio.to_thread(web_search, str(tool_input.get("query", "")))
if tool_name == "fetch_url": if tool_name == "fetch_url":
return fetch_url(str(tool_input.get("url", ""))) return await asyncio.to_thread(fetch_url, str(tool_input.get("url", "")))
if tool_name == "get_ui_performance": if tool_name == "get_ui_performance":
return get_ui_performance() return await asyncio.to_thread(get_ui_performance)
if tool_name == "py_find_usages": if tool_name == "py_find_usages":
return py_find_usages(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_find_usages, path, str(tool_input.get("name", "")))
if tool_name == "py_get_imports": if tool_name == "py_get_imports":
return py_get_imports(path) return await asyncio.to_thread(py_get_imports, path)
if tool_name == "py_check_syntax": if tool_name == "py_check_syntax":
return py_check_syntax(path) return await asyncio.to_thread(py_check_syntax, path)
if tool_name == "py_get_hierarchy": if tool_name == "py_get_hierarchy":
return py_get_hierarchy(path, str(tool_input.get("class_name", ""))) return await asyncio.to_thread(py_get_hierarchy, path, str(tool_input.get("class_name", "")))
if tool_name == "py_get_docstring": if tool_name == "py_get_docstring":
return py_get_docstring(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_get_docstring, path, str(tool_input.get("name", "")))
if tool_name == "get_tree": if tool_name == "get_tree":
return get_tree(path, int(tool_input.get("max_depth", 2))) return await asyncio.to_thread(get_tree, path, int(tool_input.get("max_depth", 2)))
return f"ERROR: unknown MCP tool '{tool_name}'" return f"ERROR: unknown MCP tool '{tool_name}'"
def dispatch(tool_name: str, tool_input: dict[str, Any]) -> str:
"""
Dispatch an MCP tool call by name. Returns the result as a string.
"""
try:
loop = asyncio.get_running_loop()
# If we are in a running loop, we can't use asyncio.run
# But we are in a synchronous function.
# This is tricky. If we are in a thread, we might not have a loop.
return asyncio.run_coroutine_threadsafe(async_dispatch(tool_name, tool_input), loop).result()
except RuntimeError:
# No running loop, use asyncio.run
return asyncio.run(async_dispatch(tool_name, tool_input))
def get_tool_schemas() -> list[dict[str, Any]]: def get_tool_schemas() -> list[dict[str, Any]]:
"""Returns the list of tool specifications for the AI.""" """Returns the list of tool specifications for the AI."""
return list(MCP_TOOL_SPECS) return list(MCP_TOOL_SPECS)