Compare commits
5 Commits
684a6d1d3b
...
87dbfc5958
| Author | SHA1 | Date | |
|---|---|---|---|
| 87dbfc5958 | |||
| 60e1dce2b6 | |||
| a960f3b3d0 | |||
| c01f1ea2c8 | |||
| 7eaed9c78a |
@@ -41,6 +41,7 @@
|
|||||||
- **ApiHookClient:** A dedicated IPC client for automated GUI interaction and state inspection.
|
- **ApiHookClient:** A dedicated IPC client for automated GUI interaction and state inspection.
|
||||||
- **mma-exec / mma.ps1:** Python-based execution engine and PowerShell wrapper for managing the 4-Tier MMA hierarchy and automated documentation mapping.
|
- **mma-exec / mma.ps1:** Python-based execution engine and PowerShell wrapper for managing the 4-Tier MMA hierarchy and automated documentation mapping.
|
||||||
- **dag_engine.py:** A native Python utility implementing `TrackDAG` and `ExecutionEngine` for dependency resolution, cycle detection, transitive blocking propagation, and programmable task execution loops.
|
- **dag_engine.py:** A native Python utility implementing `TrackDAG` and `ExecutionEngine` for dependency resolution, cycle detection, transitive blocking propagation, and programmable task execution loops.
|
||||||
|
- **Thread-Local Context Isolation:** Utilizes `threading.local()` for managing per-thread AI client context (e.g., source tier tagging), ensuring thread safety during concurrent multi-agent execution.
|
||||||
|
|
||||||
## Architectural Patterns
|
## Architectural Patterns
|
||||||
|
|
||||||
|
|||||||
@@ -20,18 +20,18 @@ This file tracks all major tracks for the project. Each track has its own detail
|
|||||||
4. [x] **Track: Robust JSON Parsing for Tech Lead**
|
4. [x] **Track: Robust JSON Parsing for Tech Lead**
|
||||||
*Link: [./tracks/robust_json_parsing_tech_lead_20260302/](./tracks/robust_json_parsing_tech_lead_20260302/)*
|
*Link: [./tracks/robust_json_parsing_tech_lead_20260302/](./tracks/robust_json_parsing_tech_lead_20260302/)*
|
||||||
|
|
||||||
5. [~] **Track: Concurrent Tier Source Isolation**
|
5. [x] **Track: Concurrent Tier Source Isolation**
|
||||||
*Link: [./tracks/concurrent_tier_source_tier_20260302/](./tracks/concurrent_tier_source_tier_20260302/)*
|
*Link: [./tracks/concurrent_tier_source_tier_20260302/](./tracks/concurrent_tier_source_tier_20260302/)*
|
||||||
|
|
||||||
6. [ ] **Track: Manual UX Validation & Polish**
|
6. [ ] **Track: Asynchronous Tool Execution Engine**
|
||||||
*Link: [./tracks/manual_ux_validation_20260302/](./tracks/manual_ux_validation_20260302/)*
|
|
||||||
|
|
||||||
7. [ ] **Track: Asynchronous Tool Execution Engine**
|
|
||||||
*Link: [./tracks/async_tool_execution_20260303/](./tracks/async_tool_execution_20260303/)*
|
*Link: [./tracks/async_tool_execution_20260303/](./tracks/async_tool_execution_20260303/)*
|
||||||
|
|
||||||
8. [ ] **Track: Simulation Fidelity Enhancement**
|
7. [ ] **Track: Simulation Fidelity Enhancement**
|
||||||
*Link: [./tracks/simulation_fidelity_enhancement_20260305/](./tracks/simulation_fidelity_enhancement_20260305/)*
|
*Link: [./tracks/simulation_fidelity_enhancement_20260305/](./tracks/simulation_fidelity_enhancement_20260305/)*
|
||||||
|
|
||||||
|
8. [ ] **Track: Manual UX Validation & Polish**
|
||||||
|
*Link: [./tracks/manual_ux_validation_20260302/](./tracks/manual_ux_validation_20260302/)*
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Completed / Archived
|
## Completed / Archived
|
||||||
|
|||||||
@@ -4,30 +4,30 @@
|
|||||||
|
|
||||||
## Phase 1: Thread-Local Context Refactoring
|
## Phase 1: Thread-Local Context Refactoring
|
||||||
- [x] Task: Initialize MMA Environment `activate_skill mma-orchestrator`
|
- [x] Task: Initialize MMA Environment `activate_skill mma-orchestrator`
|
||||||
- [~] Task: Refactor `ai_client` to `threading.local()`
|
- [x] Task: Refactor `ai_client` to `threading.local()` (684a6d1)
|
||||||
- [ ] WHERE: `ai_client.py`
|
- [x] WHERE: `ai_client.py`
|
||||||
- [ ] WHAT: Replace `current_tier = None` with `_local_context = threading.local()`. Implement safe getters/setters for the tier.
|
- [x] WHAT: Replace `current_tier = None` with `_local_context = threading.local()`. Implement safe getters/setters for the tier.
|
||||||
- [ ] HOW: Use standard `threading.local` attributes.
|
- [x] HOW: Use standard `threading.local` attributes.
|
||||||
- [ ] SAFETY: Provide defaults (e.g., `getattr(_local_context, 'tier', None)`) so uninitialized threads don't crash.
|
- [x] SAFETY: Provide defaults (e.g., `getattr(_local_context, 'tier', None)`) so uninitialized threads don't crash.
|
||||||
- [ ] Task: Update Lifecycle Callers
|
- [x] Task: Update Lifecycle Callers (684a6d1)
|
||||||
- [ ] WHERE: `multi_agent_conductor.py`, `conductor_tech_lead.py`
|
- [x] WHERE: `multi_agent_conductor.py`, `conductor_tech_lead.py`
|
||||||
- [ ] WHAT: Update how they set the current tier around `send()` calls.
|
- [x] WHAT: Update how they set the current tier around `send()` calls.
|
||||||
- [ ] HOW: Use the new setter/getter functions from `ai_client`.
|
- [x] HOW: Use the new setter/getter functions from `ai_client`.
|
||||||
- [ ] SAFETY: Ensure `finally` blocks clean up the thread-local state.
|
- [x] SAFETY: Ensure `finally` blocks clean up the thread-local state.
|
||||||
- [ ] Task: Conductor - User Manual Verification 'Phase 1: Refactoring' (Protocol in workflow.md)
|
- [x] Task: Conductor - User Manual Verification 'Phase 1: Refactoring' (Protocol in workflow.md)
|
||||||
|
|
||||||
## Phase 2: Testing Concurrency
|
## Phase 2: Testing Concurrency
|
||||||
- [ ] Task: Write Concurrent Execution Test
|
- [x] Task: Write Concurrent Execution Test (684a6d1)
|
||||||
- [ ] WHERE: `tests/test_ai_client_concurrency.py` (New)
|
- [x] WHERE: `tests/test_ai_client_concurrency.py` (New)
|
||||||
- [ ] WHAT: Spawn two threads. Thread A sets Tier 3 and calls a mock `send`. Thread B sets Tier 4 and calls mock `send`.
|
- [x] WHAT: Spawn two threads. Thread A sets Tier 3 and calls a mock `send`. Thread B sets Tier 4 and calls mock `send`.
|
||||||
- [ ] HOW: Assert that the resulting `comms_log` correctly maps the entries to Tier 3 and Tier 4 respectively without race condition overwrites.
|
- [x] HOW: Assert that the resulting `comms_log` correctly maps the entries to Tier 3 and Tier 4 respectively without race condition overwrites.
|
||||||
- [ ] SAFETY: Use `threading.Barrier` to force race conditions in the test to ensure the isolation holds.
|
- [x] SAFETY: Use `threading.Barrier` to force race conditions in the test to ensure the isolation holds.
|
||||||
- [ ] Task: Conductor - User Manual Verification 'Phase 2: Testing Concurrency' (Protocol in workflow.md)
|
- [x] Task: Conductor - User Manual Verification 'Phase 2: Testing Concurrency' (Protocol in workflow.md)
|
||||||
|
|
||||||
## Phase 3: Final Validation
|
## Phase 3: Final Validation
|
||||||
- [ ] Task: Full Suite Validation & Warning Cleanup
|
- [x] Task: Full Suite Validation & Warning Cleanup (684a6d1)
|
||||||
- [ ] WHERE: Project root
|
- [x] WHERE: Project root
|
||||||
- [ ] WHAT: `uv run pytest`
|
- [x] WHAT: `uv run pytest`
|
||||||
- [ ] HOW: Ensure 100% pass rate.
|
- [x] HOW: Ensure 100% pass rate.
|
||||||
- [ ] SAFETY: None.
|
- [x] SAFETY: None.
|
||||||
- [ ] Task: Conductor - User Manual Verification 'Phase 3: Final Validation' (Protocol in workflow.md)
|
- [x] Task: Conductor - User Manual Verification 'Phase 3: Final Validation' (Protocol in workflow.md)
|
||||||
315
src/ai_client.py
315
src/ai_client.py
@@ -13,6 +13,7 @@ during chat creation to avoid massive history bloat.
|
|||||||
"""
|
"""
|
||||||
# ai_client.py
|
# ai_client.py
|
||||||
import tomllib
|
import tomllib
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
@@ -473,6 +474,84 @@ def _gemini_tool_declaration() -> Optional[types.Tool]:
|
|||||||
))
|
))
|
||||||
return types.Tool(function_declarations=declarations) if declarations else None
|
return types.Tool(function_declarations=declarations) if declarations else None
|
||||||
|
|
||||||
|
async def _execute_tool_calls_concurrently(
|
||||||
|
calls: list[Any],
|
||||||
|
base_dir: str,
|
||||||
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
|
||||||
|
qa_callback: Optional[Callable[[str], str]],
|
||||||
|
r_idx: int,
|
||||||
|
provider: str
|
||||||
|
) -> list[tuple[str, str, str, str]]: # tool_name, call_id, output, original_name
|
||||||
|
"""
|
||||||
|
Executes multiple tool calls concurrently using asyncio.gather.
|
||||||
|
Returns a list of (tool_name, call_id, output, original_name).
|
||||||
|
"""
|
||||||
|
tasks = []
|
||||||
|
for fc in calls:
|
||||||
|
if provider == "gemini":
|
||||||
|
name, args, call_id = fc.name, dict(fc.args), fc.name # Gemini 1.0.0 doesn't have call IDs in types.Part
|
||||||
|
elif provider == "gemini_cli":
|
||||||
|
name, args, call_id = cast(str, fc.get("name")), cast(dict[str, Any], fc.get("args", {})), cast(str, fc.get("id"))
|
||||||
|
elif provider == "anthropic":
|
||||||
|
name, args, call_id = cast(str, getattr(fc, "name")), cast(dict[str, Any], getattr(fc, "input")), cast(str, getattr(fc, "id"))
|
||||||
|
elif provider == "deepseek":
|
||||||
|
tool_info = fc.get("function", {})
|
||||||
|
name = cast(str, tool_info.get("name"))
|
||||||
|
tool_args_str = cast(str, tool_info.get("arguments", "{}"))
|
||||||
|
call_id = cast(str, fc.get("id"))
|
||||||
|
try: args = json.loads(tool_args_str)
|
||||||
|
except: args = {}
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx))
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _execute_single_tool_call_async(
|
||||||
|
name: str,
|
||||||
|
args: dict[str, Any],
|
||||||
|
call_id: str,
|
||||||
|
base_dir: str,
|
||||||
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
|
||||||
|
qa_callback: Optional[Callable[[str], str]],
|
||||||
|
r_idx: int
|
||||||
|
) -> tuple[str, str, str, str]:
|
||||||
|
out = ""
|
||||||
|
tool_executed = False
|
||||||
|
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
||||||
|
|
||||||
|
# Check for run_powershell
|
||||||
|
if name == TOOL_NAME and pre_tool_callback:
|
||||||
|
scr = cast(str, args.get("script", ""))
|
||||||
|
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
||||||
|
# pre_tool_callback is synchronous and might block for HITL
|
||||||
|
res = await asyncio.to_thread(pre_tool_callback, scr, base_dir, qa_callback)
|
||||||
|
if res is None:
|
||||||
|
out = "USER REJECTED: tool execution cancelled"
|
||||||
|
else:
|
||||||
|
out = res
|
||||||
|
tool_executed = True
|
||||||
|
|
||||||
|
if not tool_executed:
|
||||||
|
if name and name in mcp_client.TOOL_NAMES:
|
||||||
|
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
|
||||||
|
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
|
||||||
|
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
|
||||||
|
_res = await asyncio.to_thread(pre_tool_callback, desc, base_dir, qa_callback)
|
||||||
|
out = "USER REJECTED: tool execution cancelled" if _res is None else await mcp_client.async_dispatch(name, args)
|
||||||
|
else:
|
||||||
|
out = await mcp_client.async_dispatch(name, args)
|
||||||
|
elif name == TOOL_NAME:
|
||||||
|
scr = cast(str, args.get("script", ""))
|
||||||
|
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
||||||
|
out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback)
|
||||||
|
else:
|
||||||
|
out = f"ERROR: unknown tool '{name}'"
|
||||||
|
|
||||||
|
return (name, call_id, out, name)
|
||||||
|
|
||||||
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str:
|
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str:
|
||||||
if confirm_and_run_callback is None:
|
if confirm_and_run_callback is None:
|
||||||
return "ERROR: no confirmation handler registered"
|
return "ERROR: no confirmation handler registered"
|
||||||
@@ -762,48 +841,33 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
|||||||
if not calls or r_idx > MAX_TOOL_ROUNDS: break
|
if not calls or r_idx > MAX_TOOL_ROUNDS: break
|
||||||
f_resps: list[types.Part] = []
|
f_resps: list[types.Part] = []
|
||||||
log: list[dict[str, Any]] = []
|
log: list[dict[str, Any]] = []
|
||||||
for i, fc in enumerate(calls):
|
|
||||||
name, args = fc.name, dict(fc.args)
|
# Execute tools concurrently
|
||||||
out = ""
|
try:
|
||||||
tool_executed = False
|
loop = asyncio.get_running_loop()
|
||||||
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
results = asyncio.run_coroutine_threadsafe(
|
||||||
if name == TOOL_NAME and pre_tool_callback:
|
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"),
|
||||||
scr = cast(str, args.get("script", ""))
|
loop
|
||||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr})
|
).result()
|
||||||
res = pre_tool_callback(scr, base_dir, qa_callback)
|
except RuntimeError:
|
||||||
if res is None:
|
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"))
|
||||||
out = "USER REJECTED: tool execution cancelled"
|
|
||||||
else:
|
for i, (name, call_id, out, _) in enumerate(results):
|
||||||
out = res
|
# Check if this is the last tool to trigger file refresh
|
||||||
tool_executed = True
|
if i == len(results) - 1:
|
||||||
|
|
||||||
if not tool_executed:
|
|
||||||
if name and name in mcp_client.TOOL_NAMES:
|
|
||||||
_append_comms("OUT", "tool_call", {"name": name, "args": args})
|
|
||||||
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
|
|
||||||
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
|
|
||||||
_res = pre_tool_callback(desc, base_dir, qa_callback)
|
|
||||||
out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args)
|
|
||||||
else:
|
|
||||||
out = mcp_client.dispatch(cast(str, name), args)
|
|
||||||
elif name == TOOL_NAME:
|
|
||||||
scr = cast(str, args.get("script", ""))
|
|
||||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr})
|
|
||||||
out = _run_script(scr, base_dir, qa_callback)
|
|
||||||
else: out = f"ERROR: unknown tool '{name}'"
|
|
||||||
|
|
||||||
if i == len(calls) - 1:
|
|
||||||
if file_items:
|
if file_items:
|
||||||
file_items, changed = _reread_file_items(file_items)
|
file_items, changed = _reread_file_items(file_items)
|
||||||
ctx = _build_file_diff_text(changed)
|
ctx = _build_file_diff_text(changed)
|
||||||
if ctx:
|
if ctx:
|
||||||
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
||||||
if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
||||||
|
|
||||||
out = _truncate_tool_output(out)
|
out = _truncate_tool_output(out)
|
||||||
_cumulative_tool_bytes += len(out)
|
_cumulative_tool_bytes += len(out)
|
||||||
f_resps.append(types.Part(function_response=types.FunctionResponse(name=cast(str, name), response={"output": out})))
|
f_resps.append(types.Part(function_response=types.FunctionResponse(name=cast(str, name), response={"output": out})))
|
||||||
log.append({"tool_use_id": name, "content": out})
|
log.append({"tool_use_id": name, "content": out})
|
||||||
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
||||||
|
|
||||||
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||||
f_resps.append(types.Part(text=
|
f_resps.append(types.Part(text=
|
||||||
f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
|
f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
|
||||||
@@ -877,42 +941,21 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
|||||||
})
|
})
|
||||||
if not calls or r_idx > MAX_TOOL_ROUNDS:
|
if not calls or r_idx > MAX_TOOL_ROUNDS:
|
||||||
break
|
break
|
||||||
tool_results_for_cli: list[dict[str, Any]] = []
|
|
||||||
for i, fc in enumerate(calls):
|
# Execute tools concurrently
|
||||||
name = cast(str, fc.get("name"))
|
try:
|
||||||
args = cast(dict[str, Any], fc.get("args", {}))
|
loop = asyncio.get_running_loop()
|
||||||
call_id = cast(str, fc.get("id"))
|
results = asyncio.run_coroutine_threadsafe(
|
||||||
out = ""
|
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"),
|
||||||
tool_executed = False
|
loop
|
||||||
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
|
).result()
|
||||||
if name == TOOL_NAME and pre_tool_callback:
|
except RuntimeError:
|
||||||
scr = cast(str, args.get("script", ""))
|
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"))
|
||||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
|
||||||
res = pre_tool_callback(scr, base_dir, qa_callback)
|
|
||||||
if res is None:
|
|
||||||
out = "USER REJECTED: tool execution cancelled"
|
|
||||||
else:
|
|
||||||
out = res
|
|
||||||
tool_executed = True
|
|
||||||
|
|
||||||
if not tool_executed:
|
|
||||||
if name and name in mcp_client.TOOL_NAMES:
|
|
||||||
|
|
||||||
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
|
tool_results_for_cli: list[dict[str, Any]] = []
|
||||||
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
|
for i, (name, call_id, out, _) in enumerate(results):
|
||||||
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
|
# Check if this is the last tool to trigger file refresh
|
||||||
_res = pre_tool_callback(desc, base_dir, qa_callback)
|
if i == len(results) - 1:
|
||||||
out = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(name, args)
|
|
||||||
else:
|
|
||||||
out = mcp_client.dispatch(name, args)
|
|
||||||
elif name == TOOL_NAME:
|
|
||||||
scr = cast(str, args.get("script", ""))
|
|
||||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
|
||||||
out = _run_script(scr, base_dir, qa_callback)
|
|
||||||
else:
|
|
||||||
out = f"ERROR: unknown tool '{name}'"
|
|
||||||
|
|
||||||
if i == len(calls) - 1:
|
|
||||||
if file_items:
|
if file_items:
|
||||||
file_items, changed = _reread_file_items(file_items)
|
file_items, changed = _reread_file_items(file_items)
|
||||||
ctx = _build_file_diff_text(changed)
|
ctx = _build_file_diff_text(changed)
|
||||||
@@ -920,6 +963,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
|||||||
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
||||||
if r_idx == MAX_TOOL_ROUNDS:
|
if r_idx == MAX_TOOL_ROUNDS:
|
||||||
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
||||||
|
|
||||||
out = _truncate_tool_output(out)
|
out = _truncate_tool_output(out)
|
||||||
_cumulative_tool_bytes += len(out)
|
_cumulative_tool_bytes += len(out)
|
||||||
tool_results_for_cli.append({
|
tool_results_for_cli.append({
|
||||||
@@ -930,6 +974,7 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
|||||||
})
|
})
|
||||||
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
||||||
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
||||||
|
|
||||||
payload = tool_results_for_cli
|
payload = tool_results_for_cli
|
||||||
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||||
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
|
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
|
||||||
@@ -1217,61 +1262,29 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
|
|||||||
break
|
break
|
||||||
if round_idx > MAX_TOOL_ROUNDS:
|
if round_idx > MAX_TOOL_ROUNDS:
|
||||||
break
|
break
|
||||||
tool_results: list[dict[str, Any]] = []
|
|
||||||
for block in response.content:
|
# Execute tools concurrently
|
||||||
if getattr(block, "type", None) != "tool_use":
|
try:
|
||||||
continue
|
loop = asyncio.get_running_loop()
|
||||||
b_name = cast(str, getattr(block, "name"))
|
results = asyncio.run_coroutine_threadsafe(
|
||||||
b_id = cast(str, getattr(block, "id"))
|
_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic"),
|
||||||
b_input = cast(dict[str, Any], getattr(block, "input"))
|
loop
|
||||||
output = ""
|
).result()
|
||||||
tool_executed = False
|
except RuntimeError:
|
||||||
events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx})
|
results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic"))
|
||||||
if b_name == TOOL_NAME and pre_tool_callback:
|
|
||||||
script = cast(str, b_input.get("script", ""))
|
|
||||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": b_id, "script": script})
|
|
||||||
res = pre_tool_callback(script, base_dir, qa_callback)
|
|
||||||
if res is None:
|
|
||||||
output = "USER REJECTED: tool execution cancelled"
|
|
||||||
else:
|
|
||||||
output = res
|
|
||||||
tool_executed = True
|
|
||||||
|
|
||||||
if not tool_executed:
|
|
||||||
if b_name and b_name in mcp_client.TOOL_NAMES:
|
|
||||||
|
|
||||||
_append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input})
|
tool_results: list[dict[str, Any]] = []
|
||||||
if b_name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
|
for i, (name, call_id, out, _) in enumerate(results):
|
||||||
desc = f"# MCP MUTATING TOOL: {b_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in b_input.items())
|
truncated = _truncate_tool_output(out)
|
||||||
_res = pre_tool_callback(desc, base_dir, qa_callback)
|
|
||||||
output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(b_name, b_input)
|
|
||||||
else:
|
|
||||||
output = mcp_client.dispatch(b_name, b_input)
|
|
||||||
_append_comms("IN", "tool_result", {"name": b_name, "id": b_id, "output": output})
|
|
||||||
elif b_name == TOOL_NAME:
|
|
||||||
script = cast(str, b_input.get("script", ""))
|
|
||||||
_append_comms("OUT", "tool_call", {
|
|
||||||
"name": TOOL_NAME,
|
|
||||||
"id": b_id,
|
|
||||||
"script": script,
|
|
||||||
})
|
|
||||||
output = _run_script(script, base_dir, qa_callback)
|
|
||||||
_append_comms("IN", "tool_result", {
|
|
||||||
"name": TOOL_NAME,
|
|
||||||
"id": b_id,
|
|
||||||
"output": output,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
output = f"ERROR: unknown tool '{b_name}'"
|
|
||||||
|
|
||||||
truncated = _truncate_tool_output(output)
|
|
||||||
_cumulative_tool_bytes += len(truncated)
|
_cumulative_tool_bytes += len(truncated)
|
||||||
tool_results.append({
|
tool_results.append({
|
||||||
"type": "tool_result",
|
"type": "tool_result",
|
||||||
"tool_use_id": b_id,
|
"tool_use_id": call_id,
|
||||||
"content": truncated,
|
"content": truncated,
|
||||||
})
|
})
|
||||||
events.emit("tool_execution", payload={"status": "completed", "tool": b_name, "result": output, "round": round_idx})
|
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
||||||
|
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx})
|
||||||
|
|
||||||
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||||
tool_results.append({
|
tool_results.append({
|
||||||
"type": "text",
|
"type": "text",
|
||||||
@@ -1448,62 +1461,39 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
|||||||
break
|
break
|
||||||
if round_idx > MAX_TOOL_ROUNDS:
|
if round_idx > MAX_TOOL_ROUNDS:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Execute tools concurrently
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
results = asyncio.run_coroutine_threadsafe(
|
||||||
|
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"),
|
||||||
|
loop
|
||||||
|
).result()
|
||||||
|
except RuntimeError:
|
||||||
|
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"))
|
||||||
|
|
||||||
tool_results_for_history: list[dict[str, Any]] = []
|
tool_results_for_history: list[dict[str, Any]] = []
|
||||||
for i, tc_raw in enumerate(tool_calls_raw):
|
for i, (name, call_id, out, _) in enumerate(results):
|
||||||
tool_info = tc_raw.get("function", {})
|
if i == len(results) - 1:
|
||||||
tool_name = cast(str, tool_info.get("name"))
|
|
||||||
tool_args_str = cast(str, tool_info.get("arguments", "{}"))
|
|
||||||
tool_id = cast(str, tc_raw.get("id"))
|
|
||||||
try:
|
|
||||||
tool_args = json.loads(tool_args_str)
|
|
||||||
except:
|
|
||||||
tool_args = {}
|
|
||||||
tool_output = ""
|
|
||||||
tool_executed = False
|
|
||||||
events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx})
|
|
||||||
if tool_name == TOOL_NAME and pre_tool_callback:
|
|
||||||
script = cast(str, tool_args.get("script", ""))
|
|
||||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script})
|
|
||||||
res = pre_tool_callback(script, base_dir, qa_callback)
|
|
||||||
if res is None:
|
|
||||||
tool_output = "USER REJECTED: tool execution cancelled"
|
|
||||||
else:
|
|
||||||
tool_output = res
|
|
||||||
tool_executed = True
|
|
||||||
|
|
||||||
if not tool_executed:
|
|
||||||
if tool_name in mcp_client.TOOL_NAMES:
|
|
||||||
_append_comms("OUT", "tool_call", {"name": tool_name, "id": tool_id, "args": tool_args})
|
|
||||||
if tool_name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
|
|
||||||
desc = f"# MCP MUTATING TOOL: {tool_name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in tool_args.items())
|
|
||||||
_res = pre_tool_callback(desc, base_dir, qa_callback)
|
|
||||||
tool_output = "USER REJECTED: tool execution cancelled" if _res is None else mcp_client.dispatch(tool_name, tool_args)
|
|
||||||
else:
|
|
||||||
tool_output = mcp_client.dispatch(tool_name, tool_args)
|
|
||||||
elif tool_name == TOOL_NAME:
|
|
||||||
script = cast(str, tool_args.get("script", ""))
|
|
||||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script})
|
|
||||||
tool_output = _run_script(script, base_dir, qa_callback)
|
|
||||||
else:
|
|
||||||
tool_output = f"ERROR: unknown tool '{tool_name}'"
|
|
||||||
|
|
||||||
if i == len(tool_calls_raw) - 1:
|
|
||||||
if file_items:
|
if file_items:
|
||||||
file_items, changed = _reread_file_items(file_items)
|
file_items, changed = _reread_file_items(file_items)
|
||||||
ctx = _build_file_diff_text(changed)
|
ctx = _build_file_diff_text(changed)
|
||||||
if ctx:
|
if ctx:
|
||||||
tool_output += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
|
||||||
if round_idx == MAX_TOOL_ROUNDS:
|
if round_idx == MAX_TOOL_ROUNDS:
|
||||||
tool_output += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
||||||
tool_output = _truncate_tool_output(tool_output)
|
|
||||||
_cumulative_tool_bytes += len(tool_output)
|
truncated = _truncate_tool_output(out)
|
||||||
|
_cumulative_tool_bytes += len(truncated)
|
||||||
tool_results_for_history.append({
|
tool_results_for_history.append({
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": tool_id,
|
"tool_call_id": call_id,
|
||||||
"content": tool_output,
|
"name": name,
|
||||||
|
"content": truncated,
|
||||||
})
|
})
|
||||||
_append_comms("IN", "tool_result", {"name": tool_name, "id": tool_id, "output": tool_output})
|
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
||||||
events.emit("tool_execution", payload={"status": "completed", "tool": tool_name, "result": tool_output, "round": round_idx})
|
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx})
|
||||||
|
|
||||||
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
||||||
tool_results_for_history.append({
|
tool_results_for_history.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@@ -1524,6 +1514,7 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise _classify_deepseek_error(e) from e
|
raise _classify_deepseek_error(e) from e
|
||||||
|
|
||||||
|
|
||||||
def run_tier4_analysis(stderr: str) -> str:
|
def run_tier4_analysis(stderr: str) -> str:
|
||||||
if not stderr or not stderr.strip():
|
if not stderr or not stderr.strip():
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ so the AI doesn't wander outside the project workspace.
|
|||||||
#
|
#
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Callable, Any, cast
|
from typing import Optional, Callable, Any, cast
|
||||||
import os
|
import os
|
||||||
@@ -858,75 +859,90 @@ def get_ui_performance() -> str:
|
|||||||
|
|
||||||
TOOL_NAMES: set[str] = {"read_file", "list_directory", "search_files", "get_file_summary", "py_get_skeleton", "py_get_code_outline", "py_get_definition", "get_git_diff", "web_search", "fetch_url", "get_ui_performance", "get_file_slice", "set_file_slice", "edit_file", "py_update_definition", "py_get_signature", "py_set_signature", "py_get_class_summary", "py_get_var_declaration", "py_set_var_declaration", "py_find_usages", "py_get_imports", "py_check_syntax", "py_get_hierarchy", "py_get_docstring", "get_tree"}
|
TOOL_NAMES: set[str] = {"read_file", "list_directory", "search_files", "get_file_summary", "py_get_skeleton", "py_get_code_outline", "py_get_definition", "get_git_diff", "web_search", "fetch_url", "get_ui_performance", "get_file_slice", "set_file_slice", "edit_file", "py_update_definition", "py_get_signature", "py_set_signature", "py_get_class_summary", "py_get_var_declaration", "py_set_var_declaration", "py_find_usages", "py_get_imports", "py_check_syntax", "py_get_hierarchy", "py_get_docstring", "get_tree"}
|
||||||
|
|
||||||
def dispatch(tool_name: str, tool_input: dict[str, Any]) -> str:
|
async def async_dispatch(tool_name: str, tool_input: dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Dispatch an MCP tool call by name asynchronously. Returns the result as a string.
|
||||||
"""
|
"""
|
||||||
Dispatch an MCP tool call by name. Returns the result as a string.
|
|
||||||
"""
|
|
||||||
# Handle aliases
|
# Handle aliases
|
||||||
path = str(tool_input.get("path", tool_input.get("file_path", tool_input.get("dir_path", ""))))
|
path = str(tool_input.get("path", tool_input.get("file_path", tool_input.get("dir_path", ""))))
|
||||||
if tool_name == "read_file":
|
if tool_name == "read_file":
|
||||||
return read_file(path)
|
return await asyncio.to_thread(read_file, path)
|
||||||
if tool_name == "list_directory":
|
if tool_name == "list_directory":
|
||||||
return list_directory(path)
|
return await asyncio.to_thread(list_directory, path)
|
||||||
if tool_name == "search_files":
|
if tool_name == "search_files":
|
||||||
return search_files(path, str(tool_input.get("pattern", "*")))
|
return await asyncio.to_thread(search_files, path, str(tool_input.get("pattern", "*")))
|
||||||
if tool_name == "get_file_summary":
|
if tool_name == "get_file_summary":
|
||||||
return get_file_summary(path)
|
return await asyncio.to_thread(get_file_summary, path)
|
||||||
if tool_name == "py_get_skeleton":
|
if tool_name == "py_get_skeleton":
|
||||||
return py_get_skeleton(path)
|
return await asyncio.to_thread(py_get_skeleton, path)
|
||||||
if tool_name == "py_get_code_outline":
|
if tool_name == "py_get_code_outline":
|
||||||
return py_get_code_outline(path)
|
return await asyncio.to_thread(py_get_code_outline, path)
|
||||||
if tool_name == "py_get_definition":
|
if tool_name == "py_get_definition":
|
||||||
return py_get_definition(path, str(tool_input.get("name", "")))
|
return await asyncio.to_thread(py_get_definition, path, str(tool_input.get("name", "")))
|
||||||
if tool_name == "py_update_definition":
|
if tool_name == "py_update_definition":
|
||||||
return py_update_definition(path, str(tool_input.get("name", "")), str(tool_input.get("new_content", "")))
|
return await asyncio.to_thread(py_update_definition, path, str(tool_input.get("name", "")), str(tool_input.get("new_content", "")))
|
||||||
if tool_name == "py_get_signature":
|
if tool_name == "py_get_signature":
|
||||||
return py_get_signature(path, str(tool_input.get("name", "")))
|
return await asyncio.to_thread(py_get_signature, path, str(tool_input.get("name", "")))
|
||||||
if tool_name == "py_set_signature":
|
if tool_name == "py_set_signature":
|
||||||
return py_set_signature(path, str(tool_input.get("name", "")), str(tool_input.get("new_signature", "")))
|
return await asyncio.to_thread(py_set_signature, path, str(tool_input.get("name", "")), str(tool_input.get("new_signature", "")))
|
||||||
if tool_name == "py_get_class_summary":
|
if tool_name == "py_get_class_summary":
|
||||||
return py_get_class_summary(path, str(tool_input.get("name", "")))
|
return await asyncio.to_thread(py_get_class_summary, path, str(tool_input.get("name", "")))
|
||||||
if tool_name == "py_get_var_declaration":
|
if tool_name == "py_get_var_declaration":
|
||||||
return py_get_var_declaration(path, str(tool_input.get("name", "")))
|
return await asyncio.to_thread(py_get_var_declaration, path, str(tool_input.get("name", "")))
|
||||||
if tool_name == "py_set_var_declaration":
|
if tool_name == "py_set_var_declaration":
|
||||||
return py_set_var_declaration(path, str(tool_input.get("name", "")), str(tool_input.get("new_declaration", "")))
|
return await asyncio.to_thread(py_set_var_declaration, path, str(tool_input.get("name", "")), str(tool_input.get("new_declaration", "")))
|
||||||
if tool_name == "get_file_slice":
|
if tool_name == "get_file_slice":
|
||||||
return get_file_slice(path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1)))
|
return await asyncio.to_thread(get_file_slice, path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1)))
|
||||||
if tool_name == "set_file_slice":
|
if tool_name == "set_file_slice":
|
||||||
return set_file_slice(path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1)), str(tool_input.get("new_content", "")))
|
return await asyncio.to_thread(set_file_slice, path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1)), str(tool_input.get("new_content", "")))
|
||||||
if tool_name == "get_git_diff":
|
if tool_name == "get_git_diff":
|
||||||
return get_git_diff(
|
return await asyncio.to_thread(get_git_diff,
|
||||||
path,
|
path,
|
||||||
str(tool_input.get("base_rev", "HEAD")),
|
str(tool_input.get("base_rev", "HEAD")),
|
||||||
str(tool_input.get("head_rev", ""))
|
str(tool_input.get("head_rev", ""))
|
||||||
)
|
)
|
||||||
if tool_name == "edit_file":
|
if tool_name == "edit_file":
|
||||||
return edit_file(
|
return await asyncio.to_thread(edit_file,
|
||||||
path,
|
path,
|
||||||
str(tool_input.get("old_string", "")),
|
str(tool_input.get("old_string", "")),
|
||||||
str(tool_input.get("new_string", "")),
|
str(tool_input.get("new_string", "")),
|
||||||
bool(tool_input.get("replace_all", False))
|
bool(tool_input.get("replace_all", False))
|
||||||
)
|
)
|
||||||
if tool_name == "web_search":
|
if tool_name == "web_search":
|
||||||
return web_search(str(tool_input.get("query", "")))
|
return await asyncio.to_thread(web_search, str(tool_input.get("query", "")))
|
||||||
if tool_name == "fetch_url":
|
if tool_name == "fetch_url":
|
||||||
return fetch_url(str(tool_input.get("url", "")))
|
return await asyncio.to_thread(fetch_url, str(tool_input.get("url", "")))
|
||||||
if tool_name == "get_ui_performance":
|
if tool_name == "get_ui_performance":
|
||||||
return get_ui_performance()
|
return await asyncio.to_thread(get_ui_performance)
|
||||||
if tool_name == "py_find_usages":
|
if tool_name == "py_find_usages":
|
||||||
return py_find_usages(path, str(tool_input.get("name", "")))
|
return await asyncio.to_thread(py_find_usages, path, str(tool_input.get("name", "")))
|
||||||
if tool_name == "py_get_imports":
|
if tool_name == "py_get_imports":
|
||||||
return py_get_imports(path)
|
return await asyncio.to_thread(py_get_imports, path)
|
||||||
if tool_name == "py_check_syntax":
|
if tool_name == "py_check_syntax":
|
||||||
return py_check_syntax(path)
|
return await asyncio.to_thread(py_check_syntax, path)
|
||||||
if tool_name == "py_get_hierarchy":
|
if tool_name == "py_get_hierarchy":
|
||||||
return py_get_hierarchy(path, str(tool_input.get("class_name", "")))
|
return await asyncio.to_thread(py_get_hierarchy, path, str(tool_input.get("class_name", "")))
|
||||||
if tool_name == "py_get_docstring":
|
if tool_name == "py_get_docstring":
|
||||||
return py_get_docstring(path, str(tool_input.get("name", "")))
|
return await asyncio.to_thread(py_get_docstring, path, str(tool_input.get("name", "")))
|
||||||
if tool_name == "get_tree":
|
if tool_name == "get_tree":
|
||||||
return get_tree(path, int(tool_input.get("max_depth", 2)))
|
return await asyncio.to_thread(get_tree, path, int(tool_input.get("max_depth", 2)))
|
||||||
return f"ERROR: unknown MCP tool '{tool_name}'"
|
return f"ERROR: unknown MCP tool '{tool_name}'"
|
||||||
|
|
||||||
|
def dispatch(tool_name: str, tool_input: dict[str, Any]) -> str:
|
||||||
|
"""
|
||||||
|
Dispatch an MCP tool call by name. Returns the result as a string.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
# If we are in a running loop, we can't use asyncio.run
|
||||||
|
# But we are in a synchronous function.
|
||||||
|
# This is tricky. If we are in a thread, we might not have a loop.
|
||||||
|
return asyncio.run_coroutine_threadsafe(async_dispatch(tool_name, tool_input), loop).result()
|
||||||
|
except RuntimeError:
|
||||||
|
# No running loop, use asyncio.run
|
||||||
|
return asyncio.run(async_dispatch(tool_name, tool_input))
|
||||||
|
|
||||||
|
|
||||||
def get_tool_schemas() -> list[dict[str, Any]]:
|
def get_tool_schemas() -> list[dict[str, Any]]:
|
||||||
"""Returns the list of tool specifications for the AI."""
|
"""Returns the list of tool specifications for the AI."""
|
||||||
return list(MCP_TOOL_SPECS)
|
return list(MCP_TOOL_SPECS)
|
||||||
|
|||||||
Reference in New Issue
Block a user