Files
manual_slop/ai_client.py
2026-02-22 11:32:54 -05:00

866 lines
36 KiB
Python

# ai_client.py
"""
Note(Gemini):
Acts as the unified interface for multiple LLM providers (Anthropic, Gemini).
Abstracts away the differences in how they handle tool schemas, history, and caching.
For Anthropic: aggressively manages the ~200k token limit by manually culling
stale [FILES UPDATED] entries and dropping the oldest message pairs.
For Gemini: injects the initial context directly into system_instruction
during chat creation to avoid massive history bloat.
"""
# ai_client.py
import tomllib
import json
import datetime
from pathlib import Path
import file_cache
import mcp_client
_provider: str = "gemini"
_model: str = "gemini-2.5-flash"
_temperature: float = 0.0
_max_tokens: int = 8192
_history_trunc_limit: int = 8000
def set_model_params(temp: float, max_tok: int, trunc_limit: int = 8000):
global _temperature, _max_tokens, _history_trunc_limit
_temperature = temp
_max_tokens = max_tok
_history_trunc_limit = trunc_limit
_gemini_client = None
_gemini_chat = None
_gemini_cache = None
_anthropic_client = None
_anthropic_history: list[dict] = []
# Injected by gui.py - called when AI wants to run a command.
# Signature: (script: str, base_dir: str) -> str | None
confirm_and_run_callback = None
# Injected by gui.py - called whenever a comms entry is appended.
# Signature: (entry: dict) -> None
comms_log_callback = None
# Injected by gui.py - called whenever a tool call completes.
# Signature: (script: str, result: str) -> None
tool_log_callback = None
# Increased to allow thorough code exploration before forcing a summary
MAX_TOOL_ROUNDS = 10
# Maximum characters per text chunk sent to Anthropic.
# Kept well under the ~200k token API limit.
_ANTHROPIC_CHUNK_SIZE = 120_000
_SYSTEM_PROMPT = (
"You are a helpful coding assistant with access to a PowerShell tool and MCP tools (file access: read_file, list_directory, search_files, get_file_summary, web access: web_search, fetch_url). "
"When asked to create or edit files, prefer targeted edits over full rewrites. "
"Always explain what you are doing before invoking the tool.\n\n"
"When writing or rewriting large files (especially those containing quotes, backticks, or special characters), "
"avoid python -c with inline strings. Instead: (1) write a .py helper script to disk using a PS here-string "
"(@'...'@ for literal content), (2) run it with `python <script>`, (3) delete the helper. "
"For small targeted edits, use PowerShell's (Get-Content) / .Replace() / Set-Content or Add-Content directly.\n\n"
"When making function calls using tools that accept array or object parameters "
"ensure those are structured using JSON. For example:\n"
"When you need to verify a change, rely on the exit code and stdout/stderr from the tool \u2014 "
"the user's context files are automatically refreshed after every tool call, so you do NOT "
"need to re-read files that are already provided in the <context> block."
)
_custom_system_prompt: str = ""
def set_custom_system_prompt(prompt: str):
global _custom_system_prompt
_custom_system_prompt = prompt
def _get_combined_system_prompt() -> str:
if _custom_system_prompt.strip():
return f"{_SYSTEM_PROMPT}\n\n[USER SYSTEM PROMPT]\n{_custom_system_prompt}"
return _SYSTEM_PROMPT
# ------------------------------------------------------------------ comms log
_comms_log: list[dict] = []
COMMS_CLAMP_CHARS = 300
def _append_comms(direction: str, kind: str, payload: dict):
entry = {
"ts": datetime.datetime.now().strftime("%H:%M:%S"),
"direction": direction,
"kind": kind,
"provider": _provider,
"model": _model,
"payload": payload,
}
_comms_log.append(entry)
if comms_log_callback is not None:
comms_log_callback(entry)
def get_comms_log() -> list[dict]:
return list(_comms_log)
def clear_comms_log():
_comms_log.clear()
def _load_credentials() -> dict:
with open("credentials.toml", "rb") as f:
return tomllib.load(f)
# ------------------------------------------------------------------ provider errors
class ProviderError(Exception):
def __init__(self, kind: str, provider: str, original: Exception):
self.kind = kind
self.provider = provider
self.original = original
super().__init__(str(original))
def ui_message(self) -> str:
labels = {
"quota": "QUOTA EXHAUSTED",
"rate_limit": "RATE LIMITED",
"auth": "AUTH / API KEY ERROR",
"balance": "BALANCE / BILLING ERROR",
"network": "NETWORK / CONNECTION ERROR",
"unknown": "API ERROR",
}
label = labels.get(self.kind, "API ERROR")
return f"[{self.provider.upper()} {label}]\n\n{self.original}"
def _classify_anthropic_error(exc: Exception) -> ProviderError:
try:
import anthropic
if isinstance(exc, anthropic.RateLimitError):
return ProviderError("rate_limit", "anthropic", exc)
if isinstance(exc, anthropic.AuthenticationError):
return ProviderError("auth", "anthropic", exc)
if isinstance(exc, anthropic.PermissionDeniedError):
return ProviderError("auth", "anthropic", exc)
if isinstance(exc, anthropic.APIConnectionError):
return ProviderError("network", "anthropic", exc)
if isinstance(exc, anthropic.APIStatusError):
status = getattr(exc, "status_code", 0)
body = str(exc).lower()
if status == 429:
return ProviderError("rate_limit", "anthropic", exc)
if status in (401, 403):
return ProviderError("auth", "anthropic", exc)
if status == 402:
return ProviderError("balance", "anthropic", exc)
if "credit" in body or "balance" in body or "billing" in body:
return ProviderError("balance", "anthropic", exc)
if "quota" in body or "limit" in body or "exceeded" in body:
return ProviderError("quota", "anthropic", exc)
except ImportError:
pass
return ProviderError("unknown", "anthropic", exc)
def _classify_gemini_error(exc: Exception) -> ProviderError:
body = str(exc).lower()
try:
from google.api_core import exceptions as gac
if isinstance(exc, gac.ResourceExhausted):
return ProviderError("quota", "gemini", exc)
if isinstance(exc, gac.TooManyRequests):
return ProviderError("rate_limit", "gemini", exc)
if isinstance(exc, (gac.Unauthenticated, gac.PermissionDenied)):
return ProviderError("auth", "gemini", exc)
if isinstance(exc, gac.ServiceUnavailable):
return ProviderError("network", "gemini", exc)
except ImportError:
pass
if "429" in body or "quota" in body or "resource exhausted" in body:
return ProviderError("quota", "gemini", exc)
if "rate" in body and "limit" in body:
return ProviderError("rate_limit", "gemini", exc)
if "401" in body or "403" in body or "api key" in body or "unauthenticated" in body:
return ProviderError("auth", "gemini", exc)
if "402" in body or "billing" in body or "balance" in body or "payment" in body:
return ProviderError("balance", "gemini", exc)
if "connection" in body or "timeout" in body or "unreachable" in body:
return ProviderError("network", "gemini", exc)
return ProviderError("unknown", "gemini", exc)
# ------------------------------------------------------------------ provider setup
def set_provider(provider: str, model: str):
global _provider, _model
_provider = provider
_model = model
def cleanup():
"""Called on application exit to prevent orphaned caches from billing."""
global _gemini_client, _gemini_cache
if _gemini_client and _gemini_cache:
try:
_gemini_client.caches.delete(name=_gemini_cache.name)
except Exception:
pass
def reset_session():
global _gemini_client, _gemini_chat, _gemini_cache
global _anthropic_client, _anthropic_history
global _CACHED_ANTHROPIC_TOOLS
if _gemini_client and _gemini_cache:
try:
_gemini_client.caches.delete(name=_gemini_cache.name)
except Exception:
pass
_gemini_client = None
_gemini_chat = None
_gemini_cache = None
_anthropic_client = None
_anthropic_history = []
_CACHED_ANTHROPIC_TOOLS = None
file_cache.reset_client()
# ------------------------------------------------------------------ model listing
def list_models(provider: str) -> list[str]:
creds = _load_credentials()
if provider == "gemini":
return _list_gemini_models(creds["gemini"]["api_key"])
elif provider == "anthropic":
return _list_anthropic_models()
return []
def _list_gemini_models(api_key: str) -> list[str]:
from google import genai
try:
client = genai.Client(api_key=api_key)
models = []
for m in client.models.list():
name = m.name
if name.startswith("models/"):
name = name[len("models/"):]
if "gemini" in name.lower():
models.append(name)
return sorted(models)
except Exception as exc:
raise _classify_gemini_error(exc) from exc
def _list_anthropic_models() -> list[str]:
import anthropic
try:
creds = _load_credentials()
client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
models = []
for m in client.models.list():
models.append(m.id)
return sorted(models)
except Exception as exc:
raise _classify_anthropic_error(exc) from exc
# ------------------------------------------------------------------ tool definition
TOOL_NAME = "run_powershell"
def _build_anthropic_tools() -> list[dict]:
"""Build the full Anthropic tools list: run_powershell + MCP file tools."""
mcp_tools = []
for spec in mcp_client.MCP_TOOL_SPECS:
mcp_tools.append({
"name": spec["name"],
"description": spec["description"],
"input_schema": spec["parameters"],
})
powershell_tool = {
"name": TOOL_NAME,
"description": (
"Run a PowerShell script within the project base_dir. "
"Use this to create, edit, rename, or delete files and directories. "
"The working directory is set to base_dir automatically. "
"Always prefer targeted edits over full rewrites where possible. "
"stdout and stderr are returned to you as the result."
),
"input_schema": {
"type": "object",
"properties": {
"script": {
"type": "string",
"description": "The PowerShell script to execute."
}
},
"required": ["script"]
},
"cache_control": {"type": "ephemeral"},
}
return mcp_tools + [powershell_tool]
_ANTHROPIC_TOOLS = _build_anthropic_tools()
_CACHED_ANTHROPIC_TOOLS = None
def _get_anthropic_tools() -> list[dict]:
"""Return the Anthropic tools list, rebuilding only once per session."""
global _CACHED_ANTHROPIC_TOOLS
if _CACHED_ANTHROPIC_TOOLS is None:
_CACHED_ANTHROPIC_TOOLS = _build_anthropic_tools()
return _CACHED_ANTHROPIC_TOOLS
def _gemini_tool_declaration():
from google.genai import types
declarations = []
# MCP file tools
for spec in mcp_client.MCP_TOOL_SPECS:
props = {}
for pname, pdef in spec["parameters"].get("properties", {}).items():
props[pname] = types.Schema(
type=types.Type.STRING,
description=pdef.get("description", ""),
)
declarations.append(types.FunctionDeclaration(
name=spec["name"],
description=spec["description"],
parameters=types.Schema(
type=types.Type.OBJECT,
properties=props,
required=spec["parameters"].get("required", []),
),
))
# PowerShell tool
declarations.append(types.FunctionDeclaration(
name=TOOL_NAME,
description=(
"Run a PowerShell script within the project base_dir. "
"Use this to create, edit, rename, or delete files and directories. "
"The working directory is set to base_dir automatically. "
"stdout and stderr are returned to you as the result."
),
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"script": types.Schema(
type=types.Type.STRING,
description="The PowerShell script to execute."
)
},
required=["script"]
),
))
return types.Tool(function_declarations=declarations)
def _run_script(script: str, base_dir: str) -> str:
if confirm_and_run_callback is None:
return "ERROR: no confirmation handler registered"
result = confirm_and_run_callback(script, base_dir)
if result is None:
output = "USER REJECTED: command was not executed"
else:
output = result
if tool_log_callback is not None:
tool_log_callback(script, output)
return output
# ------------------------------------------------------------------ dynamic file context refresh
def _reread_file_items(file_items: list[dict]) -> list[dict]:
"""
Re-read every file in file_items from disk, returning a fresh list.
This is called after tool calls so the AI sees updated file contents.
"""
refreshed = []
for item in file_items:
path = item.get("path")
if path is None:
refreshed.append(item)
continue
from pathlib import Path as _P
p = _P(path) if not isinstance(path, _P) else path
try:
content = p.read_text(encoding="utf-8")
refreshed.append({**item, "content": content, "error": False})
except Exception as e:
refreshed.append({**item, "content": f"ERROR re-reading {p}: {e}", "error": True})
return refreshed
def _build_file_context_text(file_items: list[dict]) -> str:
"""
Build a compact text summary of all files from file_items, suitable for
injecting into a tool_result message so the AI sees current file contents.
"""
if not file_items:
return ""
parts = []
for item in file_items:
path = item.get("path") or item.get("entry", "unknown")
suffix = str(path).rsplit(".", 1)[-1] if "." in str(path) else "text"
content = item.get("content", "")
parts.append(f"### `{path}`\n\n```{suffix}\n{content}\n```")
return "\n\n---\n\n".join(parts)
# ------------------------------------------------------------------ content block serialisation
def _content_block_to_dict(block) -> dict:
"""
Convert an Anthropic SDK content block object to a plain dict.
This ensures history entries are always JSON-serialisable dicts,
not opaque SDK objects that may fail on re-serialisation.
"""
if isinstance(block, dict):
return block
if hasattr(block, "model_dump"):
return block.model_dump()
if hasattr(block, "to_dict"):
return block.to_dict()
# Fallback: manually construct based on type
block_type = getattr(block, "type", None)
if block_type == "text":
return {"type": "text", "text": block.text}
if block_type == "tool_use":
return {"type": "tool_use", "id": block.id, "name": block.name, "input": block.input}
return {"type": "text", "text": str(block)}
# ------------------------------------------------------------------ gemini
def _ensure_gemini_client():
global _gemini_client
if _gemini_client is None:
from google import genai
creds = _load_credentials()
_gemini_client = genai.Client(api_key=creds["gemini"]["api_key"])
def _send_gemini(static_md: str, dynamic_md: str, user_message: str, base_dir: str, file_items: list[dict] | None = None) -> str:
global _gemini_chat, _gemini_cache
from google.genai import types
try:
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{static_md}\n</context>"
tools_decl = [_gemini_tool_declaration()]
current_md_hash = hash(static_md)
old_history = None
if _gemini_chat and getattr(_gemini_chat, "_last_md_hash", None) != current_md_hash:
old_history = list(_gemini_chat.history) if _gemini_chat.history else []
if _gemini_cache:
try: _gemini_client.caches.delete(name=_gemini_cache.name)
except: pass
_gemini_chat, _gemini_cache = None, None
_append_comms("OUT", "request", {"message": "[STATIC CONTEXT CHANGED] Rebuilding cache and chat session..."})
if not _gemini_chat:
chat_config = types.GenerateContentConfig(
system_instruction=sys_instr, tools=tools_decl, temperature=_temperature, max_output_tokens=_max_tokens,
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
)
try:
_gemini_cache = _gemini_client.caches.create(model=_model, config=types.CreateCachedContentConfig(system_instruction=sys_instr, tools=tools_decl, ttl="3600s"))
chat_config = types.GenerateContentConfig(
cached_content=_gemini_cache.name, temperature=_temperature, max_output_tokens=_max_tokens,
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
)
_append_comms("OUT", "request", {"message": f"[CACHE CREATED] {_gemini_cache.name}"})
except Exception: _gemini_cache = None
kwargs = {"model": _model, "config": chat_config}
if old_history: kwargs["history"] = old_history
_gemini_chat = _gemini_client.chats.create(**kwargs)
_gemini_chat._last_md_hash = current_md_hash
import re
if _gemini_chat and _gemini_chat.history:
for msg in _gemini_chat.history:
if msg.role == "user" and hasattr(msg, "parts"):
for p in msg.parts:
if hasattr(p, "text") and p.text and "<discussion>" in p.text:
p.text = re.sub(r"<discussion>.*?</discussion>\n\n", "", p.text, flags=re.DOTALL)
if hasattr(p, "function_response") and p.function_response and hasattr(p.function_response, "response"):
r = p.function_response.response
r_dict = r if isinstance(r, dict) else getattr(r, "__dict__", {})
val = r_dict.get("output") if isinstance(r_dict, dict) else getattr(r, "output", None)
if isinstance(val, str):
if "[SYSTEM: FILES UPDATED]" in val: val = val.split("[SYSTEM: FILES UPDATED]")[0].strip()
if _history_trunc_limit > 0 and len(val) > _history_trunc_limit:
val = val[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS.]"
if isinstance(r, dict): r["output"] = val
else: setattr(r, "output", val)
full_user_msg = f"<discussion>\n{dynamic_md}\n</discussion>\n\n{user_message}" if dynamic_md else user_message
_append_comms("OUT", "request", {"message": f"[ctx {len(static_md)} static + {len(dynamic_md)} dynamic + msg {len(user_message)}]"})
payload, all_text = full_user_msg, []
for r_idx in range(MAX_TOOL_ROUNDS + 2):
resp = _gemini_chat.send_message(payload)
txt = "\n".join(p.text for c in resp.candidates if getattr(c, "content", None) for p in c.content.parts if hasattr(p, "text") and p.text)
if txt: all_text.append(txt)
calls = [p.function_call for c in resp.candidates if getattr(c, "content", None) for p in c.content.parts if hasattr(p, "function_call") and p.function_call]
usage = {"input_tokens": getattr(resp.usage_metadata, "prompt_token_count", 0), "output_tokens": getattr(resp.usage_metadata, "candidates_token_count", 0)}
cached_tokens = getattr(resp.usage_metadata, "cached_content_token_count", None)
if cached_tokens: usage["cache_read_input_tokens"] = cached_tokens
reason = resp.candidates[0].finish_reason.name if resp.candidates and hasattr(resp.candidates[0], "finish_reason") else "STOP"
_append_comms("IN", "response", {"round": r_idx, "stop_reason": reason, "text": txt, "tool_calls": [{"name": c.name, "args": dict(c.args)} for c in calls], "usage": usage})
total_in = usage.get("input_tokens", 0)
if total_in > _GEMINI_MAX_INPUT_TOKENS and _gemini_chat and _gemini_chat.history:
hist = list(_gemini_chat.history)
dropped = 0
while len(hist) > 4 and total_in > _GEMINI_MAX_INPUT_TOKENS * 0.7:
saved = sum(len(p.text)//4 for p in hist[0].parts if hasattr(p, "text") and p.text)
for p in hist[0].parts:
if hasattr(p, "function_response") and p.function_response:
r = getattr(p.function_response, "response", {})
val = r.get("output", "") if isinstance(r, dict) else getattr(r, "output", "")
saved += len(str(val)) // 4
hist.pop(0)
total_in -= max(saved, 100)
dropped += 1
if dropped > 0:
_gemini_chat.history = hist
_append_comms("OUT", "request", {"message": f"[GEMINI HISTORY TRIMMED: dropped {dropped} old entries to stay within token budget]"})
if not calls or r_idx > MAX_TOOL_ROUNDS: break
f_resps, log = [], []
for i, fc in enumerate(calls):
name, args = fc.name, dict(fc.args)
if name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": name, "args": args})
out = mcp_client.dispatch(name, args)
elif name == TOOL_NAME:
scr = args.get("script", "")
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr})
out = _run_script(scr, base_dir)
else: out = f"ERROR: unknown tool '{name}'"
if i == len(calls) - 1:
if file_items:
file_items = _reread_file_items(file_items)
ctx = _build_file_context_text(file_items)
if ctx: out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
f_resps.append(types.Part.from_function_response(name=name, response={"output": out}))
log.append({"tool_use_id": name, "content": out})
_append_comms("OUT", "tool_result_send", {"results": log})
payload = f_resps
return "\n\n".join(all_text) if all_text else "(No text returned)"
except Exception as e: raise _classify_gemini_error(e) from e
# ------------------------------------------------------------------ anthropic history management
# Rough chars-per-token ratio. Anthropic tokeniser averages ~3.5-4 chars/token.
# We use 3.5 to be conservative (overestimate token count = safer).
_CHARS_PER_TOKEN = 3.5
# Maximum token budget for the entire prompt (system + tools + messages).
# Anthropic's limit is 200k. We leave headroom for the response + tool schemas.
_ANTHROPIC_MAX_PROMPT_TOKENS = 180_000
# Gemini models have a 1M context window but we cap well below to leave headroom.
# If the model reports input tokens exceeding this, we trim old history.
_GEMINI_MAX_INPUT_TOKENS = 900_000
# Marker prefix used to identify stale file-refresh injections in history
_FILE_REFRESH_MARKER = "[FILES UPDATED"
def _estimate_message_tokens(msg: dict) -> int:
"""Rough token estimate for a single Anthropic message dict."""
total_chars = 0
content = msg.get("content", "")
if isinstance(content, str):
total_chars += len(content)
elif isinstance(content, list):
for block in content:
if isinstance(block, dict):
text = block.get("text", "") or block.get("content", "")
if isinstance(text, str):
total_chars += len(text)
# tool_use input
inp = block.get("input")
if isinstance(inp, dict):
import json as _json
total_chars += len(_json.dumps(inp, ensure_ascii=False))
elif isinstance(block, str):
total_chars += len(block)
return max(1, int(total_chars / _CHARS_PER_TOKEN))
def _estimate_prompt_tokens(system_blocks: list[dict], history: list[dict]) -> int:
"""Estimate total prompt tokens: system + tools + all history messages."""
total = 0
# System blocks
for block in system_blocks:
text = block.get("text", "")
total += max(1, int(len(text) / _CHARS_PER_TOKEN))
# Tool definitions (rough fixed estimate — they're ~2k tokens for our set)
total += 2500
# History messages
for msg in history:
total += _estimate_message_tokens(msg)
return total
def _strip_stale_file_refreshes(history: list[dict]):
if len(history) < 2:
return
last_user_idx = next((i for i in range(len(history)-1, -1, -1) if history[i].get("role") == "user"), -1)
for i, msg in enumerate(history):
if msg.get("role") != "user" or i == last_user_idx:
continue
content = msg.get("content")
if not isinstance(content, list):
continue
cleaned = [b for b in content if not (isinstance(b, dict) and b.get("type") == "text" and b.get("text", "").startswith(_FILE_REFRESH_MARKER))]
if len(cleaned) < len(content):
msg["content"] = cleaned
def _trim_anthropic_history(system_blocks: list[dict], history: list[dict]) -> int:
_strip_stale_file_refreshes(history)
est = _estimate_prompt_tokens(system_blocks, history)
if est <= _ANTHROPIC_MAX_PROMPT_TOKENS:
return 0
dropped = 0
while len(history) > 3 and est > _ANTHROPIC_MAX_PROMPT_TOKENS:
if history[1].get("role") == "assistant" and len(history) > 2 and history[2].get("role") == "user":
est -= _estimate_message_tokens(history.pop(1))
est -= _estimate_message_tokens(history.pop(1))
dropped += 2
while len(history) > 2 and history[1].get("role") == "assistant" and history[2].get("role") == "user":
c = history[2].get("content", [])
if isinstance(c, list) and c and isinstance(c[0], dict) and c[0].get("type") == "tool_result":
est -= _estimate_message_tokens(history.pop(1))
est -= _estimate_message_tokens(history.pop(1))
dropped += 2
else: break
else:
est -= _estimate_message_tokens(history.pop(1))
dropped += 1
return dropped
# ------------------------------------------------------------------ anthropic
def _ensure_anthropic_client():
global _anthropic_client
if _anthropic_client is None:
import anthropic
creds = _load_credentials()
_anthropic_client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
def _chunk_text(text: str, chunk_size: int) -> list[str]:
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
def _build_chunked_context_blocks(md_content: str) -> list[dict]:
"""
Split md_content into <=_ANTHROPIC_CHUNK_SIZE char chunks.
cache_control:ephemeral is placed only on the LAST block so the whole
prefix is cached as one unit.
"""
chunks = _chunk_text(md_content, _ANTHROPIC_CHUNK_SIZE)
blocks = []
for i, chunk in enumerate(chunks):
block: dict = {"type": "text", "text": chunk}
if i == len(chunks) - 1:
block["cache_control"] = {"type": "ephemeral"}
blocks.append(block)
return blocks
def _strip_cache_controls(history: list[dict]):
"""
Remove cache_control from all content blocks in message history.
Anthropic allows max 4 cache_control blocks total across system + tools +
messages. We reserve those slots for the stable system/tools prefix and
the current turn's context block, so all older history entries must be clean.
"""
for msg in history:
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
block.pop("cache_control", None)
def _repair_anthropic_history(history: list[dict]):
"""
If history ends with an assistant message that contains tool_use blocks
without a following user tool_result message, append a synthetic tool_result
message so the history is valid before the next request.
"""
if not history:
return
last = history[-1]
if last.get("role") != "assistant":
return
content = last.get("content", [])
tool_use_ids = []
for block in content:
if isinstance(block, dict):
if block.get("type") == "tool_use":
tool_use_ids.append(block["id"])
if not tool_use_ids:
return
history.append({
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": tid,
"content": "Tool call was not completed (session interrupted).",
}
for tid in tool_use_ids
],
})
def _send_anthropic(static_md: str, dynamic_md: str, user_message: str, base_dir: str, file_items: list[dict] | None = None) -> str:
try:
_ensure_anthropic_client()
mcp_client.configure(file_items or [], [base_dir])
system_text = _get_combined_system_prompt() + f"\n\n<context>\n{static_md}\n</context>"
system_blocks = _build_chunked_context_blocks(system_text)
if dynamic_md:
system_blocks.append({"type": "text", "text": f"<discussion>\n{dynamic_md}\n</discussion>"})
user_content = [{"type": "text", "text": user_message}]
for msg in _anthropic_history:
if msg.get("role") == "user" and isinstance(msg.get("content"), list):
for block in msg["content"]:
if isinstance(block, dict) and block.get("type") == "tool_result":
t_content = block.get("content", "")
if _history_trunc_limit > 0 and isinstance(t_content, str) and len(t_content) > _history_trunc_limit:
block["content"] = t_content[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS. Original output was too large.]"
_strip_cache_controls(_anthropic_history)
_repair_anthropic_history(_anthropic_history)
user_content[-1]["cache_control"] = {"type": "ephemeral"}
_anthropic_history.append({"role": "user", "content": user_content})
n_chunks = len(system_blocks)
_append_comms("OUT", "request", {
"message": (f"[system {n_chunks} chunk(s), {len(static_md)} static + {len(dynamic_md)} dynamic chars context] "
f"{user_message[:200]}{'...' if len(user_message) > 200 else ''}"),
})
all_text_parts = []
for round_idx in range(MAX_TOOL_ROUNDS + 2):
dropped = _trim_anthropic_history(system_blocks, _anthropic_history)
if dropped > 0:
est_tokens = _estimate_prompt_tokens(system_blocks, _anthropic_history)
_append_comms("OUT", "request", {"message": f"[HISTORY TRIMMED: dropped {dropped} old messages to fit token budget. Estimated {est_tokens} tokens remaining.]"})
response = _anthropic_client.messages.create(
model=_model, max_tokens=_max_tokens, temperature=_temperature,
system=system_blocks, tools=_get_anthropic_tools(), messages=_anthropic_history,
)
serialised_content = [_content_block_to_dict(b) for b in response.content]
_anthropic_history.append({"role": "assistant", "content": serialised_content})
text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text]
if text_blocks: all_text_parts.append("\n".join(text_blocks))
tool_use_blocks = [{"id": b.id, "name": b.name, "input": b.input} for b in response.content if getattr(b, "type", None) == "tool_use"]
usage_dict = {}
if response.usage:
usage_dict.update({"input_tokens": response.usage.input_tokens, "output_tokens": response.usage.output_tokens})
if getattr(response.usage, "cache_creation_input_tokens", None) is not None:
usage_dict["cache_creation_input_tokens"] = response.usage.cache_creation_input_tokens
if getattr(response.usage, "cache_read_input_tokens", None) is not None:
usage_dict["cache_read_input_tokens"] = response.usage.cache_read_input_tokens
_append_comms("IN", "response", {"round": round_idx, "stop_reason": response.stop_reason, "text": "\n".join(text_blocks), "tool_calls": tool_use_blocks, "usage": usage_dict})
if response.stop_reason != "tool_use" or not tool_use_blocks: break
if round_idx > MAX_TOOL_ROUNDS: break
tool_results = []
for block in response.content:
if getattr(block, "type", None) != "tool_use": continue
b_name, b_id, b_input = getattr(block, "name", None), getattr(block, "id", ""), getattr(block, "input", {})
if b_name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input})
out = mcp_client.dispatch(b_name, b_input)
elif b_name == TOOL_NAME:
scr = b_input.get("script", "")
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": b_id, "script": scr})
out = _run_script(scr, base_dir)
else: out = f"ERROR: unknown tool '{b_name}'"
_append_comms("IN", "tool_result", {"name": b_name, "id": b_id, "output": out})
tool_results.append({"type": "tool_result", "tool_use_id": b_id, "content": out})
if file_items:
file_items = _reread_file_items(file_items)
refreshed_ctx = _build_file_context_text(file_items)
if refreshed_ctx:
tool_results.append({"type": "text", "text": f"[{_FILE_REFRESH_MARKER} — current contents below. Do NOT re-read these files with PowerShell.]\n\n{refreshed_ctx}"})
if round_idx == MAX_TOOL_ROUNDS:
tool_results.append({"type": "text", "text": "SYSTEM WARNING: MAX TOOL ROUNDS REACHED. YOU MUST PROVIDE YOUR FINAL ANSWER NOW WITHOUT CALLING ANY MORE TOOLS."})
_anthropic_history.append({"role": "user", "content": tool_results})
_append_comms("OUT", "tool_result_send", {"results": [{"tool_use_id": r["tool_use_id"], "content": r["content"]} for r in tool_results if r.get("type") == "tool_result"]})
final_text = "\n\n".join(all_text_parts)
return final_text if final_text.strip() else "(No text returned by the model)"
except ProviderError: raise
except Exception as exc: raise _classify_anthropic_error(exc) from exc
# ------------------------------------------------------------------ unified send
def send(
static_md: str,
dynamic_md: str,
user_message: str,
base_dir: str = ".",
file_items: list[dict] | None = None,
) -> str:
"""Send a message to the active provider."""
if _provider == "gemini":
return _send_gemini(static_md, dynamic_md, user_message, base_dir, file_items)
elif _provider == "anthropic":
return _send_anthropic(static_md, dynamic_md, user_message, base_dir, file_items)
raise ValueError(f"unknown provider: {_provider}")