Files
manual_slop/src/ai_client.py

2270 lines
86 KiB
Python

# ai_client.py
from __future__ import annotations
"""
Note(Gemini):
Acts as the unified interface for multiple LLM providers (Anthropic, Gemini).
Abstracts away the differences in how they handle tool schemas, history, and caching.
For Anthropic: aggressively manages the ~200k token limit by manually culling
stale [FILES UPDATED] entries and dropping the oldest message pairs.
For Gemini: injects the initial context directly into system_instruction
during chat creation to avoid massive history bloat.
"""
# ai_client.py
import tomllib
import asyncio
import json
import sys
import time
import datetime
import hashlib
import difflib
import threading
import requests # type: ignore[import-untyped]
from typing import Optional, Callable, Any, List, Union, cast, Iterable
import os
from pathlib import Path
from src import project_manager
from src import file_cache
from src import mcp_client
from src import mma_prompts
import anthropic
from src.gemini_cli_adapter import GeminiCliAdapter as GeminiCliAdapter
from google import genai
from google.genai import types
from src.events import EventEmitter
_provider: str = "gemini"
_model: str = "gemini-2.5-flash-lite"
_temperature: float = 0.0
_max_tokens: int = 8192
_history_trunc_limit: int = 8000
# Global event emitter for API lifecycle events
events: EventEmitter = EventEmitter()
def set_model_params(temp: float, max_tok: int, trunc_limit: int = 8000) -> None:
global _temperature, _max_tokens, _history_trunc_limit
_temperature = temp
_max_tokens = max_tok
_history_trunc_limit = trunc_limit
def get_history_trunc_limit() -> int:
return _history_trunc_limit
def set_history_trunc_limit(val: int) -> None:
global _history_trunc_limit
_history_trunc_limit = val
_gemini_client: Optional[genai.Client] = None
_gemini_chat: Any = None
_gemini_cache: Any = None
_gemini_cache_md_hash: Optional[str] = None
_gemini_cache_created_at: Optional[float] = None
_gemini_cached_file_paths: list[str] = []
# Gemini cache TTL in seconds. Caches are created with this TTL and
# proactively rebuilt at 90% of this value to avoid stale-reference errors.
_GEMINI_CACHE_TTL: int = 3600
_anthropic_client: Optional[anthropic.Anthropic] = None
_anthropic_history: list[dict[str, Any]] = []
_anthropic_history_lock: threading.Lock = threading.Lock()
_deepseek_client: Any = None
_deepseek_history: list[dict[str, Any]] = []
_deepseek_history_lock: threading.Lock = threading.Lock()
_minimax_client: Any = None
_minimax_history: list[dict[str, Any]] = []
_minimax_history_lock: threading.Lock = threading.Lock()
_send_lock: threading.Lock = threading.Lock()
_gemini_cli_adapter: Optional[GeminiCliAdapter] = None
# Injected by gui.py - called when AI wants to run a command.
confirm_and_run_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]], Optional[Callable[[str, str], Optional[str]]]], Optional[str]]] = None
# Injected by gui.py - called whenever a comms entry is appended.
comms_log_callback: Optional[Callable[[dict[str, Any]], None]] = None
# Injected by gui.py - called whenever a tool call completes.
tool_log_callback: Optional[Callable[[str, str], None]] = None
_local_storage = threading.local()
def get_current_tier() -> Optional[str]:
"""Returns the current tier from thread-local storage."""
return getattr(_local_storage, "current_tier", None)
def set_current_tier(tier: Optional[str]) -> None:
"""Sets the current tier in thread-local storage."""
_local_storage.current_tier = tier
# Increased to allow thorough code exploration before forcing a summary
MAX_TOOL_ROUNDS: int = 10
# Maximum cumulative bytes of tool output allowed per send() call.
_MAX_TOOL_OUTPUT_BYTES: int = 500_000
# Maximum characters per text chunk sent to Anthropic.
_ANTHROPIC_CHUNK_SIZE: int = 120_000
_SYSTEM_PROMPT: str = (
"You are a helpful coding assistant with access to a PowerShell tool (run_powershell) and MCP tools (file access: read_file, list_directory, search_files, get_file_summary, web access: web_search, fetch_url). "
"When calling file/directory tools, always use the 'path' parameter for the target path. "
"When asked to create or edit files, prefer targeted edits over full rewrites. "
"Always explain what you are doing before invoking the tool.\n\n"
"When writing or rewriting large files (especially those containing quotes, backticks, or special characters), "
"avoid python -c with inline strings. Instead: (1) write a .py helper script to disk using a PS here-string "
"(@'...'@ for literal content), (2) run it with `python <script>`, (3) delete the helper. "
"For small targeted edits, use PowerShell's (Get-Content) / .Replace() / Set-Content or Add-Content directly.\n\n"
"When making function calls using tools that accept array or object parameters "
"ensure those are structured using JSON. For example:\n"
"When you need to verify a change, rely on the exit code and stdout/stderr from the tool \u2014 "
"the user's context files are automatically refreshed after every tool call, so you do NOT "
"need to re-read files that are already provided in the <context> block."
)
_custom_system_prompt: str = ""
def set_custom_system_prompt(prompt: str) -> None:
global _custom_system_prompt
_custom_system_prompt = prompt
def _get_combined_system_prompt() -> str:
if _custom_system_prompt.strip():
return f"{_SYSTEM_PROMPT}\n\n[USER SYSTEM PROMPT]\n{_custom_system_prompt}"
return _SYSTEM_PROMPT
from collections import deque
_comms_log: deque[dict[str, Any]] = deque(maxlen=1000)
COMMS_CLAMP_CHARS: int = 300
def _append_comms(direction: str, kind: str, payload: dict[str, Any]) -> None:
entry: dict[str, Any] = {
"ts": datetime.datetime.now().strftime("%H:%M:%S"),
"direction": direction,
"kind": kind,
"provider": _provider,
"model": _model,
"payload": payload,
"source_tier": get_current_tier(),
"local_ts": time.time(),
}
_comms_log.append(entry)
if comms_log_callback is not None:
comms_log_callback(entry)
def get_comms_log() -> list[dict[str, Any]]:
return list(_comms_log)
def clear_comms_log() -> None:
_comms_log.clear()
def get_credentials_path() -> Path:
return Path(os.environ.get("SLOP_CREDENTIALS", str(Path(__file__).parent.parent / "credentials.toml")))
def _load_credentials() -> dict[str, Any]:
cred_path = get_credentials_path()
try:
with open(cred_path, "rb") as f:
return tomllib.load(f)
except FileNotFoundError:
raise FileNotFoundError(
f"Credentials file not found: {cred_path}\n"
f"Create a credentials.toml with:\n"
f" [gemini]\n api_key = \"your-key\"\n"
f" [anthropic]\n api_key = \"your-key\"\n"
f" [deepseek]\n api_key = \"your-key\"\n"
f" [minimax]\n api_key = \"your-key\"\n"
f"Or set SLOP_CREDENTIALS env var to a custom path."
)
class ProviderError(Exception):
def __init__(self, kind: str, provider: str, original: Exception) -> None:
self.kind = kind
self.provider = provider
self.original = original
super().__init__(str(original))
def ui_message(self) -> str:
labels = {
"quota": "QUOTA EXHAUSTED",
"rate_limit": "RATE LIMITED",
"auth": "AUTH / API KEY ERROR",
"balance": "BALANCE / BILLING ERROR",
"network": "NETWORK / CONNECTION ERROR",
"unknown": "API ERROR",
}
label = labels.get(self.kind, "API ERROR")
return f"[{self.provider.upper()} {label}]\n\n{self.original}"
def _classify_anthropic_error(exc: Exception) -> ProviderError:
try:
if isinstance(exc, anthropic.RateLimitError):
return ProviderError("rate_limit", "anthropic", exc)
if isinstance(exc, anthropic.AuthenticationError):
return ProviderError("auth", "anthropic", exc)
if isinstance(exc, anthropic.PermissionDeniedError):
return ProviderError("auth", "anthropic", exc)
if isinstance(exc, anthropic.APIConnectionError):
return ProviderError("network", "anthropic", exc)
if isinstance(exc, anthropic.APIStatusError):
status = getattr(exc, "status_code", 0)
body = str(exc).lower()
if status == 429:
return ProviderError("rate_limit", "anthropic", exc)
if status in (401, 403):
return ProviderError("auth", "anthropic", exc)
if status == 402:
return ProviderError("balance", "anthropic", exc)
if "credit" in body or "balance" in body or "billing" in body:
return ProviderError("balance", "anthropic", exc)
if "quota" in body or "limit" in body or "exceeded" in body:
return ProviderError("quota", "anthropic", exc)
except ImportError:
pass
return ProviderError("unknown", "anthropic", exc)
def _classify_gemini_error(exc: Exception) -> ProviderError:
body = str(exc).lower()
try:
from google.api_core import exceptions as gac
if isinstance(exc, gac.ResourceExhausted):
return ProviderError("quota", "gemini", exc)
if isinstance(exc, gac.TooManyRequests):
return ProviderError("rate_limit", "gemini", exc)
if isinstance(exc, (gac.Unauthenticated, gac.PermissionDenied)):
return ProviderError("auth", "gemini", exc)
if isinstance(exc, gac.ServiceUnavailable):
return ProviderError("network", "gemini", exc)
except ImportError:
pass
if "429" in body or "quota" in body or "resource exhausted" in body:
return ProviderError("quota", "gemini", exc)
if "rate" in body and "limit" in body:
return ProviderError("rate_limit", "gemini", exc)
if "401" in body or "403" in body or "api key" in body or "unauthenticated" in body:
return ProviderError("auth", "gemini", exc)
if "402" in body or "billing" in body or "balance" in body or "payment" in body:
return ProviderError("balance", "gemini", exc)
if "connection" in body or "timeout" in body or "unreachable" in body:
return ProviderError("network", "gemini", exc)
return ProviderError("unknown", "gemini", exc)
def _classify_deepseek_error(exc: Exception) -> ProviderError:
body = ""
if isinstance(exc, requests.exceptions.HTTPError) and exc.response is not None:
try:
# Try to get the detailed error from DeepSeek's JSON response
err_data = exc.response.json()
if "error" in err_data:
body = str(err_data["error"].get("message", exc.response.text))
else:
body = exc.response.text
except:
body = exc.response.text
else:
body = str(exc)
body_l = body.lower()
if "429" in body_l or "rate" in body_l:
return ProviderError("rate_limit", "deepseek", Exception(body))
if "401" in body_l or "403" in body_l or "auth" in body_l or "api key" in body_l:
return ProviderError("auth", "deepseek", Exception(body))
if "402" in body_l or "balance" in body_l or "billing" in body_l:
return ProviderError("balance", "deepseek", Exception(body))
if "quota" in body_l or "limit exceeded" in body_l:
return ProviderError("quota", "deepseek", Exception(body))
if "connection" in body_l or "timeout" in body_l or "network" in body_l:
return ProviderError("network", "deepseek", Exception(body))
# If we have a body for a 400 error, wrap it
if "400" in body_l or "bad request" in body_l:
return ProviderError("unknown", "deepseek", Exception(f"DeepSeek Bad Request: {body}"))
return ProviderError("unknown", "deepseek", Exception(body))
def _classify_minimax_error(exc: Exception) -> ProviderError:
body = ""
if isinstance(exc, requests.exceptions.HTTPError) and exc.response is not None:
try:
err_data = exc.response.json()
if "error" in err_data:
body = str(err_data["error"].get("message", exc.response.text))
else:
body = exc.response.text
except:
body = exc.response.text
else:
body = str(exc)
body_l = body.lower()
if "429" in body_l or "rate" in body_l:
return ProviderError("rate_limit", "minimax", Exception(body))
if "401" in body_l or "403" in body_l or "auth" in body_l or "api key" in body_l:
return ProviderError("auth", "minimax", Exception(body))
if "402" in body_l or "balance" in body_l or "billing" in body_l:
return ProviderError("balance", "minimax", Exception(body))
if "quota" in body_l or "limit exceeded" in body_l:
return ProviderError("quota", "minimax", Exception(body))
if "connection" in body_l or "timeout" in body_l or "network" in body_l:
return ProviderError("network", "minimax", Exception(body))
if "400" in body_l or "bad request" in body_l:
return ProviderError("unknown", "minimax", Exception(f"MiniMax Bad Request: {body}"))
return ProviderError("unknown", "minimax", Exception(body))
def set_provider(provider: str, model: str) -> None:
global _provider, _model
_provider = provider
if provider == "gemini_cli":
valid_models = _list_gemini_cli_models()
if model != "mock" and (model not in valid_models or model.startswith("deepseek")):
_model = "gemini-3-flash-preview"
else:
_model = model
elif provider == "minimax":
valid_models = _list_minimax_models("")
if model not in valid_models:
_model = "MiniMax-M2.5"
else:
_model = model
else:
_model = model
def get_provider() -> str:
return _provider
def cleanup() -> None:
global _gemini_client, _gemini_cache, _gemini_cached_file_paths
if _gemini_client and _gemini_cache:
try:
_gemini_client.caches.delete(name=_gemini_cache.name)
except Exception:
pass
_gemini_cached_file_paths = []
def reset_session() -> None:
global _gemini_client, _gemini_chat, _gemini_cache
global _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths
global _anthropic_client, _anthropic_history
global _deepseek_client, _deepseek_history
global _minimax_client, _minimax_history
global _CACHED_ANTHROPIC_TOOLS
global _gemini_cli_adapter
if _gemini_client and _gemini_cache:
try:
_gemini_client.caches.delete(name=_gemini_cache.name)
except Exception:
pass
_gemini_client = None
_gemini_chat = None
_gemini_cache = None
_gemini_cache_md_hash = None
_gemini_cache_created_at = None
_gemini_cached_file_paths = []
# Preserve binary_path if adapter exists
old_path = _gemini_cli_adapter.binary_path if _gemini_cli_adapter else "gemini"
_gemini_cli_adapter = GeminiCliAdapter(binary_path=old_path)
_anthropic_client = None
with _anthropic_history_lock:
_anthropic_history = []
_deepseek_client = None
with _deepseek_history_lock:
_deepseek_history = []
_minimax_client = None
with _minimax_history_lock:
_minimax_history = []
_CACHED_ANTHROPIC_TOOLS = None
file_cache.reset_client()
def get_gemini_cache_stats() -> dict[str, Any]:
_ensure_gemini_client()
if not _gemini_client:
return {"cache_count": 0, "total_size_bytes": 0, "cached_files": []}
caches_iterator = _gemini_client.caches.list()
caches = list(caches_iterator)
total_size_bytes = sum(getattr(c, 'size_bytes', 0) for c in caches)
return {
"cache_count": len(caches),
"total_size_bytes": total_size_bytes,
"cached_files": _gemini_cached_file_paths,
}
def list_models(provider: str) -> list[str]:
creds = _load_credentials()
if provider == "gemini":
return _list_gemini_models(creds["gemini"]["api_key"])
elif provider == "anthropic":
return _list_anthropic_models()
elif provider == "deepseek":
return _list_deepseek_models(creds["deepseek"]["api_key"])
elif provider == "gemini_cli":
return _list_gemini_cli_models()
elif provider == "minimax":
return _list_minimax_models(creds["minimax"]["api_key"])
return []
def _list_gemini_cli_models() -> list[str]:
return [
"gemini-3-flash-preview",
"gemini-3.1-pro-preview",
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-2.0-flash",
"gemini-2.5-flash-lite",
]
def _list_gemini_models(api_key: str) -> list[str]:
try:
client = genai.Client(api_key=api_key)
models: list[str] = []
for m in client.models.list():
name = m.name
if name and name.startswith("models/"):
name = name[len("models/"):]
if name and "gemini" in name.lower():
models.append(name)
return sorted(models)
except Exception as exc:
raise _classify_gemini_error(exc) from exc
def _list_anthropic_models() -> list[str]:
try:
creds = _load_credentials()
client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
models: list[str] = []
for m in client.models.list():
models.append(m.id)
return sorted(models)
except Exception as exc:
raise _classify_anthropic_error(exc) from exc
def _list_deepseek_models(api_key: str) -> list[str]:
return ["deepseek-chat", "deepseek-reasoner"]
def _list_minimax_models(api_key: str) -> list[str]:
return ["MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1", "MiniMax-M2.1-highspeed", "MiniMax-M2"]
TOOL_NAME: str = "run_powershell"
_agent_tools: dict[str, bool] = {}
def set_agent_tools(tools: dict[str, bool]) -> None:
global _agent_tools, _CACHED_ANTHROPIC_TOOLS
_agent_tools = tools
_CACHED_ANTHROPIC_TOOLS = None
def _build_anthropic_tools() -> list[dict[str, Any]]:
mcp_tools: list[dict[str, Any]] = []
for spec in mcp_client.MCP_TOOL_SPECS:
if _agent_tools.get(spec["name"], True):
mcp_tools.append({
"name": spec["name"],
"description": spec["description"],
"input_schema": spec["parameters"],
})
tools_list = mcp_tools
if _agent_tools.get(TOOL_NAME, True):
powershell_tool: dict[str, Any] = {
"name": TOOL_NAME,
"description": (
"Run a PowerShell script within the project base_dir. "
"Use this to create, edit, rename, or delete files and directories. "
"The working directory is set to base_dir automatically. "
"Always prefer targeted edits over full rewrites where possible. "
"stdout and stderr are returned to you as the result."
),
"input_schema": {
"type": "object",
"properties": {
"script": {
"type": "string",
"description": "The PowerShell script to execute."
}
},
"required": ["script"]
},
"cache_control": {"type": "ephemeral"},
}
tools_list.append(powershell_tool)
elif tools_list:
tools_list[-1]["cache_control"] = {"type": "ephemeral"}
return tools_list
_CACHED_ANTHROPIC_TOOLS: Optional[list[dict[str, Any]]] = None
def _get_anthropic_tools() -> list[dict[str, Any]]:
global _CACHED_ANTHROPIC_TOOLS
if _CACHED_ANTHROPIC_TOOLS is None:
_CACHED_ANTHROPIC_TOOLS = _build_anthropic_tools()
return _CACHED_ANTHROPIC_TOOLS
def _gemini_tool_declaration() -> Optional[types.Tool]:
declarations: list[types.FunctionDeclaration] = []
for spec in mcp_client.MCP_TOOL_SPECS:
if not _agent_tools.get(spec["name"], True):
continue
props = {}
for pname, pdef in spec["parameters"].get("properties", {}).items():
ptype_str = pdef.get("type", "string").upper()
ptype = getattr(types.Type, ptype_str, types.Type.STRING)
props[pname] = types.Schema(
type=ptype,
description=pdef.get("description", ""),
)
declarations.append(types.FunctionDeclaration(
name=spec["name"],
description=spec["description"],
parameters=types.Schema(
type=types.Type.OBJECT,
properties=props,
required=spec["parameters"].get("required", []),
),
))
if _agent_tools.get(TOOL_NAME, True):
declarations.append(types.FunctionDeclaration(
name=TOOL_NAME,
description=(
"Run a PowerShell script within the project base_dir. "
"Use this to create, edit, rename, or delete files and directories. "
"The working directory is set to base_dir automatically. "
"stdout and stderr are returned to you as the result."
),
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"script": types.Schema(
type=types.Type.STRING,
description="The PowerShell script to execute."
)
},
required=["script"]
),
))
return types.Tool(function_declarations=declarations) if declarations else None
async def _execute_tool_calls_concurrently(
calls: list[Any],
base_dir: str,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
qa_callback: Optional[Callable[[str], str]],
r_idx: int,
provider: str,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
) -> list[tuple[str, str, str, str]]: # tool_name, call_id, output, original_name
"""
Executes multiple tool calls concurrently using asyncio.gather.
Returns a list of (tool_name, call_id, output, original_name).
"""
tasks = []
for fc in calls:
if provider == "gemini":
name, args, call_id = fc.name, dict(fc.args), fc.name # Gemini 1.0.0 doesn't have call IDs in types.Part
elif provider == "gemini_cli":
name, args, call_id = cast(str, fc.get("name")), cast(dict[str, Any], fc.get("args", {})), cast(str, fc.get("id"))
elif provider == "anthropic":
name, args, call_id = cast(str, getattr(fc, "name")), cast(dict[str, Any], getattr(fc, "input")), cast(str, getattr(fc, "id"))
elif provider == "deepseek":
tool_info = fc.get("function", {})
name = cast(str, tool_info.get("name"))
tool_args_str = cast(str, tool_info.get("arguments", "{}"))
call_id = cast(str, fc.get("id"))
try: args = json.loads(tool_args_str)
except: args = {}
elif provider == "minimax":
tool_info = fc.get("function", {})
name = cast(str, tool_info.get("name"))
tool_args_str = cast(str, tool_info.get("arguments", "{}"))
call_id = cast(str, fc.get("id"))
try: args = json.loads(tool_args_str)
except: args = {}
else:
continue
tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx, patch_callback))
results = await asyncio.gather(*tasks)
return results
async def _execute_single_tool_call_async(
name: str,
args: dict[str, Any],
call_id: str,
base_dir: str,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
qa_callback: Optional[Callable[[str], str]],
r_idx: int,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
) -> tuple[str, str, str, str]:
out = ""
tool_executed = False
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
# Check for run_powershell
if name == TOOL_NAME and pre_tool_callback:
scr = cast(str, args.get("script", ""))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
# pre_tool_callback is synchronous and might block for HITL
res = await asyncio.to_thread(pre_tool_callback, scr, base_dir, qa_callback)
if res is None:
out = "USER REJECTED: tool execution cancelled"
else:
out = res
tool_executed = True
if not tool_executed:
if name and name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
if name in mcp_client.MUTATING_TOOLS and pre_tool_callback:
desc = f"# MCP MUTATING TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
_res = await asyncio.to_thread(pre_tool_callback, desc, base_dir, qa_callback)
out = "USER REJECTED: tool execution cancelled" if _res is None else await mcp_client.async_dispatch(name, args)
else:
out = await mcp_client.async_dispatch(name, args)
elif name == TOOL_NAME:
scr = cast(str, args.get("script", ""))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback, patch_callback)
else:
out = f"ERROR: unknown tool '{name}'"
return (name, call_id, out, name)
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
if confirm_and_run_callback is None:
return "ERROR: no confirmation handler registered"
result = confirm_and_run_callback(script, base_dir, qa_callback, patch_callback)
if result is None:
output = "USER REJECTED: command was not executed"
else:
output = result
if tool_log_callback is not None:
tool_log_callback(script, output)
return output
def _truncate_tool_output(output: str) -> str:
if _history_trunc_limit > 0 and len(output) > _history_trunc_limit:
return output[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS.]"
return output
def _reread_file_items(file_items: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
refreshed: list[dict[str, Any]] = []
changed: list[dict[str, Any]] = []
for item in file_items:
path = item.get("path")
if path is None:
refreshed.append(item)
continue
from pathlib import Path as _P
p = path if isinstance(path, _P) else _P(path)
try:
current_mtime = p.stat().st_mtime
prev_mtime = cast(float, item.get("mtime", 0.0))
if current_mtime == prev_mtime:
refreshed.append(item)
continue
content = p.read_text(encoding="utf-8")
new_item = {**item, "old_content": item.get("content", ""), "content": content, "error": False, "mtime": current_mtime}
refreshed.append(new_item)
changed.append(new_item)
except Exception as e:
err_item = {**item, "content": f"ERROR re-reading {p}: {e}", "error": True, "mtime": 0.0}
refreshed.append(err_item)
changed.append(err_item)
return refreshed, changed
def _build_file_context_text(file_items: list[dict[str, Any]]) -> str:
if not file_items:
return ""
parts: list[str] = []
for item in file_items:
path = item.get("path") or item.get("entry", "unknown")
suffix = str(path).rsplit(".", 1)[-1] if "." in str(path) else "text"
content = item.get("content", "")
parts.append(f"### `{path}`\n\n```{suffix}\n{content}\n```")
return "\n\n---\n\n".join(parts)
_DIFF_LINE_THRESHOLD: int = 200
def _build_file_diff_text(changed_items: list[dict[str, Any]]) -> str:
if not changed_items:
return ""
parts: list[str] = []
for item in changed_items:
path = item.get("path") or item.get("entry", "unknown")
content = cast(str, item.get("content", ""))
old_content = cast(str, item.get("old_content", ""))
new_lines = content.splitlines(keepends=True)
if len(new_lines) <= _DIFF_LINE_THRESHOLD or not old_content:
suffix = str(path).rsplit(".", 1)[-1] if "." in str(path) else "text"
parts.append(f"### `{path}` (full)\n\n```{suffix}\n{content}\n```")
else:
old_lines = old_content.splitlines(keepends=True)
diff = difflib.unified_diff(old_lines, new_lines, fromfile=str(path), tofile=str(path), lineterm="")
diff_text = "\n".join(diff)
if diff_text:
parts.append(f"### `{path}` (diff)\n\n```diff\n{diff_text}\n```")
else:
parts.append(f"### `{path}` (no changes detected)")
return "\n\n---\n\n".join(parts)
def _build_deepseek_tools() -> list[dict[str, Any]]:
mcp_tools: list[dict[str, Any]] = []
for spec in mcp_client.MCP_TOOL_SPECS:
if _agent_tools.get(spec["name"], True):
mcp_tools.append({
"type": "function",
"function": {
"name": spec["name"],
"description": spec["description"],
"parameters": spec["parameters"],
}
})
tools_list = mcp_tools
if _agent_tools.get(TOOL_NAME, True):
powershell_tool: dict[str, Any] = {
"type": "function",
"function": {
"name": TOOL_NAME,
"description": (
"Run a PowerShell script within the project base_dir. "
"Use this to create, edit, rename, or delete files and directories. "
"The working directory is set to base_dir automatically. "
"Always prefer targeted edits over full rewrites where possible. "
"stdout and stderr are returned to you as the result."
),
"parameters": {
"type": "object",
"properties": {
"script": {
"type": "string",
"description": "The PowerShell script to execute."
}
},
"required": ["script"]
}
}
}
tools_list.append(powershell_tool)
return tools_list
_CACHED_DEEPSEEK_TOOLS: Optional[list[dict[str, Any]]] = None
def _get_deepseek_tools() -> list[dict[str, Any]]:
global _CACHED_DEEPSEEK_TOOLS
if _CACHED_DEEPSEEK_TOOLS is None:
_CACHED_DEEPSEEK_TOOLS = _build_deepseek_tools()
return _CACHED_DEEPSEEK_TOOLS
def _content_block_to_dict(block: Any) -> dict[str, Any]:
if isinstance(block, dict):
return block
if hasattr(block, "model_dump"):
return cast(dict[str, Any], block.model_dump())
if hasattr(block, "to_dict"):
return cast(dict[str, Any], block.to_dict())
block_type = getattr(block, "type", None)
if block_type == "text":
return {"type": "text", "text": block.text}
if block_type == "tool_use":
return {"type": "tool_use", "id": getattr(block, "id"), "name": getattr(block, "name"), "input": getattr(block, "input")}
return {"type": "text", "text": str(block)}
def _ensure_gemini_client() -> None:
global _gemini_client
if _gemini_client is None:
creds = _load_credentials()
_gemini_client = genai.Client(api_key=creds["gemini"]["api_key"])
def _get_gemini_history_list(chat: Any | None) -> list[Any]:
if not chat: return []
if hasattr(chat, "_history"):
return cast(list[Any], chat._history)
if hasattr(chat, "history"):
return cast(list[Any], chat.history)
if hasattr(chat, "get_history"):
return cast(list[Any], chat.get_history())
return []
def _send_gemini(md_content: str, user_message: str, base_dir: str,
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
qa_callback: Optional[Callable[[str], str]] = None,
enable_tools: bool = True,
stream_callback: Optional[Callable[[str], None]] = None,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths
try:
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"
td = _gemini_tool_declaration() if enable_tools else None
tools_decl = [td] if td else None
current_md_hash = hashlib.md5(md_content.encode()).hexdigest()
old_history = None
assert _gemini_client is not None
if _gemini_chat and _gemini_cache_md_hash != current_md_hash:
old_history = list(_get_gemini_history_list(_gemini_chat)) if _get_gemini_history_list(_gemini_chat) else []
if _gemini_cache:
try: _gemini_client.caches.delete(name=_gemini_cache.name)
except Exception as e: _append_comms("OUT", "request", {"message": f"[CACHE DELETE WARN] {e}"})
_gemini_chat = None
_gemini_cache = None
_gemini_cache_created_at = None
_gemini_cached_file_paths = []
_append_comms("OUT", "request", {"message": "[CONTEXT CHANGED] Rebuilding cache and chat session..."})
if _gemini_chat and _gemini_cache and _gemini_cache_created_at:
elapsed = time.time() - _gemini_cache_created_at
if elapsed > _GEMINI_CACHE_TTL * 0.9:
old_history = list(_get_gemini_history_list(_gemini_chat)) if _get_gemini_history_list(_gemini_chat) else []
try: _gemini_client.caches.delete(name=_gemini_cache.name)
except Exception as e: _append_comms("OUT", "request", {"message": f"[CACHE DELETE WARN] {e}"})
_gemini_chat = None
_gemini_cache = None
_gemini_cache_created_at = None
_gemini_cached_file_paths = []
_append_comms("OUT", "request", {"message": f"[CACHE TTL] Rebuilding cache (expired after {int(elapsed)}s)..."})
if not _gemini_chat:
chat_config = types.GenerateContentConfig(
system_instruction=sys_instr,
tools=cast(Any, tools_decl),
temperature=_temperature,
max_output_tokens=_max_tokens,
safety_settings=[types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=types.HarmBlockThreshold.BLOCK_ONLY_HIGH)]
)
should_cache = False
try:
if _gemini_client:
count_resp = _gemini_client.models.count_tokens(model=_model, contents=[sys_instr])
if count_resp.total_tokens and count_resp.total_tokens >= 2048:
should_cache = True
else:
_append_comms("OUT", "request", {"message": f"[CACHING SKIPPED] Context too small ({count_resp.total_tokens} tokens < 2048)"})
except Exception as e:
_append_comms("OUT", "request", {"message": f"[COUNT FAILED] {e}"})
if should_cache and _gemini_client:
try:
_gemini_cache = _gemini_client.caches.create(
model=_model,
config=types.CreateCachedContentConfig(
system_instruction=sys_instr,
tools=cast(Any, tools_decl),
ttl=f"{_GEMINI_CACHE_TTL}s",
)
)
_gemini_cache_created_at = time.time()
_gemini_cached_file_paths = [str(item.get("path", "")) for item in (file_items or []) if item.get("path")]
chat_config = types.GenerateContentConfig(
cached_content=_gemini_cache.name,
temperature=_temperature,
max_output_tokens=_max_tokens,
safety_settings=[types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=types.HarmBlockThreshold.BLOCK_ONLY_HIGH)]
)
_append_comms("OUT", "request", {"message": f"[CACHE CREATED] {_gemini_cache.name}"})
except Exception as e:
_gemini_cache = None
_gemini_cache_created_at = None
_gemini_cached_file_paths = []
_append_comms("OUT", "request", {"message": f"[CACHE FAILED] {type(e).__name__}: {e} \u2014 falling back to inline system_instruction"})
kwargs: dict[str, Any] = {"model": _model, "config": chat_config}
if old_history:
kwargs["history"] = old_history
if _gemini_client:
_gemini_chat = _gemini_client.chats.create(**kwargs)
_gemini_cache_md_hash = current_md_hash
if discussion_history and not old_history:
_gemini_chat.send_message(f"[DISCUSSION HISTORY]\n\n{discussion_history}")
_append_comms("OUT", "request", {"message": f"[HISTORY INJECTED] {len(discussion_history)} chars"})
_append_comms("OUT", "request", {"message": f"[ctx {len(md_content)} + msg {len(user_message)}]"})
payload: str | list[types.Part] = user_message
all_text: list[str] = []
_cumulative_tool_bytes = 0
if _gemini_chat and _get_gemini_history_list(_gemini_chat):
for msg in _get_gemini_history_list(_gemini_chat):
if msg.role == "user" and hasattr(msg, "parts"):
for p in msg.parts:
if hasattr(p, "function_response") and p.function_response and hasattr(p.function_response, "response"):
r = p.function_response.response
if isinstance(r, dict) and "output" in r:
val = r["output"]
if isinstance(val, str):
if "[SYSTEM: FILES UPDATED]" in val:
val = val.split("[SYSTEM: FILES UPDATED]")[0].strip()
if _history_trunc_limit > 0 and len(val) > _history_trunc_limit:
val = val[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS.]"
r["output"] = val
for r_idx in range(MAX_TOOL_ROUNDS + 2):
events.emit("request_start", payload={"provider": "gemini", "model": _model, "round": r_idx})
# Shared config for this round
td = _gemini_tool_declaration() if enable_tools else None
config = types.GenerateContentConfig(
tools=[td] if td else [],
temperature=_temperature,
max_output_tokens=_max_tokens,
)
if stream_callback:
resp = _gemini_chat.send_message_stream(payload, config=config)
txt_chunks: list[str] = []
calls = []
usage = {}
reason = "STOP"
final_resp = None
for chunk in resp:
if chunk.text:
txt_chunks.append(chunk.text)
stream_callback(chunk.text)
if chunk.candidates:
c = chunk.candidates[0]
if c.content and c.content.parts:
calls.extend([p.function_call for p in c.content.parts if p.function_call])
if hasattr(c, "finish_reason") and c.finish_reason:
reason = c.finish_reason.name
if chunk.usage_metadata:
usage = {
"input_tokens": chunk.usage_metadata.prompt_token_count,
"output_tokens": chunk.usage_metadata.candidates_token_count,
"total_tokens": chunk.usage_metadata.total_token_count,
"cache_read_input_tokens": getattr(chunk.usage_metadata, "cached_content_token_count", 0)
}
final_resp = chunk
txt = "".join(txt_chunks)
if txt: all_text.append(txt)
# Final validation of response object for subsequent code
resp = final_resp
events.emit("response_received", payload={"provider": "gemini", "model": _model, "usage": usage, "round": r_idx})
else:
resp = _gemini_chat.send_message(payload, config=config)
txt = resp.text or ""
if txt: all_text.append(txt)
calls = [p.function_call for c in resp.candidates if getattr(c, "content", None) for p in c.content.parts if p.function_call]
usage = {
"input_tokens": getattr(resp.usage_metadata, "prompt_token_count", 0),
"output_tokens": getattr(resp.usage_metadata, "candidates_token_count", 0),
"total_tokens": getattr(resp.usage_metadata, "total_token_count", 0),
"cache_read_input_tokens": getattr(resp.usage_metadata, "cached_content_token_count", 0)
}
reason = resp.candidates[0].finish_reason.name if (resp.candidates and hasattr(resp.candidates[0], "finish_reason")) else "STOP"
events.emit("response_received", payload={"provider": "gemini", "model": _model, "usage": usage, "round": r_idx})
_append_comms("IN", "response", {"round": r_idx, "stop_reason": reason, "text": txt, "tool_calls": [{"name": c.name, "args": dict(c.args)} for c in calls], "usage": usage})
total_in = usage.get("input_tokens", 0)
if total_in > _GEMINI_MAX_INPUT_TOKENS * 0.4 and _gemini_chat and _get_gemini_history_list(_gemini_chat):
hist = _get_gemini_history_list(_gemini_chat)
dropped = 0
while len(hist) > 4 and total_in > _GEMINI_MAX_INPUT_TOKENS * 0.3:
saved = 0
for _ in range(2):
if not hist: break
for p in hist[0].parts:
if hasattr(p, "text") and p.text:
saved += int(len(p.text) / _CHARS_PER_TOKEN)
elif hasattr(p, "function_response") and p.function_response:
r = getattr(p.function_response, "response", {})
if isinstance(r, dict):
saved += int(len(str(r.get("output", ""))) / _CHARS_PER_TOKEN)
hist.pop(0)
dropped += 1
total_in -= max(saved, 200)
if dropped > 0:
_append_comms("OUT", "request", {"message": f"[GEMINI HISTORY TRIMMED: dropped {dropped} old entries to stay within token budget]"})
if not calls or r_idx > MAX_TOOL_ROUNDS: break
f_resps: list[types.Part] = []
log: list[dict[str, Any]] = []
# Execute tools concurrently
try:
loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback),
loop
).result()
except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback))
for i, (name, call_id, out, _) in enumerate(results):
# Check if this is the last tool to trigger file refresh
if i == len(results) - 1:
if file_items:
file_items, changed = _reread_file_items(file_items)
ctx = _build_file_diff_text(changed)
if ctx:
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
out = _truncate_tool_output(out)
_cumulative_tool_bytes += len(out)
f_resps.append(types.Part(function_response=types.FunctionResponse(name=cast(str, name), response={"output": out})))
log.append({"tool_use_id": name, "content": out})
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
f_resps.append(types.Part(text=
f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
))
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
_append_comms("OUT", "tool_result_send", {"results": log})
payload = f_resps
return "\n\n".join(all_text) if all_text else "(No text returned)"
except Exception as e: raise _classify_gemini_error(e) from e
def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
qa_callback: Optional[Callable[[str], str]] = None,
stream_callback: Optional[Callable[[str], None]] = None,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
global _gemini_cli_adapter
sys.stderr.write(f"[DEBUG] _send_gemini_cli running in module {__name__}, adapter is {_gemini_cli_adapter}\n")
sys.stderr.flush()
try:
if _gemini_cli_adapter is None:
_gemini_cli_adapter = GeminiCliAdapter(binary_path="gemini")
adapter = _gemini_cli_adapter
mcp_client.configure(file_items or [], [base_dir])
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"
safety_settings = [{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_ONLY_HIGH'}]
payload: Union[str, list[dict[str, Any]]] = user_message
if adapter.session_id is None:
if discussion_history:
payload = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"
all_text: list[str] = []
_cumulative_tool_bytes = 0
for r_idx in range(MAX_TOOL_ROUNDS + 2):
if adapter is None:
break
events.emit("request_start", payload={"provider": "gemini_cli", "model": _model, "round": r_idx})
_append_comms("OUT", "request", {"message": f"[CLI] [round {r_idx}] [msg {len(payload)}]"})
send_payload = payload
if isinstance(payload, list):
send_payload = json.dumps(payload)
resp_data = adapter.send(cast(str, send_payload), safety_settings=safety_settings, system_instruction=sys_instr, model=_model, stream_callback=stream_callback)
cli_stderr = resp_data.get("stderr", "")
if cli_stderr:
sys.stderr.write(f"\n--- Gemini CLI stderr ---\n{cli_stderr}\n-------------------------\n")
sys.stderr.flush()
txt = cast(str, resp_data.get("text", ""))
if txt: all_text.append(txt)
calls = cast(List[dict[str, Any]], resp_data.get("tool_calls", []))
usage = adapter.last_usage or {}
latency = adapter.last_latency
events.emit("response_received", payload={"provider": "gemini_cli", "model": _model, "usage": usage, "latency": latency, "round": r_idx})
log_calls: list[dict[str, Any]] = []
for c in calls:
log_calls.append({"name": c.get("name"), "args": c.get("args"), "id": c.get("id")})
_append_comms("IN", "response", {
"round": r_idx,
"stop_reason": "TOOL_USE" if calls else "STOP",
"text": txt,
"tool_calls": log_calls,
"usage": usage
})
if txt and calls and comms_log_callback:
comms_log_callback({
"ts": project_manager.now_ts(),
"direction": "IN",
"kind": "history_add",
"payload": {
"role": "AI",
"content": txt
}
})
if not calls or r_idx > MAX_TOOL_ROUNDS:
break
# Execute tools concurrently
try:
loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback),
loop
).result()
except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback))
tool_results_for_cli: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results):
# Check if this is the last tool to trigger file refresh
if i == len(results) - 1:
if file_items:
file_items, changed = _reread_file_items(file_items)
ctx = _build_file_diff_text(changed)
if ctx:
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
if r_idx == MAX_TOOL_ROUNDS:
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
out = _truncate_tool_output(out)
_cumulative_tool_bytes += len(out)
tool_results_for_cli.append({
"role": "tool",
"tool_call_id": call_id,
"name": name,
"content": out
})
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
payload = tool_results_for_cli
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
final_text = all_text[-1] if all_text else "(No text returned)"
return final_text
except Exception as e:
raise ProviderError("unknown", "gemini_cli", e)
_CHARS_PER_TOKEN: float = 3.5
_ANTHROPIC_MAX_PROMPT_TOKENS: int = 180_000
_GEMINI_MAX_INPUT_TOKENS: int = 900_000
_FILE_REFRESH_MARKER: str = "[FILES UPDATED"
def _estimate_message_tokens(msg: dict[str, Any]) -> int:
cached = msg.get("_est_tokens")
if cached is not None:
return cast(int, cached)
total_chars = 0
content = msg.get("content", "")
if isinstance(content, str):
total_chars += len(content)
elif isinstance(content, list):
for block in content:
if isinstance(block, dict):
text = block.get("text", "") or block.get("content", "")
if isinstance(text, str):
total_chars += len(text)
inp = block.get("input")
if isinstance(inp, dict):
import json as _json
total_chars += len(_json.dumps(inp, ensure_ascii=False))
elif isinstance(block, str):
total_chars += len(block)
est = max(1, int(total_chars / _CHARS_PER_TOKEN))
msg["_est_tokens"] = est
return est
def _invalidate_token_estimate(msg: dict[str, Any]) -> None:
msg.pop("_est_tokens", None)
def _estimate_prompt_tokens(system_blocks: list[dict[str, Any]], history: list[dict[str, Any]]) -> int:
total = 0
for block in system_blocks:
text = cast(str, block.get("text", ""))
total += max(1, int(len(text) / _CHARS_PER_TOKEN))
total += 2500
for msg in history:
total += _estimate_message_tokens(msg)
return total
def _strip_stale_file_refreshes(history: list[dict[str, Any]]) -> None:
if len(history) < 2:
return
last_user_idx = -1
for i in range(len(history) - 1, -1, -1):
if history[i].get("role") == "user":
last_user_idx = i
break
for i, msg in enumerate(history):
if msg.get("role") != "user" or i == last_user_idx:
continue
content = msg.get("content")
if not isinstance(content, list):
continue
cleaned: list[dict[str, Any]] = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text = cast(str, block.get("text", ""))
if text.startswith(_FILE_REFRESH_MARKER):
continue
cleaned.append(block)
if len(cleaned) < len(content):
msg["content"] = cleaned
_invalidate_token_estimate(msg)
def _trim_anthropic_history(system_blocks: list[dict[str, Any]], history: list[dict[str, Any]]) -> int:
_strip_stale_file_refreshes(history)
est = _estimate_prompt_tokens(system_blocks, history)
if est <= _ANTHROPIC_MAX_PROMPT_TOKENS:
return 0
dropped = 0
while len(history) > 3 and est > _ANTHROPIC_MAX_PROMPT_TOKENS:
if history[1].get("role") == "assistant" and len(history) > 2 and history[2].get("role") == "user":
removed_asst = history.pop(1)
removed_user = history.pop(1)
dropped += 2
est -= _estimate_message_tokens(removed_asst)
est -= _estimate_message_tokens(removed_user)
while len(history) > 2 and history[1].get("role") == "assistant" and history[2].get("role") == "user":
content = history[2].get("content", [])
if isinstance(content, list) and content and isinstance(content[0], dict) and content[0].get("type") == "tool_result":
r_a = history.pop(1)
r_u = history.pop(1)
dropped += 2
est -= _estimate_message_tokens(r_a)
est -= _estimate_message_tokens(r_u)
else:
break
else:
removed = history.pop(1)
dropped += 1
est -= _estimate_message_tokens(removed)
return dropped
def _ensure_anthropic_client() -> None:
global _anthropic_client
if _anthropic_client is None:
creds = _load_credentials()
_anthropic_client = anthropic.Anthropic(
api_key=creds["anthropic"]["api_key"],
default_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
)
def _chunk_text(text: str, chunk_size: int) -> list[str]:
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
def _build_chunked_context_blocks(md_content: str) -> list[dict[str, Any]]:
chunks = _chunk_text(md_content, _ANTHROPIC_CHUNK_SIZE)
blocks: list[dict[str, Any]] = []
for i, chunk in enumerate(chunks):
block: dict[str, Any] = {"type": "text", "text": chunk}
if i == len(chunks) - 1:
block["cache_control"] = {"type": "ephemeral"}
blocks.append(block)
return blocks
def _strip_cache_controls(history: list[dict[str, Any]]) -> None:
for msg in history:
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
block.pop("cache_control", None)
def _add_history_cache_breakpoint(history: list[dict[str, Any]]) -> None:
user_indices = [i for i, m in enumerate(history) if m.get("role") == "user"]
if len(user_indices) < 2:
return
target_idx = user_indices[-2]
content = history[target_idx].get("content")
if isinstance(content, list) and content:
last_block = content[-1]
if isinstance(last_block, dict):
last_block["cache_control"] = {"type": "ephemeral"}
elif isinstance(content, str):
history[target_idx]["content"] = [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
def _repair_anthropic_history(history: list[dict[str, Any]]) -> None:
if not history:
return
last = history[-1]
if last.get("role") != "assistant":
return
content = last.get("content", [])
tool_use_ids: list[str] = []
for block in content:
if isinstance(block, dict):
if block.get("type") == "tool_use":
tool_use_ids.append(cast(str, block["id"]))
if not tool_use_ids:
return
history.append({
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": tid,
"content": "Tool call was not completed (session interrupted).",
}
for tid in tool_use_ids
],
})
def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_items: list[dict[str, Any]] | None = None, discussion_history: str = "", pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, qa_callback: Optional[Callable[[str], str]] = None, stream_callback: Optional[Callable[[str], None]] = None) -> str:
try:
_ensure_anthropic_client()
mcp_client.configure(file_items or [], [base_dir])
stable_prompt = _get_combined_system_prompt()
stable_blocks: list[dict[str, Any]] = [{"type": "text", "text": stable_prompt, "cache_control": {"type": "ephemeral"}}]
context_text = f"\n\n<context>\n{md_content}\n</context>"
context_blocks = _build_chunked_context_blocks(context_text)
system_blocks = stable_blocks + context_blocks
if discussion_history and not _anthropic_history:
user_content: list[dict[str, Any]] = [{"type": "text", "text": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"}]
else:
user_content = [{"type": "text", "text": user_message}]
for msg in _anthropic_history:
if msg.get("role") == "user" and isinstance(msg.get("content"), list):
modified = False
for block in cast(List[dict[str, Any]], msg["content"]):
if isinstance(block, dict) and block.get("type") == "tool_result":
t_content = block.get("content", "")
if _history_trunc_limit > 0 and isinstance(t_content, str) and len(t_content) > _history_trunc_limit:
block["content"] = t_content[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS. Original output was too large.]"
modified = True
if modified:
_invalidate_token_estimate(msg)
_strip_cache_controls(_anthropic_history)
_repair_anthropic_history(_anthropic_history)
_anthropic_history.append({"role": "user", "content": user_content})
_add_history_cache_breakpoint(_anthropic_history)
n_chunks = len(system_blocks)
_append_comms("OUT", "request", {
"message": (
f"[system {n_chunks} chunk(s), {len(md_content)} chars context] "
f"{user_message[:200]}{'...' if len(user_message) > 200 else ''}"
),
})
all_text_parts: list[str] = []
_cumulative_tool_bytes = 0
def _strip_private_keys(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
return [{k: v for k, v in m.items() if not k.startswith("_")} for m in history]
for round_idx in range(MAX_TOOL_ROUNDS + 2):
response: Any = None
dropped = _trim_anthropic_history(system_blocks, _anthropic_history)
if dropped > 0:
est_tokens = _estimate_prompt_tokens(system_blocks, _anthropic_history)
_append_comms("OUT", "request", {
"message": (
f"[HISTORY TRIMMED: dropped {dropped} old messages to fit token budget. "
f"Estimated {est_tokens} tokens remaining. {len(_anthropic_history)} messages in history.]"
),
})
events.emit("request_start", payload={"provider": "anthropic", "model": _model, "round": round_idx})
assert _anthropic_client is not None
if stream_callback:
with _anthropic_client.messages.stream(
model=_model,
max_tokens=_max_tokens,
temperature=_temperature,
system=cast(Iterable[anthropic.types.TextBlockParam], system_blocks),
tools=cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()),
messages=cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)),
) as stream:
for event in stream:
if isinstance(event, anthropic.types.ContentBlockDeltaEvent) and event.delta.type == "text_delta":
stream_callback(event.delta.text)
response = stream.get_final_message()
else:
response = _anthropic_client.messages.create(
model=_model,
max_tokens=_max_tokens,
temperature=_temperature,
system=cast(Iterable[anthropic.types.TextBlockParam], system_blocks),
tools=cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()),
messages=cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)),
)
serialised_content = [_content_block_to_dict(b) for b in response.content]
_anthropic_history.append({
"role": "assistant",
"content": serialised_content,
})
text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text]
if text_blocks:
all_text_parts.append("\n".join(text_blocks))
tool_use_blocks = [
{"id": getattr(b, "id"), "name": getattr(b, "name"), "input": getattr(b, "input")}
for b in response.content
if getattr(b, "type", None) == "tool_use"
]
usage_dict: dict[str, Any] = {}
if response.usage:
usage_dict["input_tokens"] = response.usage.input_tokens
usage_dict["output_tokens"] = response.usage.output_tokens
cache_creation = getattr(response.usage, "cache_creation_input_tokens", None)
cache_read = getattr(response.usage, "cache_read_input_tokens", None)
if cache_creation is not None:
usage_dict["cache_creation_input_tokens"] = cache_creation
if cache_read is not None:
usage_dict["cache_read_input_tokens"] = cache_read
events.emit("response_received", payload={"provider": "anthropic", "model": _model, "usage": usage_dict, "round": round_idx})
_append_comms("IN", "response", {
"round": round_idx,
"stop_reason": response.stop_reason,
"text": "\n".join(text_blocks),
"tool_calls": tool_use_blocks,
"usage": usage_dict,
})
if response.stop_reason != "tool_use" or not tool_use_blocks:
break
if round_idx > MAX_TOOL_ROUNDS:
break
# Execute tools concurrently
try:
loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic"),
loop
).result()
except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic", patch_callback))
tool_results: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results):
truncated = _truncate_tool_output(out)
_cumulative_tool_bytes += len(truncated)
tool_results.append({
"type": "tool_result",
"tool_use_id": call_id,
"content": truncated,
})
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx})
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
tool_results.append({
"type": "text",
"text": f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
})
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
if file_items:
file_items, changed = _reread_file_items(file_items)
refreshed_ctx = _build_file_diff_text(changed)
if refreshed_ctx:
tool_results.append({
"type": "text",
"text": (
"[FILES UPDATED \u2014 current contents below. "
"Do NOT re-read these files with PowerShell.]\n\n"
+ refreshed_ctx
),
})
if round_idx == MAX_TOOL_ROUNDS:
tool_results.append({
"type": "text",
"text": "SYSTEM WARNING: MAX TOOL ROUNDS REACHED. YOU MUST PROVIDE YOUR FINAL ANSWER NOW WITHOUT CALLING ANY MORE TOOLS."
})
_anthropic_history.append({
"role": "user",
"content": tool_results,
})
_append_comms("OUT", "tool_result_send", {
"results": [
{"tool_use_id": r["tool_use_id"], "content": r["content"]}
for r in tool_results if r.get("type") == "tool_result"
],
})
final_text = "\n\n".join(all_text_parts)
return final_text if final_text.strip() else "(No text returned by the model)"
except ProviderError:
raise
except Exception as exc:
raise _classify_anthropic_error(exc) from exc
def _ensure_deepseek_client() -> None:
global _deepseek_client
if _deepseek_client is None:
_load_credentials()
pass
def _ensure_minimax_client() -> None:
global _minimax_client
if _minimax_client is None:
from openai import OpenAI
creds = _load_credentials()
api_key = creds.get("minimax", {}).get("api_key")
if not api_key:
raise ValueError("MiniMax API key not found in credentials.toml")
_minimax_client = OpenAI(api_key=api_key, base_url="https://api.minimax.chat/v1")
def _send_deepseek(md_content: str, user_message: str, base_dir: str,
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
stream: bool = False,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
qa_callback: Optional[Callable[[str], str]] = None,
stream_callback: Optional[Callable[[str], None]] = None) -> str:
try:
mcp_client.configure(file_items or [], [base_dir])
creds = _load_credentials()
api_key = creds.get("deepseek", {}).get("api_key")
if not api_key:
raise ValueError("DeepSeek API key not found in credentials.toml")
api_url = "https://api.deepseek.com/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
is_reasoner = _model in ("deepseek-reasoner", "deepseek-r1")
# Update history following Anthropic pattern
with _deepseek_history_lock:
if discussion_history and not _deepseek_history:
user_content = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"
else:
user_content = user_message
_deepseek_history.append({"role": "user", "content": user_content})
all_text_parts: list[str] = []
_cumulative_tool_bytes = 0
for round_idx in range(MAX_TOOL_ROUNDS + 2):
current_api_messages: list[dict[str, Any]] = []
# DeepSeek R1 (Reasoner) can be extremely strict about the 'system' role.
# For maximum compatibility, we'll only use 'system' for non-reasoner models.
if not is_reasoner:
sys_msg = {"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}
current_api_messages.append(sys_msg)
with _deepseek_history_lock:
for i, msg in enumerate(_deepseek_history):
# Create a clean copy of the message for the API
role = msg.get("role")
api_msg = {"role": role}
content = msg.get("content")
if i == 0 and is_reasoner:
# Prepend system instructions to the first user message for R1
content = f"System Instructions:\n{_get_combined_system_prompt()}\n\nContext:\n{md_content}\n\n---\n\n{content}"
if role == "assistant":
# OpenAI/DeepSeek: content MUST be a string if tool_calls is absent
# If tool_calls is present, content can be null
if msg.get("tool_calls"):
api_msg["content"] = content or None
api_msg["tool_calls"] = msg["tool_calls"]
else:
api_msg["content"] = content or ""
if msg.get("reasoning_content"):
api_msg["reasoning_content"] = msg["reasoning_content"]
elif role == "tool":
api_msg["content"] = content or ""
api_msg["tool_call_id"] = msg.get("tool_call_id")
else:
api_msg["content"] = content or ""
current_api_messages.append(api_msg)
request_payload: dict[str, Any] = {
"model": _model,
"messages": current_api_messages,
"stream": stream,
}
if stream:
request_payload["stream_options"] = {"include_usage": True}
if not is_reasoner:
request_payload["temperature"] = _temperature
# DeepSeek max_tokens is for the output, clamp to 8192 which is their hard limit for V3/Chat
request_payload["max_tokens"] = min(_max_tokens, 8192)
tools = _get_deepseek_tools()
if tools:
request_payload["tools"] = tools
events.emit("request_start", payload={"provider": "deepseek", "model": _model, "round": round_idx, "streaming": stream})
try:
response = requests.post(api_url, headers=headers, json=request_payload, timeout=120, stream=stream)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise _classify_deepseek_error(e) from e
assistant_text = ""
tool_calls_raw = []
reasoning_content = ""
finish_reason = "stop"
usage = {}
if stream:
aggregated_content = ""
aggregated_tool_calls: list[dict[str, Any]] = []
aggregated_reasoning = ""
current_usage: dict[str, Any] = {}
final_finish_reason = "stop"
for line in response.iter_lines():
if not line:
continue
decoded = line.decode('utf-8')
if decoded.startswith('data: '):
chunk_str = decoded[len('data: '):]
if chunk_str.strip() == '[DONE]':
continue
try:
chunk = json.loads(chunk_str)
if not chunk.get("choices"):
if chunk.get("usage"):
current_usage = cast(dict[str, Any], chunk["usage"])
continue
delta = cast(dict[str, Any], chunk.get("choices", [{}])[0].get("delta", {}))
if delta.get("content"):
content_chunk = cast(str, delta["content"])
aggregated_content += content_chunk
if stream_callback:
stream_callback(content_chunk)
if delta.get("reasoning_content"):
aggregated_reasoning += cast(str, delta["reasoning_content"])
if delta.get("tool_calls"):
for tc_delta in cast(List[dict[str, Any]], delta["tool_calls"]):
idx = cast(int, tc_delta.get("index", 0))
while len(aggregated_tool_calls) <= idx:
aggregated_tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
target = aggregated_tool_calls[idx]
if tc_delta.get("id"):
target["id"] = cast(str, tc_delta["id"])
if tc_delta.get("function", {}).get("name"):
target["function"]["name"] += cast(str, tc_delta["function"]["name"])
if tc_delta.get("function", {}).get("arguments"):
target["function"]["arguments"] += cast(str, tc_delta["function"]["arguments"])
if chunk.get("choices", [{}])[0].get("finish_reason"):
final_finish_reason = cast(str, chunk["choices"][0]["finish_reason"])
if chunk.get("usage"):
current_usage = cast(dict[str, Any], chunk["usage"])
except json.JSONDecodeError:
continue
assistant_text = aggregated_content
tool_calls_raw = aggregated_tool_calls
reasoning_content = aggregated_reasoning
finish_reason = final_finish_reason
usage = current_usage
else:
response_data = response.json()
choices = response_data.get("choices", [])
if not choices:
_append_comms("IN", "response", {"round": round_idx, "text": "(No choices returned)", "usage": response_data.get("usage", {})})
break
choice = choices[0]
message = choice.get("message", {})
assistant_text = message.get("content", "")
tool_calls_raw = message.get("tool_calls", [])
reasoning_content = message.get("reasoning_content", "")
finish_reason = choice.get("finish_reason", "stop")
usage = response_data.get("usage", {})
thinking_tags = ""
if reasoning_content:
thinking_tags = f"<thinking>\n{reasoning_content}\n</thinking>\n"
full_assistant_text = thinking_tags + assistant_text
with _deepseek_history_lock:
# DeepSeek/OpenAI: If tool_calls are present, content can be null but should usually be present
msg_to_store: dict[str, Any] = {"role": "assistant", "content": assistant_text or None}
if reasoning_content:
msg_to_store["reasoning_content"] = reasoning_content
if tool_calls_raw:
msg_to_store["tool_calls"] = tool_calls_raw
_deepseek_history.append(msg_to_store)
if full_assistant_text:
all_text_parts.append(full_assistant_text)
_append_comms("IN", "response", {
"round": round_idx,
"stop_reason": finish_reason,
"text": full_assistant_text,
"tool_calls": tool_calls_raw,
"usage": usage,
"streaming": stream
})
if finish_reason != "tool_calls" and not tool_calls_raw:
break
if round_idx > MAX_TOOL_ROUNDS:
break
# Execute tools concurrently
try:
loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback),
loop
).result()
except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback))
tool_results_for_history: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results):
if i == len(results) - 1:
if file_items:
file_items, changed = _reread_file_items(file_items)
ctx = _build_file_diff_text(changed)
if ctx:
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
if round_idx == MAX_TOOL_ROUNDS:
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
truncated = _truncate_tool_output(out)
_cumulative_tool_bytes += len(truncated)
tool_results_for_history.append({
"role": "tool",
"tool_call_id": call_id,
"content": truncated,
})
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx})
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
tool_results_for_history.append({
"role": "user",
"content": f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
})
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
with _deepseek_history_lock:
for tr in tool_results_for_history:
_deepseek_history.append(tr)
return "\n\n".join(all_text_parts) if all_text_parts else "(No text returned)"
except Exception as e:
raise _classify_deepseek_error(e) from e
def _send_minimax(md_content: str, user_message: str, base_dir: str,
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
stream: bool = False,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
qa_callback: Optional[Callable[[str], str]] = None,
stream_callback: Optional[Callable[[str], None]] = None) -> str:
try:
mcp_client.configure(file_items or [], [base_dir])
creds = _load_credentials()
api_key = creds.get("minimax", {}).get("api_key")
if not api_key:
raise ValueError("MiniMax API key not found in credentials.toml")
from openai import OpenAI
client = OpenAI(api_key=api_key, base_url="https://api.minimax.io/v1")
with _minimax_history_lock:
if discussion_history and not _minimax_history:
user_content = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"
else:
user_content = user_message
_minimax_history.append({"role": "user", "content": user_content})
all_text_parts: list[str] = []
_cumulative_tool_bytes = 0
for round_idx in range(MAX_TOOL_ROUNDS + 2):
current_api_messages: list[dict[str, Any]] = []
sys_msg = {"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}
current_api_messages.append(sys_msg)
with _minimax_history_lock:
for i, msg in enumerate(_minimax_history):
role = msg.get("role")
api_msg = {"role": role}
content = msg.get("content")
if role == "assistant":
if msg.get("tool_calls"):
api_msg["content"] = content or None
api_msg["tool_calls"] = msg["tool_calls"]
else:
api_msg["content"] = content or ""
elif role == "tool":
api_msg["content"] = content or ""
api_msg["tool_call_id"] = msg.get("tool_call_id")
else:
api_msg["content"] = content or ""
current_api_messages.append(api_msg)
request_payload: dict[str, Any] = {
"model": _model,
"messages": current_api_messages,
"stream": stream,
"extra_body": {"reasoning_split": True},
}
if stream:
request_payload["stream_options"] = {"include_usage": True}
request_payload["temperature"] = 1.0
request_payload["max_tokens"] = min(_max_tokens, 8192)
tools = _get_deepseek_tools()
if tools:
request_payload["tools"] = tools
events.emit("request_start", payload={"provider": "minimax", "model": _model, "round": round_idx, "streaming": stream})
try:
response = client.chat.completions.create(**request_payload, timeout=120)
except Exception as e:
raise _classify_minimax_error(e) from e
assistant_text = ""
tool_calls_raw = []
reasoning_content = ""
finish_reason = "stop"
usage = {}
if stream:
aggregated_content = ""
aggregated_tool_calls: list[dict[str, Any]] = []
aggregated_reasoning = ""
current_usage: dict[str, Any] = {}
final_finish_reason = "stop"
for chunk in response:
if not chunk.choices:
if chunk.usage:
current_usage = chunk.usage.model_dump()
continue
delta = chunk.choices[0].delta
if delta.content:
content_chunk = delta.content
aggregated_content += content_chunk
if stream_callback:
stream_callback(content_chunk)
if hasattr(delta, "reasoning_details") and delta.reasoning_details:
for detail in delta.reasoning_details:
if "text" in detail:
aggregated_reasoning += detail["text"]
if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
while len(aggregated_tool_calls) <= idx:
aggregated_tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
target = aggregated_tool_calls[idx]
if tc_delta.id:
target["id"] = tc_delta.id
if tc_delta.function and tc_delta.function.name:
target["function"]["name"] += tc_delta.function.name
if tc_delta.function and tc_delta.function.arguments:
target["function"]["arguments"] += tc_delta.function.arguments
if chunk.choices[0].finish_reason:
final_finish_reason = chunk.choices[0].finish_reason
if chunk.usage:
current_usage = chunk.usage.model_dump()
assistant_text = aggregated_content
tool_calls_raw = aggregated_tool_calls
reasoning_content = aggregated_reasoning
finish_reason = final_finish_reason
usage = current_usage
else:
choice = response.choices[0]
message = choice.message
assistant_text = message.content or ""
tool_calls_raw = message.tool_calls or []
if hasattr(message, "reasoning_details") and message.reasoning_details:
reasoning_content = message.reasoning_details[0].get("text", "") if message.reasoning_details else ""
finish_reason = choice.finish_reason or "stop"
usage = response.usage.model_dump() if response.usage else {}
thinking_tags = ""
if reasoning_content:
thinking_tags = f"<thinking>\n{reasoning_content}\n</thinking>\n"
full_assistant_text = thinking_tags + assistant_text
with _minimax_history_lock:
msg_to_store: dict[str, Any] = {"role": "assistant", "content": assistant_text or None}
if reasoning_content:
msg_to_store["reasoning_content"] = reasoning_content
if tool_calls_raw:
msg_to_store["tool_calls"] = tool_calls_raw
_minimax_history.append(msg_to_store)
if full_assistant_text:
all_text_parts.append(full_assistant_text)
_append_comms("IN", "response", {
"round": round_idx,
"stop_reason": finish_reason,
"text": full_assistant_text,
"tool_calls": tool_calls_raw,
"usage": usage,
"streaming": stream
})
if finish_reason != "tool_calls" and not tool_calls_raw:
break
if round_idx > MAX_TOOL_ROUNDS:
break
try:
loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax", patch_callback),
loop
).result()
except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax", patch_callback))
tool_results_for_history: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results):
if i == len(results) - 1:
if file_items:
file_items, changed = _reread_file_items(file_items)
ctx = _build_file_diff_text(changed)
if ctx:
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
if round_idx == MAX_TOOL_ROUNDS:
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
truncated = _truncate_tool_output(out)
_cumulative_tool_bytes += len(truncated)
tool_results_for_history.append({
"role": "tool",
"tool_call_id": call_id,
"content": truncated,
})
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx})
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
tool_results_for_history.append({
"role": "user",
"content": f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
})
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
with _minimax_history_lock:
for tr in tool_results_for_history:
_minimax_history.append(tr)
return "\n\n".join(all_text_parts) if all_text_parts else "(No text returned)"
except Exception as e:
raise _classify_minimax_error(e) from e
def run_tier4_analysis(stderr: str) -> str:
if not stderr or not stderr.strip():
return ""
try:
_ensure_gemini_client()
if not _gemini_client:
return ""
prompt = (
f"You are a Tier 4 QA Agent specializing in error analysis.\n"
f"Analyze the following stderr output from a PowerShell command:\n\n"
f"```\n{stderr}\n```\n\n"
f"Provide a concise summary of the failure and suggest a fix in approximately 20 words."
)
model_name = "gemini-2.5-flash-lite"
resp = _gemini_client.models.generate_content(
model=model_name,
contents=prompt,
config=types.GenerateContentConfig(
temperature=0.0,
max_output_tokens=150,
)
)
analysis = resp.text.strip() if resp.text else ""
return analysis
except Exception as e:
return f"[QA ANALYSIS FAILED] {e}"
def run_tier4_patch_callback(stderr: str, base_dir: str) -> Optional[str]:
try:
from src import project_manager
file_items = project_manager.get_current_file_items()
file_context = ""
for item in file_items[:5]:
path = item.get("path", "")
content = item.get("content", "")[:2000]
file_context += f"\n\nFile: {path}\n```\n{content}\n```\n"
patch = run_tier4_patch_generation(stderr, file_context)
if patch and "---" in patch and "+++" in patch:
return patch
return None
except Exception as e:
return None
def run_tier4_patch_generation(error: str, file_context: str) -> str:
if not error or not error.strip():
return ""
try:
_ensure_gemini_client()
if not _gemini_client:
return ""
prompt = (
f"{mma_prompts.TIER4_PATCH_PROMPT}\n\n"
f"Error:\n```\n{error}\n```\n\n"
f"File Context:\n```\n{file_context}\n```\n"
)
model_name = "gemini-2.5-flash-lite"
resp = _gemini_client.models.generate_content(
model=model_name,
contents=prompt,
config=types.GenerateContentConfig(
temperature=0.0,
max_output_tokens=2048,
)
)
patch = resp.text.strip() if resp.text else ""
return patch
except Exception as e:
return f"[PATCH GENERATION FAILED] {e}"
def get_token_stats(md_content: str) -> dict[str, Any]:
global _provider, _gemini_client, _model, _CHARS_PER_TOKEN
total_tokens = 0
if _provider == "gemini":
try:
_ensure_gemini_client()
if _gemini_client:
resp = _gemini_client.models.count_tokens(model=_model, contents=md_content)
total_tokens = cast(int, resp.total_tokens)
except Exception:
pass
elif _provider == "gemini_cli":
try:
_ensure_gemini_client()
if _gemini_client:
resp = _gemini_client.models.count_tokens(model=_model, contents=md_content)
total_tokens = cast(int, resp.total_tokens)
except Exception:
pass
if total_tokens == 0:
total_tokens = max(1, int(len(md_content) / _CHARS_PER_TOKEN))
limit = _GEMINI_MAX_INPUT_TOKENS if _provider in ["gemini", "gemini_cli"] else _ANTHROPIC_MAX_PROMPT_TOKENS
if _provider == "deepseek":
limit = 64000
pct = (total_tokens / limit * 100) if limit > 0 else 0
stats = {
"total_tokens": total_tokens,
"current": total_tokens,
"limit": limit,
"percentage": pct
}
return _add_bleed_derived(stats, sys_tok=total_tokens)
def send(
md_content: str,
user_message: str,
base_dir: str = ".",
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
stream: bool = False,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
qa_callback: Optional[Callable[[str], str]] = None,
enable_tools: bool = True,
stream_callback: Optional[Callable[[str], None]] = None,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None,
) -> str:
with _send_lock:
if _provider == "gemini":
return _send_gemini(
md_content, user_message, base_dir, file_items, discussion_history,
pre_tool_callback, qa_callback, enable_tools, stream_callback, patch_callback
)
elif _provider == "gemini_cli":
return _send_gemini_cli(
md_content, user_message, base_dir, file_items, discussion_history,
pre_tool_callback, qa_callback, stream_callback, patch_callback
)
elif _provider == "anthropic":
return _send_anthropic(
md_content, user_message, base_dir, file_items, discussion_history,
pre_tool_callback, qa_callback, stream_callback=stream_callback, patch_callback=patch_callback
)
elif _provider == "deepseek":
return _send_deepseek(
md_content, user_message, base_dir, file_items, discussion_history,
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
)
elif _provider == "minimax":
return _send_minimax(
md_content, user_message, base_dir, file_items, discussion_history,
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
)
else:
raise ValueError(f"Unknown provider: {_provider}")
def _add_bleed_derived(d: dict[str, Any], sys_tok: int = 0, tool_tok: int = 0) -> dict[str, Any]:
cur = d.get("current", 0)
lim = d.get("limit", 0)
d["estimated_prompt_tokens"] = cur
d["max_prompt_tokens"] = lim
d["utilization_pct"] = d.get("percentage", 0.0)
d["headroom"] = max(0, lim - cur)
d["would_trim"] = cur >= lim
d["sys_tokens"] = sys_tok
d["tool_tokens"] = tool_tok
d["history_tokens"] = max(0, cur - sys_tok - tool_tok)
return d
def _is_mutating_tool(name: str) -> bool:
"""Returns True if the tool name is considered a mutating tool."""
return name in mcp_client.MUTATING_TOOLS or name == TOOL_NAME
def _confirm_and_run(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Optional[str]:
"""
Wrapper for the confirm_and_run_callback.
This is what the providers call to trigger HITL approval.
"""
if confirm_and_run_callback:
return confirm_and_run_callback(script, base_dir, qa_callback, patch_callback)
# Fallback to direct execution if no callback registered (headless default)
from src import shell_runner
return shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback, patch_callback=patch_callback)
def get_history_bleed_stats(md_content: Optional[str] = None) -> dict[str, Any]:
if _provider == "anthropic":
with _anthropic_history_lock:
history_snapshot = list(_anthropic_history)
sys_tok = max(1, int(len(md_content) / _CHARS_PER_TOKEN)) if md_content else 0
current_tokens = _estimate_prompt_tokens([], history_snapshot)
if md_content:
current_tokens += max(1, int(len(md_content) / _CHARS_PER_TOKEN))
limit_tokens = _ANTHROPIC_MAX_PROMPT_TOKENS
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
return _add_bleed_derived({
"provider": "anthropic",
"limit": limit_tokens,
"current": current_tokens,
"percentage": percentage,
}, sys_tok=sys_tok, tool_tok=2500)
elif _provider == "gemini":
effective_limit = _history_trunc_limit if _history_trunc_limit > 0 else _GEMINI_MAX_INPUT_TOKENS
if _gemini_chat:
try:
_ensure_gemini_client()
if _gemini_client:
raw_history = list(_get_gemini_history_list(_gemini_chat))
history: list[types.Content] = []
for c in raw_history:
role = "model" if c.role in ["assistant", "model"] else "user"
history.append(types.Content(role=role, parts=c.parts))
if md_content:
history.insert(0, types.Content(role="user", parts=[types.Part(text=md_content)]))
if not history:
return _add_bleed_derived({
"provider": "gemini",
"limit": effective_limit,
"current": 0,
"percentage": 0,
})
resp = _gemini_client.models.count_tokens(
model=_model,
contents=history
)
current_tokens = cast(int, resp.total_tokens)
percentage = (current_tokens / effective_limit) * 100 if effective_limit > 0 else 0
return _add_bleed_derived({
"provider": "gemini",
"limit": effective_limit,
"current": current_tokens,
"percentage": percentage,
}, sys_tok=0, tool_tok=0)
except Exception:
pass
elif md_content:
try:
_ensure_gemini_client()
if _gemini_client:
resp = _gemini_client.models.count_tokens(
model=_model,
contents=[types.Content(role="user", parts=[types.Part(text=md_content)])]
)
current_tokens = cast(int, resp.total_tokens)
percentage = (current_tokens / effective_limit) * 100 if effective_limit > 0 else 0
return _add_bleed_derived({
"provider": "gemini",
"limit": effective_limit,
"current": current_tokens,
"percentage": percentage,
})
except Exception:
pass
return _add_bleed_derived({
"provider": "gemini",
"limit": effective_limit,
"current": 0,
"percentage": 0,
})
elif _provider == "gemini_cli":
effective_limit = _history_trunc_limit if _history_trunc_limit > 0 else _GEMINI_MAX_INPUT_TOKENS
limit_tokens = effective_limit
current_tokens = 0
if _gemini_cli_adapter and _gemini_cli_adapter.last_usage:
u = _gemini_cli_adapter.last_usage
current_tokens = cast(int, u.get("input_tokens") or u.get("input", 0))
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
return _add_bleed_derived({
"provider": "gemini_cli",
"limit": limit_tokens,
"current": current_tokens,
"percentage": percentage,
})
elif _provider == "deepseek":
limit_tokens = 64000
current_tokens = 0
with _deepseek_history_lock:
for msg in _deepseek_history:
content = msg.get("content", "")
if isinstance(content, str):
current_tokens += len(content)
elif isinstance(content, list):
for block in content:
if isinstance(block, dict):
text = block.get("text", "")
if isinstance(text, str):
current_tokens += len(text)
inp = block.get("input")
if isinstance(inp, dict):
import json as _json
current_tokens += len(_json.dumps(inp, ensure_ascii=False))
if md_content: current_tokens += len(md_content)
current_tokens = max(1, int(current_tokens / _CHARS_PER_TOKEN))
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
return _add_bleed_derived({
"provider": "deepseek",
"limit": limit_tokens,
"current": current_tokens,
"percentage": percentage,
})
elif _provider == "minimax":
limit_tokens = 204800
current_tokens = 0
with _minimax_history_lock:
for msg in _minimax_history:
content = msg.get("content", "")
if isinstance(content, str):
current_tokens += len(content)
elif isinstance(content, list):
for block in content:
if isinstance(block, dict):
text = block.get("text", "")
if isinstance(text, str):
current_tokens += len(text)
inp = block.get("input")
if isinstance(inp, dict):
import json as _json
current_tokens += len(_json.dumps(inp, ensure_ascii=False))
if md_content: current_tokens += len(md_content)
current_tokens = max(1, int(current_tokens / _CHARS_PER_TOKEN))
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
return _add_bleed_derived({
"provider": "minimax",
"limit": limit_tokens,
"current": current_tokens,
"percentage": percentage,
})
return _add_bleed_derived({
"provider": _provider,
"limit": 0,
"current": 0,
"percentage": 0,
})