feat(mcp_client): Add async_dispatch and support for concurrent tool execution.

This commit is contained in:
2026-03-06 13:11:48 -05:00
parent a960f3b3d0
commit 60e1dce2b6

View File

@@ -30,6 +30,7 @@ so the AI doesn't wander outside the project workspace.
# #
from __future__ import annotations from __future__ import annotations
import asyncio
from pathlib import Path from pathlib import Path
from typing import Optional, Callable, Any, cast from typing import Optional, Callable, Any, cast
import os import os
@@ -858,75 +859,90 @@ def get_ui_performance() -> str:
TOOL_NAMES: set[str] = {"read_file", "list_directory", "search_files", "get_file_summary", "py_get_skeleton", "py_get_code_outline", "py_get_definition", "get_git_diff", "web_search", "fetch_url", "get_ui_performance", "get_file_slice", "set_file_slice", "edit_file", "py_update_definition", "py_get_signature", "py_set_signature", "py_get_class_summary", "py_get_var_declaration", "py_set_var_declaration", "py_find_usages", "py_get_imports", "py_check_syntax", "py_get_hierarchy", "py_get_docstring", "get_tree"} TOOL_NAMES: set[str] = {"read_file", "list_directory", "search_files", "get_file_summary", "py_get_skeleton", "py_get_code_outline", "py_get_definition", "get_git_diff", "web_search", "fetch_url", "get_ui_performance", "get_file_slice", "set_file_slice", "edit_file", "py_update_definition", "py_get_signature", "py_set_signature", "py_get_class_summary", "py_get_var_declaration", "py_set_var_declaration", "py_find_usages", "py_get_imports", "py_check_syntax", "py_get_hierarchy", "py_get_docstring", "get_tree"}
def dispatch(tool_name: str, tool_input: dict[str, Any]) -> str: async def async_dispatch(tool_name: str, tool_input: dict[str, Any]) -> str:
"""
Dispatch an MCP tool call by name asynchronously. Returns the result as a string.
""" """
Dispatch an MCP tool call by name. Returns the result as a string.
"""
# Handle aliases # Handle aliases
path = str(tool_input.get("path", tool_input.get("file_path", tool_input.get("dir_path", "")))) path = str(tool_input.get("path", tool_input.get("file_path", tool_input.get("dir_path", ""))))
if tool_name == "read_file": if tool_name == "read_file":
return read_file(path) return await asyncio.to_thread(read_file, path)
if tool_name == "list_directory": if tool_name == "list_directory":
return list_directory(path) return await asyncio.to_thread(list_directory, path)
if tool_name == "search_files": if tool_name == "search_files":
return search_files(path, str(tool_input.get("pattern", "*"))) return await asyncio.to_thread(search_files, path, str(tool_input.get("pattern", "*")))
if tool_name == "get_file_summary": if tool_name == "get_file_summary":
return get_file_summary(path) return await asyncio.to_thread(get_file_summary, path)
if tool_name == "py_get_skeleton": if tool_name == "py_get_skeleton":
return py_get_skeleton(path) return await asyncio.to_thread(py_get_skeleton, path)
if tool_name == "py_get_code_outline": if tool_name == "py_get_code_outline":
return py_get_code_outline(path) return await asyncio.to_thread(py_get_code_outline, path)
if tool_name == "py_get_definition": if tool_name == "py_get_definition":
return py_get_definition(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_get_definition, path, str(tool_input.get("name", "")))
if tool_name == "py_update_definition": if tool_name == "py_update_definition":
return py_update_definition(path, str(tool_input.get("name", "")), str(tool_input.get("new_content", ""))) return await asyncio.to_thread(py_update_definition, path, str(tool_input.get("name", "")), str(tool_input.get("new_content", "")))
if tool_name == "py_get_signature": if tool_name == "py_get_signature":
return py_get_signature(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_get_signature, path, str(tool_input.get("name", "")))
if tool_name == "py_set_signature": if tool_name == "py_set_signature":
return py_set_signature(path, str(tool_input.get("name", "")), str(tool_input.get("new_signature", ""))) return await asyncio.to_thread(py_set_signature, path, str(tool_input.get("name", "")), str(tool_input.get("new_signature", "")))
if tool_name == "py_get_class_summary": if tool_name == "py_get_class_summary":
return py_get_class_summary(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_get_class_summary, path, str(tool_input.get("name", "")))
if tool_name == "py_get_var_declaration": if tool_name == "py_get_var_declaration":
return py_get_var_declaration(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_get_var_declaration, path, str(tool_input.get("name", "")))
if tool_name == "py_set_var_declaration": if tool_name == "py_set_var_declaration":
return py_set_var_declaration(path, str(tool_input.get("name", "")), str(tool_input.get("new_declaration", ""))) return await asyncio.to_thread(py_set_var_declaration, path, str(tool_input.get("name", "")), str(tool_input.get("new_declaration", "")))
if tool_name == "get_file_slice": if tool_name == "get_file_slice":
return get_file_slice(path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1))) return await asyncio.to_thread(get_file_slice, path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1)))
if tool_name == "set_file_slice": if tool_name == "set_file_slice":
return set_file_slice(path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1)), str(tool_input.get("new_content", ""))) return await asyncio.to_thread(set_file_slice, path, int(tool_input.get("start_line", 1)), int(tool_input.get("end_line", 1)), str(tool_input.get("new_content", "")))
if tool_name == "get_git_diff": if tool_name == "get_git_diff":
return get_git_diff( return await asyncio.to_thread(get_git_diff,
path, path,
str(tool_input.get("base_rev", "HEAD")), str(tool_input.get("base_rev", "HEAD")),
str(tool_input.get("head_rev", "")) str(tool_input.get("head_rev", ""))
) )
if tool_name == "edit_file": if tool_name == "edit_file":
return edit_file( return await asyncio.to_thread(edit_file,
path, path,
str(tool_input.get("old_string", "")), str(tool_input.get("old_string", "")),
str(tool_input.get("new_string", "")), str(tool_input.get("new_string", "")),
bool(tool_input.get("replace_all", False)) bool(tool_input.get("replace_all", False))
) )
if tool_name == "web_search": if tool_name == "web_search":
return web_search(str(tool_input.get("query", ""))) return await asyncio.to_thread(web_search, str(tool_input.get("query", "")))
if tool_name == "fetch_url": if tool_name == "fetch_url":
return fetch_url(str(tool_input.get("url", ""))) return await asyncio.to_thread(fetch_url, str(tool_input.get("url", "")))
if tool_name == "get_ui_performance": if tool_name == "get_ui_performance":
return get_ui_performance() return await asyncio.to_thread(get_ui_performance)
if tool_name == "py_find_usages": if tool_name == "py_find_usages":
return py_find_usages(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_find_usages, path, str(tool_input.get("name", "")))
if tool_name == "py_get_imports": if tool_name == "py_get_imports":
return py_get_imports(path) return await asyncio.to_thread(py_get_imports, path)
if tool_name == "py_check_syntax": if tool_name == "py_check_syntax":
return py_check_syntax(path) return await asyncio.to_thread(py_check_syntax, path)
if tool_name == "py_get_hierarchy": if tool_name == "py_get_hierarchy":
return py_get_hierarchy(path, str(tool_input.get("class_name", ""))) return await asyncio.to_thread(py_get_hierarchy, path, str(tool_input.get("class_name", "")))
if tool_name == "py_get_docstring": if tool_name == "py_get_docstring":
return py_get_docstring(path, str(tool_input.get("name", ""))) return await asyncio.to_thread(py_get_docstring, path, str(tool_input.get("name", "")))
if tool_name == "get_tree": if tool_name == "get_tree":
return get_tree(path, int(tool_input.get("max_depth", 2))) return await asyncio.to_thread(get_tree, path, int(tool_input.get("max_depth", 2)))
return f"ERROR: unknown MCP tool '{tool_name}'" return f"ERROR: unknown MCP tool '{tool_name}'"
def dispatch(tool_name: str, tool_input: dict[str, Any]) -> str:
"""
Dispatch an MCP tool call by name. Returns the result as a string.
"""
try:
loop = asyncio.get_running_loop()
# If we are in a running loop, we can't use asyncio.run
# But we are in a synchronous function.
# This is tricky. If we are in a thread, we might not have a loop.
return asyncio.run_coroutine_threadsafe(async_dispatch(tool_name, tool_input), loop).result()
except RuntimeError:
# No running loop, use asyncio.run
return asyncio.run(async_dispatch(tool_name, tool_input))
def get_tool_schemas() -> list[dict[str, Any]]: def get_tool_schemas() -> list[dict[str, Any]]:
"""Returns the list of tool specifications for the AI.""" """Returns the list of tool specifications for the AI."""
return list(MCP_TOOL_SPECS) return list(MCP_TOOL_SPECS)