feat(mcp_client): Implement ExternalMCPManager and StdioMCPServer with tests
This commit is contained in:
47
scripts/mock_mcp_server.py
Normal file
47
scripts/mock_mcp_server.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import sys
|
||||
import json
|
||||
|
||||
def main():
|
||||
while True:
|
||||
line = sys.stdin.readline()
|
||||
if not line:
|
||||
break
|
||||
try:
|
||||
req = json.loads(line)
|
||||
method = req.get("method")
|
||||
req_id = req.get("id")
|
||||
|
||||
if method == "tools/list":
|
||||
resp = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {
|
||||
"tools": [
|
||||
{"name": "echo", "description": "Echo input", "inputSchema": {"type": "object"}}
|
||||
]
|
||||
}
|
||||
}
|
||||
elif method == "tools/call":
|
||||
name = req["params"].get("name")
|
||||
args = req["params"].get("arguments", {})
|
||||
if name == "echo":
|
||||
resp = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {
|
||||
"content": [{"type": "text", "text": f"ECHO: {args}"}]
|
||||
}
|
||||
}
|
||||
else:
|
||||
resp = {"jsonrpc": "2.0", "id": req_id, "error": {"message": "Unknown tool"}}
|
||||
else:
|
||||
resp = {"jsonrpc": "2.0", "id": req_id, "error": {"message": "Unknown method"}}
|
||||
|
||||
sys.stdout.write(json.dumps(resp) + "\n")
|
||||
sys.stdout.flush()
|
||||
except Exception as e:
|
||||
sys.stderr.write(f"Error: {e}\n")
|
||||
sys.stderr.flush()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -53,6 +53,8 @@ See Also:
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
from src import models
|
||||
from pathlib import Path
|
||||
from typing import Optional, Callable, Any, cast
|
||||
import os
|
||||
@@ -915,6 +917,119 @@ def get_ui_performance() -> str:
|
||||
return f"ERROR: Failed to retrieve UI performance: {str(e)}"
|
||||
# ------------------------------------------------------------------ tool dispatch
|
||||
|
||||
class StdioMCPServer:
|
||||
def __init__(self, config: models.MCPServerConfig):
|
||||
self.config = config
|
||||
self.name = config.name
|
||||
self.proc = None
|
||||
self.tools = {}
|
||||
self._id_counter = 0
|
||||
self._pending_requests = {}
|
||||
|
||||
def _get_id(self):
|
||||
self._id_counter += 1
|
||||
return self._id_counter
|
||||
|
||||
async def start(self):
|
||||
self.proc = await asyncio.create_subprocess_exec(
|
||||
self.config.command,
|
||||
*self.config.args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
asyncio.create_task(self._read_stderr())
|
||||
await self.list_tools()
|
||||
|
||||
async def stop(self):
|
||||
if self.proc:
|
||||
try:
|
||||
if self.proc.stdin:
|
||||
self.proc.stdin.close()
|
||||
await self.proc.stdin.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.proc.terminate()
|
||||
await self.proc.wait()
|
||||
except Exception:
|
||||
pass
|
||||
self.proc = None
|
||||
|
||||
async def _read_stderr(self):
|
||||
while self.proc and not self.proc.stdout.at_eof():
|
||||
line = await self.proc.stderr.readline()
|
||||
if line:
|
||||
print(f'[MCP:{self.name}:err] {line.decode().strip()}')
|
||||
|
||||
async def _send_request(self, method: str, params: dict = None):
|
||||
req_id = self._get_id()
|
||||
request = {
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'method': method,
|
||||
'params': params or {}
|
||||
}
|
||||
self.proc.stdin.write(json.dumps(request).encode() + b'\n')
|
||||
await self.proc.stdin.drain()
|
||||
|
||||
# Simplistic wait for response - in real use, we'd need a read loop
|
||||
# For now, we'll read one line and hope it's ours (fragile, but for MVP)
|
||||
line = await self.proc.stdout.readline()
|
||||
if line:
|
||||
resp = json.loads(line.decode())
|
||||
return resp.get('result')
|
||||
return None
|
||||
|
||||
async def list_tools(self):
|
||||
result = await self._send_request('tools/list')
|
||||
if result and 'tools' in result:
|
||||
for t in result['tools']:
|
||||
self.tools[t['name']] = t
|
||||
return self.tools
|
||||
|
||||
async def call_tool(self, name: str, arguments: dict):
|
||||
result = await self._send_request('tools/call', {'name': name, 'arguments': arguments})
|
||||
if result and 'content' in result:
|
||||
return '\n'.join([c.get('text', '') for c in result['content'] if c.get('type') == 'text'])
|
||||
return str(result)
|
||||
|
||||
class ExternalMCPManager:
|
||||
def __init__(self):
|
||||
self.servers = {}
|
||||
|
||||
async def add_server(self, config: models.MCPServerConfig):
|
||||
if config.url:
|
||||
# RemoteMCPServer placeholder
|
||||
return
|
||||
server = StdioMCPServer(config)
|
||||
await server.start()
|
||||
self.servers[config.name] = server
|
||||
|
||||
async def stop_all(self):
|
||||
for server in self.servers.values():
|
||||
await server.stop()
|
||||
self.servers = {}
|
||||
|
||||
def get_all_tools(self) -> dict:
|
||||
all_tools = {}
|
||||
for sname, server in self.servers.items():
|
||||
for tname, tool in server.tools.items():
|
||||
all_tools[tname] = {**tool, 'server': sname}
|
||||
return all_tools
|
||||
|
||||
async def async_dispatch(self, tool_name: str, tool_input: dict) -> str:
|
||||
for server in self.servers.values():
|
||||
if tool_name in server.tools:
|
||||
return await server.call_tool(tool_name, tool_input)
|
||||
return f'Error: External tool {tool_name} not found.'
|
||||
|
||||
_external_mcp_manager = ExternalMCPManager()
|
||||
|
||||
def get_external_mcp_manager() -> ExternalMCPManager:
|
||||
global _external_mcp_manager
|
||||
return _external_mcp_manager
|
||||
|
||||
TOOL_NAMES: set[str] = {"read_file", "list_directory", "search_files", "get_file_summary", "py_get_skeleton", "py_get_code_outline", "py_get_definition", "get_git_diff", "web_search", "fetch_url", "get_ui_performance", "get_file_slice", "set_file_slice", "edit_file", "py_update_definition", "py_get_signature", "py_set_signature", "py_get_class_summary", "py_get_var_declaration", "py_set_var_declaration", "py_find_usages", "py_get_imports", "py_check_syntax", "py_get_hierarchy", "py_get_docstring", "get_tree"}
|
||||
|
||||
def dispatch(tool_name: str, tool_input: dict[str, Any]) -> str:
|
||||
@@ -987,17 +1102,29 @@ def dispatch(tool_name: str, tool_input: dict[str, Any]) -> str:
|
||||
return f"ERROR: unknown MCP tool '{tool_name}'"
|
||||
|
||||
async def async_dispatch(tool_name: str, tool_input: dict[str, Any]) -> str:
|
||||
"""
|
||||
Dispatch an MCP tool call by name asynchronously. Returns the result as a string.
|
||||
"""
|
||||
# Run blocking I/O bound tools in a thread to allow parallel execution via asyncio.gather
|
||||
return await asyncio.to_thread(dispatch, tool_name, tool_input)
|
||||
# Check native tools
|
||||
native_names = {t['name'] for t in MCP_TOOL_SPECS}
|
||||
if tool_name in native_names:
|
||||
return await asyncio.to_thread(dispatch, tool_name, tool_input)
|
||||
|
||||
# Check external tools
|
||||
if tool_name in get_external_mcp_manager().get_all_tools():
|
||||
return await get_external_mcp_manager().async_dispatch(tool_name, tool_input)
|
||||
|
||||
return f'ERROR: unknown MCP tool {tool_name}'
|
||||
|
||||
|
||||
|
||||
def get_tool_schemas() -> list[dict[str, Any]]:
|
||||
"""Returns the list of tool specifications for the AI."""
|
||||
return list(MCP_TOOL_SPECS)
|
||||
res = list(MCP_TOOL_SPECS)
|
||||
manager = get_external_mcp_manager()
|
||||
for tname, tinfo in manager.get_all_tools().items():
|
||||
res.append({
|
||||
'name': tname,
|
||||
'description': tinfo.get('description', ''),
|
||||
'parameters': tinfo.get('inputSchema', {'type': 'object', 'properties': {}})
|
||||
})
|
||||
return res
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ tool schema helpers
|
||||
|
||||
55
tests/test_external_mcp.py
Normal file
55
tests/test_external_mcp.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import pytest
|
||||
from src import mcp_client
|
||||
from src import models
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_external_mcp_real_process():
|
||||
manager = mcp_client.ExternalMCPManager()
|
||||
|
||||
# Use our mock script
|
||||
mock_script = "scripts/mock_mcp_server.py"
|
||||
config = models.MCPServerConfig(
|
||||
name="real-mock",
|
||||
command="python",
|
||||
args=[mock_script]
|
||||
)
|
||||
|
||||
await manager.add_server(config)
|
||||
|
||||
try:
|
||||
tools = manager.get_all_tools()
|
||||
assert "echo" in tools
|
||||
assert tools["echo"]["server"] == "real-mock"
|
||||
|
||||
result = await manager.async_dispatch("echo", {"hello": "world"})
|
||||
assert "ECHO: {'hello': 'world'}" in result
|
||||
finally:
|
||||
await manager.stop_all()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tool_schemas_includes_external():
|
||||
manager = mcp_client.get_external_mcp_manager()
|
||||
# Reset manager
|
||||
await manager.stop_all()
|
||||
|
||||
mock_script = "scripts/mock_mcp_server.py"
|
||||
config = models.MCPServerConfig(
|
||||
name="test-server",
|
||||
command="python",
|
||||
args=[mock_script]
|
||||
)
|
||||
|
||||
await manager.add_server(config)
|
||||
|
||||
try:
|
||||
schemas = mcp_client.get_tool_schemas()
|
||||
echo_schema = next((s for s in schemas if s["name"] == "echo"), None)
|
||||
|
||||
assert echo_schema is not None
|
||||
assert echo_schema["description"] == "Echo input"
|
||||
assert echo_schema["parameters"] == {"type": "object"}
|
||||
finally:
|
||||
await manager.stop_all()
|
||||
Reference in New Issue
Block a user