Files
manual_slop/ai_client.py
2026-02-28 09:06:45 -05:00

1792 lines
72 KiB
Python

# ai_client.py
"""
Note(Gemini):
Acts as the unified interface for multiple LLM providers (Anthropic, Gemini).
Abstracts away the differences in how they handle tool schemas, history, and caching.
For Anthropic: aggressively manages the ~200k token limit by manually culling
stale [FILES UPDATED] entries and dropping the oldest message pairs.
For Gemini: injects the initial context directly into system_instruction
during chat creation to avoid massive history bloat.
"""
# ai_client.py
import tomllib
import json
import sys
import time
import datetime
import hashlib
import difflib
import threading
import requests
from pathlib import Path
from typing import Optional, Callable, Any
import os
import project_manager
import file_cache
import mcp_client
import anthropic
from gemini_cli_adapter import GeminiCliAdapter
from google import genai
from google.genai import types
from events import EventEmitter
_provider: str = "gemini"
_model: str = "gemini-2.5-flash-lite"
_temperature: float = 0.0
_max_tokens: int = 8192
_history_trunc_limit: int = 8000
# Global event emitter for API lifecycle events
events = EventEmitter()
def set_model_params(temp: float, max_tok: int, trunc_limit: int = 8000) -> None:
global _temperature, _max_tokens, _history_trunc_limit
_temperature = temp
_max_tokens = max_tok
_history_trunc_limit = trunc_limit
def get_history_trunc_limit() -> int:
return _history_trunc_limit
def set_history_trunc_limit(val: int) -> None:
global _history_trunc_limit
_history_trunc_limit = val
_gemini_client = None
_gemini_chat = None
_gemini_cache = None
_gemini_cache_md_hash: int | None = None
_gemini_cache_created_at: float | None = None
# Gemini cache TTL in seconds. Caches are created with this TTL and
# proactively rebuilt at 90% of this value to avoid stale-reference errors.
_GEMINI_CACHE_TTL = 3600
_anthropic_client = None
_anthropic_history: list[dict] = []
_anthropic_history_lock = threading.Lock()
_deepseek_client = None
_deepseek_history: list[dict] = []
_deepseek_history_lock = threading.Lock()
_send_lock = threading.Lock()
_gemini_cli_adapter = None
# Injected by gui.py - called when AI wants to run a command.
# Signature: (script: str, base_dir: str) -> str | None
confirm_and_run_callback = None
# Injected by gui.py - called whenever a comms entry is appended.
# Signature: (entry: dict) -> None
comms_log_callback = None
# Injected by gui.py - called whenever a tool call completes.
# Signature: (script: str, result: str) -> None
tool_log_callback = None
# Increased to allow thorough code exploration before forcing a summary
MAX_TOOL_ROUNDS = 10
# Maximum cumulative bytes of tool output allowed per send() call.
# Prevents unbounded memory growth during long tool-calling loops.
_MAX_TOOL_OUTPUT_BYTES = 500_000
# Maximum characters per text chunk sent to Anthropic.
# Kept well under the ~200k token API limit.
_ANTHROPIC_CHUNK_SIZE = 120_000
_SYSTEM_PROMPT = (
"You are a helpful coding assistant with access to a PowerShell tool and MCP tools (file access: read_file, list_directory, search_files, get_file_summary, web access: web_search, fetch_url). "
"When calling file/directory tools, always use the 'path' parameter for the target path. "
"When asked to create or edit files, prefer targeted edits over full rewrites. "
"Always explain what you are doing before invoking the tool.\n\n"
"When writing or rewriting large files (especially those containing quotes, backticks, or special characters), "
"avoid python -c with inline strings. Instead: (1) write a .py helper script to disk using a PS here-string "
"(@'...'@ for literal content), (2) run it with `python <script>`, (3) delete the helper. "
"For small targeted edits, use PowerShell's (Get-Content) / .Replace() / Set-Content or Add-Content directly.\n\n"
"When making function calls using tools that accept array or object parameters "
"ensure those are structured using JSON. For example:\n"
"When you need to verify a change, rely on the exit code and stdout/stderr from the tool \u2014 "
"the user's context files are automatically refreshed after every tool call, so you do NOT "
"need to re-read files that are already provided in the <context> block."
)
_custom_system_prompt: str = ""
def set_custom_system_prompt(prompt: str) -> None:
global _custom_system_prompt
_custom_system_prompt = prompt
def _get_combined_system_prompt() -> str:
if _custom_system_prompt.strip():
return f"{_SYSTEM_PROMPT}\n\n[USER SYSTEM PROMPT]\n{_custom_system_prompt}"
return _SYSTEM_PROMPT
# ------------------------------------------------------------------ comms log
_comms_log: list[dict] = []
COMMS_CLAMP_CHARS = 300
def _append_comms(direction: str, kind: str, payload: dict[str, Any]) -> None:
entry = {
"ts": datetime.datetime.now().strftime("%H:%M:%S"),
"direction": direction,
"kind": kind,
"provider": _provider,
"model": _model,
"payload": payload,
}
_comms_log.append(entry)
if comms_log_callback is not None:
comms_log_callback(entry)
def get_comms_log() -> list[dict]:
return list(_comms_log)
def clear_comms_log() -> None:
_comms_log.clear()
def _load_credentials() -> dict:
cred_path = os.environ.get("SLOP_CREDENTIALS", "credentials.toml")
try:
with open(cred_path, "rb") as f:
return tomllib.load(f)
except FileNotFoundError:
raise FileNotFoundError(
f"Credentials file not found: {cred_path}\n"
f"Create a credentials.toml with:\n"
f" [gemini]\n api_key = \"your-key\"\n"
f" [anthropic]\n api_key = \"your-key\"\n"
f" [deepseek]\n api_key = \"your-key\"\n"
f"Or set SLOP_CREDENTIALS env var to a custom path."
)
# ------------------------------------------------------------------ provider errors
class ProviderError(Exception):
def __init__(self, kind: str, provider: str, original: Exception) -> None:
self.kind = kind
self.provider = provider
self.original = original
super().__init__(str(original))
def ui_message(self) -> str:
labels = {
"quota": "QUOTA EXHAUSTED",
"rate_limit": "RATE LIMITED",
"auth": "AUTH / API KEY ERROR",
"balance": "BALANCE / BILLING ERROR",
"network": "NETWORK / CONNECTION ERROR",
"unknown": "API ERROR",
}
label = labels.get(self.kind, "API ERROR")
return f"[{self.provider.upper()} {label}]\n\n{self.original}"
def _classify_anthropic_error(exc: Exception) -> ProviderError:
try:
if isinstance(exc, anthropic.RateLimitError):
return ProviderError("rate_limit", "anthropic", exc)
if isinstance(exc, anthropic.AuthenticationError):
return ProviderError("auth", "anthropic", exc)
if isinstance(exc, anthropic.PermissionDeniedError):
return ProviderError("auth", "anthropic", exc)
if isinstance(exc, anthropic.APIConnectionError):
return ProviderError("network", "anthropic", exc)
if isinstance(exc, anthropic.APIStatusError):
status = getattr(exc, "status_code", 0)
body = str(exc).lower()
if status == 429:
return ProviderError("rate_limit", "anthropic", exc)
if status in (401, 403):
return ProviderError("auth", "anthropic", exc)
if status == 402:
return ProviderError("balance", "anthropic", exc)
if "credit" in body or "balance" in body or "billing" in body:
return ProviderError("balance", "anthropic", exc)
if "quota" in body or "limit" in body or "exceeded" in body:
return ProviderError("quota", "anthropic", exc)
except ImportError:
pass
return ProviderError("unknown", "anthropic", exc)
def _classify_gemini_error(exc: Exception) -> ProviderError:
body = str(exc).lower()
try:
from google.api_core import exceptions as gac
if isinstance(exc, gac.ResourceExhausted):
return ProviderError("quota", "gemini", exc)
if isinstance(exc, gac.TooManyRequests):
return ProviderError("rate_limit", "gemini", exc)
if isinstance(exc, (gac.Unauthenticated, gac.PermissionDenied)):
return ProviderError("auth", "gemini", exc)
if isinstance(exc, gac.ServiceUnavailable):
return ProviderError("network", "gemini", exc)
except ImportError:
pass
if "429" in body or "quota" in body or "resource exhausted" in body:
return ProviderError("quota", "gemini", exc)
if "rate" in body and "limit" in body:
return ProviderError("rate_limit", "gemini", exc)
if "401" in body or "403" in body or "api key" in body or "unauthenticated" in body:
return ProviderError("auth", "gemini", exc)
if "402" in body or "billing" in body or "balance" in body or "payment" in body:
return ProviderError("balance", "gemini", exc)
if "connection" in body or "timeout" in body or "unreachable" in body:
return ProviderError("network", "gemini", exc)
return ProviderError("unknown", "gemini", exc)
def _classify_deepseek_error(exc: Exception) -> ProviderError:
body = str(exc).lower()
if "429" in body or "rate" in body:
return ProviderError("rate_limit", "deepseek", exc)
if "401" in body or "403" in body or "auth" in body or "api key" in body:
return ProviderError("auth", "deepseek", exc)
if "402" in body or "balance" in body or "billing" in body:
return ProviderError("balance", "deepseek", exc)
if "quota" in body or "limit exceeded" in body:
return ProviderError("quota", "deepseek", exc)
if "connection" in body or "timeout" in body or "network" in body:
return ProviderError("network", "deepseek", exc)
return ProviderError("unknown", "deepseek", exc)
# ------------------------------------------------------------------ provider setup
def set_provider(provider: str, model: str) -> None:
global _provider, _model
_provider = provider
if provider == "gemini_cli":
valid_models = _list_gemini_cli_models()
# If model is invalid or belongs to another provider (like deepseek), force default
if model not in valid_models or model.startswith("deepseek"):
_model = "gemini-3-flash-preview"
else:
_model = model
else:
_model = model
def cleanup() -> None:
"""Called on application exit to prevent orphaned caches from billing."""
global _gemini_client, _gemini_cache
if _gemini_client and _gemini_cache:
try:
_gemini_client.caches.delete(name=_gemini_cache.name)
except Exception:
pass
def reset_session() -> None:
global _gemini_client, _gemini_chat, _gemini_cache
global _gemini_cache_md_hash, _gemini_cache_created_at
global _anthropic_client, _anthropic_history
global _CACHED_ANTHROPIC_TOOLS
global _gemini_cli_adapter
if _gemini_client and _gemini_cache:
try:
_gemini_client.caches.delete(name=_gemini_cache.name)
except Exception:
pass
_gemini_client = None
_gemini_chat = None
_gemini_cache = None
_gemini_cache_md_hash = None
_gemini_cache_created_at = None
if _gemini_cli_adapter:
_gemini_cli_adapter.session_id = None
_gemini_cli_adapter = None
_anthropic_client = None
with _anthropic_history_lock:
_anthropic_history = []
_deepseek_client = None
with _deepseek_history_lock:
_deepseek_history = []
_CACHED_ANTHROPIC_TOOLS = None
file_cache.reset_client()
def get_gemini_cache_stats() -> dict[str, Any]:
"""
Retrieves statistics about the Gemini caches, such as count and total size.
"""
_ensure_gemini_client()
caches_iterator = _gemini_client.caches.list()
caches = list(caches_iterator)
total_size_bytes = sum(c.size_bytes for c in caches)
return {
"cache_count": len(list(caches)),
"total_size_bytes": total_size_bytes,
}
# ------------------------------------------------------------------ model listing
def list_models(provider: str) -> list[str]:
creds = _load_credentials()
if provider == "gemini":
return _list_gemini_models(creds["gemini"]["api_key"])
elif provider == "anthropic":
return _list_anthropic_models()
elif provider == "deepseek":
return _list_deepseek_models(creds["deepseek"]["api_key"])
elif provider == "gemini_cli":
return _list_gemini_cli_models()
return []
def _list_gemini_cli_models() -> list[str]:
"""
List available Gemini models for the CLI.
Since the CLI doesn't have a direct 'list models' command yet,
we return a curated list of supported models based on CLI metadata.
"""
return [
"gemini-3-flash-preview",
"gemini-3.1-pro-preview",
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-2.5-flash-lite",
]
def _list_gemini_models(api_key: str) -> list[str]:
try:
client = genai.Client(api_key=api_key)
models = []
for m in client.models.list():
name = m.name
if name.startswith("models/"):
name = name[len("models/"):]
if "gemini" in name.lower():
models.append(name)
return sorted(models)
except Exception as exc:
raise _classify_gemini_error(exc) from exc
def _list_anthropic_models() -> list[str]:
try:
creds = _load_credentials()
client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
models = []
for m in client.models.list():
models.append(m.id)
return sorted(models)
except Exception as exc:
raise _classify_anthropic_error(exc) from exc
def _list_deepseek_models(api_key: str) -> list[str]:
"""
List available DeepSeek models.
"""
# For now, return the models specified in the requirements
return ["deepseek-chat", "deepseek-reasoner", "deepseek-v3", "deepseek-r1"]
# ------------------------------------------------------------------ tool definition
TOOL_NAME = "run_powershell"
_agent_tools: dict = {}
def set_agent_tools(tools: dict[str, bool]) -> None:
global _agent_tools, _CACHED_ANTHROPIC_TOOLS
_agent_tools = tools
_CACHED_ANTHROPIC_TOOLS = None
def _build_anthropic_tools() -> list[dict]:
"""Build the full Anthropic tools list: run_powershell + MCP file tools."""
mcp_tools = []
for spec in mcp_client.MCP_TOOL_SPECS:
if _agent_tools.get(spec["name"], True):
mcp_tools.append({
"name": spec["name"],
"description": spec["description"],
"input_schema": spec["parameters"],
})
tools_list = mcp_tools
if _agent_tools.get(TOOL_NAME, True):
powershell_tool = {
"name": TOOL_NAME,
"description": (
"Run a PowerShell script within the project base_dir. "
"Use this to create, edit, rename, or delete files and directories. "
"The working directory is set to base_dir automatically. "
"Always prefer targeted edits over full rewrites where possible. "
"stdout and stderr are returned to you as the result."
),
"input_schema": {
"type": "object",
"properties": {
"script": {
"type": "string",
"description": "The PowerShell script to execute."
}
},
"required": ["script"]
},
"cache_control": {"type": "ephemeral"},
}
tools_list.append(powershell_tool)
elif tools_list:
# Anthropic requires the LAST tool to have cache_control for the prefix caching to work optimally on tools
tools_list[-1]["cache_control"] = {"type": "ephemeral"}
return tools_list
_ANTHROPIC_TOOLS = _build_anthropic_tools()
_CACHED_ANTHROPIC_TOOLS = None
def _get_anthropic_tools() -> list[dict]:
"""Return the Anthropic tools list, rebuilding only once per session."""
global _CACHED_ANTHROPIC_TOOLS
if _CACHED_ANTHROPIC_TOOLS is None:
_CACHED_ANTHROPIC_TOOLS = _build_anthropic_tools()
return _CACHED_ANTHROPIC_TOOLS
def _gemini_tool_declaration() -> types.Tool | None:
declarations = []
# MCP file tools
for spec in mcp_client.MCP_TOOL_SPECS:
if not _agent_tools.get(spec["name"], True):
continue
props = {}
for pname, pdef in spec["parameters"].get("properties", {}).items():
ptype_str = pdef.get("type", "string").upper()
ptype = getattr(types.Type, ptype_str, types.Type.STRING)
props[pname] = types.Schema(
type=ptype,
description=pdef.get("description", ""),
)
declarations.append(types.FunctionDeclaration(
name=spec["name"],
description=spec["description"],
parameters=types.Schema(
type=types.Type.OBJECT,
properties=props,
required=spec["parameters"].get("required", []),
),
))
# PowerShell tool
if _agent_tools.get(TOOL_NAME, True):
declarations.append(types.FunctionDeclaration(
name=TOOL_NAME,
description=(
"Run a PowerShell script within the project base_dir. "
"Use this to create, edit, rename, or delete files and directories. "
"The working directory is set to base_dir automatically. "
"stdout and stderr are returned to you as the result."
),
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"script": types.Schema(
type=types.Type.STRING,
description="The PowerShell script to execute."
)
},
required=["script"]
),
))
return types.Tool(function_declarations=declarations) if declarations else None
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str:
if confirm_and_run_callback is None:
return "ERROR: no confirmation handler registered"
result = confirm_and_run_callback(script, base_dir, qa_callback)
if result is None:
output = "USER REJECTED: command was not executed"
else:
output = result
if tool_log_callback is not None:
tool_log_callback(script, output)
return output
def _truncate_tool_output(output: str) -> str:
"""Truncate tool output to _history_trunc_limit chars before sending to API."""
if _history_trunc_limit > 0 and len(output) > _history_trunc_limit:
return output[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS.]"
return output
# ------------------------------------------------------------------ dynamic file context refresh
def _reread_file_items(file_items: list[dict]) -> tuple[list[dict], list[dict]]:
"""
Re-read file_items from disk, but only files whose mtime has changed.
Returns (all_items, changed_items) — all_items is the full refreshed list,
changed_items contains only the files that were actually modified since
the last read (used to build a minimal [FILES UPDATED] block).
"""
refreshed = []
changed = []
for item in file_items:
path = item.get("path")
if path is None:
refreshed.append(item)
continue
from pathlib import Path as _P
p = _P(path) if not isinstance(path, _P) else path
try:
current_mtime = p.stat().st_mtime
prev_mtime = item.get("mtime", 0.0)
if current_mtime == prev_mtime:
refreshed.append(item) # unchanged — skip re-read
continue
content = p.read_text(encoding="utf-8")
new_item = {**item, "old_content": item.get("content", ""), "content": content, "error": False, "mtime": current_mtime}
refreshed.append(new_item)
changed.append(new_item)
except Exception as e:
err_item = {**item, "content": f"ERROR re-reading {p}: {e}", "error": True, "mtime": 0.0}
refreshed.append(err_item)
changed.append(err_item)
return refreshed, changed
def _build_file_context_text(file_items: list[dict]) -> str:
"""
Build a compact text summary of all files from file_items, suitable for
injecting into a tool_result message so the AI sees current file contents.
"""
if not file_items:
return ""
parts = []
for item in file_items:
path = item.get("path") or item.get("entry", "unknown")
suffix = str(path).rsplit(".", 1)[-1] if "." in str(path) else "text"
content = item.get("content", "")
parts.append(f"### `{path}`\n\n```{suffix}\n{content}\n```")
return "\n\n---\n\n".join(parts)
_DIFF_LINE_THRESHOLD = 200
def _build_file_diff_text(changed_items: list[dict]) -> str:
"""
Build text for changed files. Small files (<= _DIFF_LINE_THRESHOLD lines)
get full content; large files get a unified diff against old_content.
"""
if not changed_items:
return ""
parts = []
for item in changed_items:
path = item.get("path") or item.get("entry", "unknown")
content = item.get("content", "")
old_content = item.get("old_content", "")
new_lines = content.splitlines(keepends=True)
if len(new_lines) <= _DIFF_LINE_THRESHOLD or not old_content:
suffix = str(path).rsplit(".", 1)[-1] if "." in str(path) else "text"
parts.append(f"### `{path}` (full)\n\n```{suffix}\n{content}\n```")
else:
old_lines = old_content.splitlines(keepends=True)
diff = difflib.unified_diff(old_lines, new_lines, fromfile=str(path), tofile=str(path), lineterm="")
diff_text = "\n".join(diff)
if diff_text:
parts.append(f"### `{path}` (diff)\n\n```diff\n{diff_text}\n```")
else:
parts.append(f"### `{path}` (no changes detected)")
return "\n\n---\n\n".join(parts)
# ------------------------------------------------------------------ content block serialisation
def _content_block_to_dict(block: Any) -> dict[str, Any]:
"""
Convert an Anthropic SDK content block object to a plain dict.
This ensures history entries are always JSON-serialisable dicts,
not opaque SDK objects that may fail on re-serialisation.
"""
if isinstance(block, dict):
return block
if hasattr(block, "model_dump"):
return block.model_dump()
if hasattr(block, "to_dict"):
return block.to_dict()
# Fallback: manually construct based on type
block_type = getattr(block, "type", None)
if block_type == "text":
return {"type": "text", "text": block.text}
if block_type == "tool_use":
return {"type": "tool_use", "id": block.id, "name": block.name, "input": block.input}
return {"type": "text", "text": str(block)}
# ------------------------------------------------------------------ gemini
def _ensure_gemini_client() -> None:
global _gemini_client
if _gemini_client is None:
creds = _load_credentials()
_gemini_client = genai.Client(api_key=creds["gemini"]["api_key"])
def _get_gemini_history_list(chat: Any | None) -> list[Any]:
if not chat: return []
# google-genai SDK stores the mutable list in _history
if hasattr(chat, "_history"):
return chat._history
if hasattr(chat, "history"):
return chat.history
if hasattr(chat, "get_history"):
return chat.get_history()
return []
def _send_gemini(md_content: str, user_message: str, base_dir: str,
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
pre_tool_callback: Optional[Callable[[str], bool]] = None,
qa_callback: Optional[Callable[[str], str]] = None) -> str:
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at
try:
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
# Only stable content (files + screenshots) goes in the cached system instruction.
# Discussion history is sent as conversation messages so the cache isn't invalidated every turn.
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"
td = _gemini_tool_declaration()
tools_decl = [td] if td else None
# DYNAMIC CONTEXT: Check if files/context changed mid-session
current_md_hash = hashlib.md5(md_content.encode()).hexdigest()
old_history = None
if _gemini_chat and _gemini_cache_md_hash != current_md_hash:
old_history = list(_get_gemini_history_list(_gemini_chat)) if _get_gemini_history_list(_gemini_chat) else []
if _gemini_cache:
try: _gemini_client.caches.delete(name=_gemini_cache.name)
except Exception as e: _append_comms("OUT", "request", {"message": f"[CACHE DELETE WARN] {e}"})
_gemini_chat = None
_gemini_cache = None
_gemini_cache_created_at = None
_append_comms("OUT", "request", {"message": "[CONTEXT CHANGED] Rebuilding cache and chat session..."})
# CACHE TTL: Proactively rebuild before the cache expires server-side.
# If we don't, send_message() will reference a deleted cache and fail.
if _gemini_chat and _gemini_cache and _gemini_cache_created_at:
elapsed = time.time() - _gemini_cache_created_at
if elapsed > _GEMINI_CACHE_TTL * 0.9:
old_history = list(_get_gemini_history_list(_gemini_chat)) if _get_gemini_history_list(_get_gemini_history_list(_gemini_chat)) else []
try: _gemini_client.caches.delete(name=_gemini_cache.name)
except Exception as e: _append_comms("OUT", "request", {"message": f"[CACHE DELETE WARN] {e}"})
_gemini_chat = None
_gemini_cache = None
_gemini_cache_created_at = None
_append_comms("OUT", "request", {"message": f"[CACHE TTL] Rebuilding cache (expired after {int(elapsed)}s)..."})
if not _gemini_chat:
chat_config = types.GenerateContentConfig(
system_instruction=sys_instr,
tools=tools_decl,
temperature=_temperature,
max_output_tokens=_max_tokens,
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
)
# Check if context is large enough to warrant caching (min 2048 tokens usually)
should_cache = False
try:
count_resp = _gemini_client.models.count_tokens(model=_model, contents=[sys_instr])
# We use a 2048 threshold to be safe across models
if count_resp.total_tokens >= 2048:
should_cache = True
else:
_append_comms("OUT", "request", {"message": f"[CACHING SKIPPED] Context too small ({count_resp.total_tokens} tokens < 2048)"})
except Exception as e:
_append_comms("OUT", "request", {"message": f"[COUNT FAILED] {e}"})
if should_cache:
try:
# Gemini requires 1024 (Flash) or 4096 (Pro) tokens to cache.
_gemini_cache = _gemini_client.caches.create(
model=_model,
config=types.CreateCachedContentConfig(
system_instruction=sys_instr,
tools=tools_decl,
ttl=f"{_GEMINI_CACHE_TTL}s",
)
)
_gemini_cache_created_at = time.time()
chat_config = types.GenerateContentConfig(
cached_content=_gemini_cache.name,
temperature=_temperature,
max_output_tokens=_max_tokens,
safety_settings=[types.SafetySetting(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_ONLY_HIGH")]
)
_append_comms("OUT", "request", {"message": f"[CACHE CREATED] {_gemini_cache.name}"})
except Exception as e:
_gemini_cache = None
_gemini_cache_created_at = None
_append_comms("OUT", "request", {"message": f"[CACHE FAILED] {type(e).__name__}: {e} \u2014 falling back to inline system_instruction"})
kwargs = {"model": _model, "config": chat_config}
if old_history:
kwargs["history"] = old_history
_gemini_chat = _gemini_client.chats.create(**kwargs)
_gemini_cache_md_hash = current_md_hash
# Inject discussion history as a user message on first chat creation
# (only when there's no old_history being restored, i.e., fresh session)
if discussion_history and not old_history:
_gemini_chat.send_message(f"[DISCUSSION HISTORY]\n\n{discussion_history}")
_append_comms("OUT", "request", {"message": f"[HISTORY INJECTED] {len(discussion_history)} chars"})
_append_comms("OUT", "request", {"message": f"[ctx {len(md_content)} + msg {len(user_message)}]"})
payload: str | list[types.Part] = user_message
all_text: list[str] = []
_cumulative_tool_bytes = 0
# Strip stale file refreshes and truncate old tool outputs ONCE before
# entering the tool loop (not per-round \u2014 history entries don't change).
if _gemini_chat and _get_gemini_history_list(_gemini_chat):
for msg in _get_gemini_history_list(_gemini_chat):
if msg.role == "user" and hasattr(msg, "parts"):
for p in msg.parts:
if hasattr(p, "function_response") and p.function_response and hasattr(p.function_response, "response"):
r = p.function_response.response
if isinstance(r, dict) and "output" in r:
val = r["output"]
if isinstance(val, str):
if "[SYSTEM: FILES UPDATED]" in val:
val = val.split("[SYSTEM: FILES UPDATED]")[0].strip()
if _history_trunc_limit > 0 and len(val) > _history_trunc_limit:
val = val[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS.]"
r["output"] = val
for r_idx in range(MAX_TOOL_ROUNDS + 2):
events.emit("request_start", payload={"provider": "gemini", "model": _model, "round": r_idx})
resp = _gemini_chat.send_message(payload)
txt = "\n".join(p.text for c in resp.candidates if getattr(c, "content", None) for p in c.content.parts if hasattr(p, "text") and p.text)
if txt: all_text.append(txt)
calls = [p.function_call for c in resp.candidates if getattr(c, "content", None) for p in c.content.parts if hasattr(p, "function_call") and p.function_call]
usage = {"input_tokens": getattr(resp.usage_metadata, "prompt_token_count", 0), "output_tokens": getattr(resp.usage_metadata, "candidates_token_count", 0)}
cached_tokens = getattr(resp.usage_metadata, "cached_content_token_count", None)
if cached_tokens:
usage["cache_read_input_tokens"] = cached_tokens
events.emit("response_received", payload={"provider": "gemini", "model": _model, "usage": usage, "round": r_idx})
reason = resp.candidates[0].finish_reason.name if resp.candidates and hasattr(resp.candidates[0], "finish_reason") else "STOP"
_append_comms("IN", "response", {"round": r_idx, "stop_reason": reason, "text": txt, "tool_calls": [{"name": c.name, "args": dict(c.args)} for c in calls], "usage": usage})
# Guard: proactively trim history when input tokens exceed 40% of limit
total_in = usage.get("input_tokens", 0)
if total_in > _GEMINI_MAX_INPUT_TOKENS * 0.4 and _gemini_chat and _get_gemini_history_list(_gemini_chat):
hist = _get_gemini_history_list(_gemini_chat)
dropped = 0
# Drop oldest pairs (user+model) but keep at least the last 2 entries
while len(hist) > 4 and total_in > _GEMINI_MAX_INPUT_TOKENS * 0.3:
# Drop in pairs (user + model) to maintain alternating roles required by Gemini
saved = 0
for _ in range(2):
if not hist: break
for p in hist[0].parts:
if hasattr(p, "text") and p.text:
saved += int(len(p.text) / _CHARS_PER_TOKEN)
elif hasattr(p, "function_response") and p.function_response:
r = getattr(p.function_response, "response", {})
if isinstance(r, dict):
saved += int(len(str(r.get("output", ""))) / _CHARS_PER_TOKEN)
hist.pop(0)
dropped += 1
total_in -= max(saved, 200)
if dropped > 0:
_append_comms("OUT", "request", {"message": f"[GEMINI HISTORY TRIMMED: dropped {dropped} old entries to stay within token budget]"})
if not calls or r_idx > MAX_TOOL_ROUNDS: break
f_resps: list[types.Part] = []
log: list[dict[str, Any]] = []
for i, fc in enumerate(calls):
name, args = fc.name, dict(fc.args)
# Check for tool confirmation if callback is provided
if pre_tool_callback:
payload_str = json.dumps({"tool": name, "args": args})
if not pre_tool_callback(payload_str):
out = "USER REJECTED: tool execution cancelled"
f_resps.append(types.Part.from_function_response(name=name, response={"output": out}))
log.append({"tool_use_id": name, "content": out})
continue
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
if name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": name, "args": args})
out = mcp_client.dispatch(name, args)
elif name == TOOL_NAME:
scr = args.get("script", "")
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "script": scr})
out = _run_script(scr, base_dir, qa_callback)
else: out = f"ERROR: unknown tool '{name}'"
if i == len(calls) - 1:
if file_items:
file_items, changed = _reread_file_items(file_items)
ctx = _build_file_diff_text(changed)
if ctx:
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
out = _truncate_tool_output(out)
_cumulative_tool_bytes += len(out)
f_resps.append(types.Part.from_function_response(name=name, response={"output": out}))
log.append({"tool_use_id": name, "content": out})
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
f_resps.append(types.Part.from_text(
f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
))
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
_append_comms("OUT", "tool_result_send", {"results": log})
payload = f_resps
return "\n\n".join(all_text) if all_text else "(No text returned)"
except Exception as e: raise _classify_gemini_error(e) from e
def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
pre_tool_callback: Optional[Callable[[str], bool]] = None,
qa_callback: Optional[Callable[[str], str]] = None) -> str:
global _gemini_cli_adapter
try:
if _gemini_cli_adapter is None:
_gemini_cli_adapter = GeminiCliAdapter(binary_path="gemini")
adapter = _gemini_cli_adapter
mcp_client.configure(file_items or [], [base_dir])
# Construct the system instruction, combining the base system prompt and the current context.
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"
safety_settings = [{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_ONLY_HIGH'}]
# Initial payload for the first message
payload: str | list[dict[str, Any]] = user_message
if adapter.session_id is None:
if discussion_history:
payload = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"
all_text: list[str] = []
_cumulative_tool_bytes = 0
for r_idx in range(MAX_TOOL_ROUNDS + 2):
if adapter is None:
break
events.emit("request_start", payload={"provider": "gemini_cli", "model": _model, "round": r_idx})
_append_comms("OUT", "request", {"message": f"[CLI] [round {r_idx}] [msg {len(payload)}]"})
resp_data = adapter.send(payload, safety_settings=safety_settings, system_instruction=sys_instr, model=_model)
# Log any stderr from the CLI for transparency
cli_stderr = resp_data.get("stderr", "")
if cli_stderr:
sys.stderr.write(f"\n--- Gemini CLI stderr ---\n{cli_stderr}\n-------------------------\n")
sys.stderr.flush()
txt = resp_data.get("text", "")
if txt: all_text.append(txt)
calls = resp_data.get("tool_calls", [])
usage = adapter.last_usage or {}
latency = adapter.last_latency
events.emit("response_received", payload={"provider": "gemini_cli", "model": _model, "usage": usage, "latency": latency, "round": r_idx})
# Clean up the tool calls format to match comms log expectation
log_calls = []
for c in calls:
log_calls.append({"name": c.get("name"), "args": c.get("args"), "id": c.get("id")})
_append_comms("IN", "response", {
"round": r_idx,
"stop_reason": "TOOL_USE" if calls else "STOP",
"text": txt,
"tool_calls": log_calls,
"usage": usage
})
# If there's text and we're not done, push it to the history immediately
# so it appears as a separate entry in the GUI.
if txt and calls and comms_log_callback:
# Use kind='history_add' to push a new entry into the disc_entries list
comms_log_callback({
"ts": project_manager.now_ts(),
"direction": "IN",
"kind": "history_add",
"payload": {
"role": "AI",
"content": txt
}
})
if not calls or r_idx > MAX_TOOL_ROUNDS:
break
tool_results_for_cli: list[dict[str, Any]] = []
for i, fc in enumerate(calls):
name = fc.get("name")
args = fc.get("args", {})
call_id = fc.get("id")
# Check for tool confirmation if callback is provided
if pre_tool_callback:
payload_str = json.dumps({"tool": name, "args": args})
if not pre_tool_callback(payload_str):
out = "USER REJECTED: tool execution cancelled"
tool_results_for_cli.append({
"role": "tool",
"tool_call_id": call_id,
"name": name,
"content": out
})
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
continue
events.emit("tool_execution", payload={"status": "started", "tool": name, "args": args, "round": r_idx})
if name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
out = mcp_client.dispatch(name, args)
elif name == TOOL_NAME:
scr = args.get("script", "")
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
out = _run_script(scr, base_dir, qa_callback)
else:
out = f"ERROR: unknown tool '{name}'"
if i == len(calls) - 1:
if file_items:
file_items, changed = _reread_file_items(file_items)
ctx = _build_file_diff_text(changed)
if ctx:
out += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
if r_idx == MAX_TOOL_ROUNDS:
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
out = _truncate_tool_output(out)
_cumulative_tool_bytes += len(out)
tool_results_for_cli.append({
"role": "tool",
"tool_call_id": call_id,
"name": name,
"content": out
})
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
# CRITICAL: Update payload for the next round
payload = tool_results_for_cli
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
# We should ideally tell the model here, but for CLI we just append to payload
# For Gemini CLI, we send the tool results as a JSON array of messages (or similar)
# The adapter expects a string, so we'll pass the JSON string of the results.
# Return only the text from the last round, because intermediate
# text chunks were already pushed to history via comms_log_callback.
final_text = all_text[-1] if all_text else "(No text returned)"
return final_text
except Exception as e:
# Basic error classification for CLI
raise ProviderError("unknown", "gemini_cli", e)
# ------------------------------------------------------------------ anthropic history management
# Rough chars-per-token ratio. Anthropic tokeniser averages ~3.5-4 chars/token.
# We use 3.5 to be conservative (overestimate token count = safer).
_CHARS_PER_TOKEN = 3.5
# Maximum token budget for the entire prompt (system + tools + messages).
# Anthropic's limit is 200k. We leave headroom for the response + tool schemas.
_ANTHROPIC_MAX_PROMPT_TOKENS = 180_000
# Gemini models have a 1M context window but we cap well below to leave headroom.
# If the model reports input tokens exceeding this, we trim old history.
_GEMINI_MAX_INPUT_TOKENS = 900_000
# Marker prefix used to identify stale file-refresh injections in history
_FILE_REFRESH_MARKER = "[FILES UPDATED"
def _estimate_message_tokens(msg: dict) -> int:
"""
Rough token estimate for a single Anthropic message dict.
Caches the result on the dict as '_est_tokens' so repeated calls
(e.g., from _trim_anthropic_history) don't re-scan unchanged messages.
Call _invalidate_token_estimate() when a message's content is modified.
"""
cached = msg.get("_est_tokens")
if cached is not None:
return cached
total_chars = 0
content = msg.get("content", "")
if isinstance(content, str):
total_chars += len(content)
elif isinstance(content, list):
for block in content:
if isinstance(block, dict):
text = block.get("text", "") or block.get("content", "")
if isinstance(text, str):
total_chars += len(text)
# tool_use input
inp = block.get("input")
if isinstance(inp, dict):
import json as _json
total_chars += len(_json.dumps(inp, ensure_ascii=False))
elif isinstance(block, str):
total_chars += len(block)
est = max(1, int(total_chars / _CHARS_PER_TOKEN))
msg["_est_tokens"] = est
return est
def _invalidate_token_estimate(msg: dict[str, Any]) -> None:
"""Remove the cached token estimate so the next call recalculates."""
msg.pop("_est_tokens", None)
def _estimate_prompt_tokens(system_blocks: list[dict], history: list[dict]) -> int:
"""Estimate total prompt tokens: system + tools + all history messages."""
total = 0
# System blocks
for block in system_blocks:
text = block.get("text", "")
total += max(1, int(len(text) / _CHARS_PER_TOKEN))
# Tool definitions (rough fixed estimate — they're ~2k tokens for our set)
total += 2500
# History messages (uses cached estimates for unchanged messages)
for msg in history:
total += _estimate_message_tokens(msg)
return total
def _strip_stale_file_refreshes(history: list[dict[str, Any]]) -> None:
"""
Remove [FILES UPDATED ...] text blocks from all history turns EXCEPT
the very last user message. These are stale snapshots from previous
tool rounds that bloat the context without providing value.
"""
if len(history) < 2:
return
# Find the index of the last user message — we keep its file refresh intact
last_user_idx = -1
for i in range(len(history) - 1, -1, -1):
if history[i].get("role") == "user":
last_user_idx = i
break
for i, msg in enumerate(history):
if msg.get("role") != "user" or i == last_user_idx:
continue
content = msg.get("content")
if not isinstance(content, list):
continue
cleaned = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text = block.get("text", "")
if text.startswith(_FILE_REFRESH_MARKER):
continue # drop this stale file refresh block
cleaned.append(block)
if len(cleaned) < len(content):
msg["content"] = cleaned
_invalidate_token_estimate(msg)
def _trim_anthropic_history(system_blocks: list[dict[str, Any]], history: list[dict[str, Any]]) -> int:
"""
Trim the Anthropic history to fit within the token budget.
Strategy:
1. Strip stale file-refresh injections from old turns.
2. If still over budget, drop oldest turn pairs (user + assistant).
Returns the number of messages dropped.
"""
# Phase 1: strip stale file refreshes
_strip_stale_file_refreshes(history)
est = _estimate_prompt_tokens(system_blocks, history)
if est <= _ANTHROPIC_MAX_PROMPT_TOKENS:
return 0
# Phase 2: drop oldest turn pairs until within budget
dropped = 0
while len(history) > 3 and est > _ANTHROPIC_MAX_PROMPT_TOKENS:
# Protect history[0] (original user prompt). Drop from history[1] (assistant) and history[2] (user)
if history[1].get("role") == "assistant" and len(history) > 2 and history[2].get("role") == "user":
removed_asst = history.pop(1)
removed_user = history.pop(1)
dropped += 2
est -= _estimate_message_tokens(removed_asst)
est -= _estimate_message_tokens(removed_user)
# Also drop dangling tool_results if the next message is an assistant and the removed user was just tool results
while len(history) > 2 and history[1].get("role") == "assistant" and history[2].get("role") == "user":
content = history[2].get("content", [])
if isinstance(content, list) and content and isinstance(content[0], dict) and content[0].get("type") == "tool_result":
r_a = history.pop(1)
r_u = history.pop(1)
dropped += 2
est -= _estimate_message_tokens(r_a)
est -= _estimate_message_tokens(r_u)
else:
break
else:
# Edge case fallback: drop index 1 (protecting index 0)
removed = history.pop(1)
dropped += 1
est -= _estimate_message_tokens(removed)
return dropped
# ------------------------------------------------------------------ anthropic
def _ensure_anthropic_client() -> None:
global _anthropic_client
if _anthropic_client is None:
creds = _load_credentials()
# Enable prompt caching beta
_anthropic_client = anthropic.Anthropic(
api_key=creds["anthropic"]["api_key"],
default_headers={"anthropic-beta": "prompt-caching-2024-07-31"}
)
def _chunk_text(text: str, chunk_size: int) -> list[str]:
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
def _build_chunked_context_blocks(md_content: str) -> list[dict]:
"""
Split md_content into <=_ANTHROPIC_CHUNK_SIZE char chunks.
cache_control:ephemeral is placed only on the LAST block so the whole
prefix is cached as one unit.
"""
chunks = _chunk_text(md_content, _ANTHROPIC_CHUNK_SIZE)
blocks = []
for i, chunk in enumerate(chunks):
block: dict = {"type": "text", "text": chunk}
if i == len(chunks) - 1:
block["cache_control"] = {"type": "ephemeral"}
blocks.append(block)
return blocks
def _strip_cache_controls(history: list[dict[str, Any]]) -> None:
"""
Remove cache_control from all content blocks in message history.
Anthropic allows max 4 cache_control blocks total across system + tools +
messages. We reserve those slots for the stable system/tools prefix and
the current turn's context block, so all older history entries must be clean.
"""
for msg in history:
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
block.pop("cache_control", None)
def _add_history_cache_breakpoint(history: list[dict[str, Any]]) -> None:
"""
Place cache_control:ephemeral on the last content block of the
second-to-last user message. This uses one of the 4 allowed Anthropic
cache breakpoints to cache the conversation prefix so the full history
isn't reprocessed on every request.
"""
user_indices = [i for i, m in enumerate(history) if m.get("role") == "user"]
if len(user_indices) < 2:
return # Only one user message (the current turn) — nothing stable to cache
target_idx = user_indices[-2]
content = history[target_idx].get("content")
if isinstance(content, list) and content:
last_block = content[-1]
if isinstance(last_block, dict):
last_block["cache_control"] = {"type": "ephemeral"}
elif isinstance(content, str):
history[target_idx]["content"] = [
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
]
def _repair_anthropic_history(history: list[dict[str, Any]]) -> None:
"""
If history ends with an assistant message that contains tool_use blocks
without a following user tool_result message, append a synthetic tool_result
message so the history is valid before the next request.
"""
if not history:
return
last = history[-1]
if last.get("role") != "assistant":
return
content = last.get("content", [])
tool_use_ids = []
for block in content:
if isinstance(block, dict):
if block.get("type") == "tool_use":
tool_use_ids.append(block["id"])
if not tool_use_ids:
return
history.append({
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": tid,
"content": "Tool call was not completed (session interrupted).",
}
for tid in tool_use_ids
],
})
def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_items: list[dict[str, Any]] | None = None, discussion_history: str = "", pre_tool_callback: Optional[Callable[[str], bool]] = None, qa_callback: Optional[Callable[[str], str]] = None) -> str:
try:
_ensure_anthropic_client()
mcp_client.configure(file_items or [], [base_dir])
# Split system into two cache breakpoints:
# 1. Stable system prompt (never changes — always a cache hit)
# 2. Dynamic file context (invalidated only when files change)
stable_prompt = _get_combined_system_prompt()
stable_blocks = [{"type": "text", "text": stable_prompt, "cache_control": {"type": "ephemeral"}}]
context_text = f"\n\n<context>\n{md_content}\n</context>"
context_blocks = _build_chunked_context_blocks(context_text)
system_blocks = stable_blocks + context_blocks
# Prepend discussion history to the first user message if this is a fresh session
if discussion_history and not _anthropic_history:
user_content: list[dict[str, Any]] = [{"type": "text", "text": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"}]
else:
user_content = [{"type": "text", "text": user_message}]
# COMPRESS HISTORY: Truncate massive tool outputs from previous turns
for msg in _anthropic_history:
if msg.get("role") == "user" and isinstance(msg.get("content"), list):
modified = False
for block in msg["content"]:
if isinstance(block, dict) and block.get("type") == "tool_result":
t_content = block.get("content", "")
if _history_trunc_limit > 0 and isinstance(t_content, str) and len(t_content) > _history_trunc_limit:
block["content"] = t_content[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS. Original output was too large.]"
modified = True
if modified:
_invalidate_token_estimate(msg)
_strip_cache_controls(_anthropic_history)
_repair_anthropic_history(_anthropic_history)
_anthropic_history.append({"role": "user", "content": user_content})
# Use the 4th cache breakpoint to cache the conversation history prefix.
# This is placed on the second-to-last user message (the last stable one).
_add_history_cache_breakpoint(_anthropic_history)
n_chunks = len(system_blocks)
_append_comms("OUT", "request", {
"message": (
f"[system {n_chunks} chunk(s), {len(md_content)} chars context] "
f"{user_message[:200]}{'...' if len(user_message) > 200 else ''}"
),
})
all_text_parts: list[str] = []
_cumulative_tool_bytes = 0
# We allow MAX_TOOL_ROUNDS, plus 1 final loop to get the text synthesis
for round_idx in range(MAX_TOOL_ROUNDS + 2):
# Trim history to fit within token budget before each API call
dropped = _trim_anthropic_history(system_blocks, _anthropic_history)
if dropped > 0:
est_tokens = _estimate_prompt_tokens(system_blocks, _anthropic_history)
_append_comms("OUT", "request", {
"message": (
f"[HISTORY TRIMMED: dropped {dropped} old messages to fit token budget. "
f"Estimated {est_tokens} tokens remaining. {len(_anthropic_history)} messages in history.]"
),
})
def _strip_private_keys(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
return [{k: v for k, v in m.items() if not k.startswith("_")} for m in history]
events.emit("request_start", payload={"provider": "anthropic", "model": _model, "round": round_idx})
response = _anthropic_client.messages.create(
model=_model,
max_tokens=_max_tokens,
temperature=_temperature,
system=system_blocks,
tools=_get_anthropic_tools(),
messages=_strip_private_keys(_anthropic_history),
)
# Convert SDK content block objects to plain dicts before storing in history
serialised_content = [_content_block_to_dict(b) for b in response.content]
_anthropic_history.append({
"role": "assistant",
"content": serialised_content,
})
text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text]
if text_blocks:
all_text_parts.append("\n".join(text_blocks))
tool_use_blocks = [
{"id": b.id, "name": b.name, "input": b.input}
for b in response.content
if getattr(b, "type", None) == "tool_use"
]
usage_dict: dict[str, Any] = {}
if response.usage:
usage_dict["input_tokens"] = response.usage.input_tokens
usage_dict["output_tokens"] = response.usage.output_tokens
cache_creation = getattr(response.usage, "cache_creation_input_tokens", None)
cache_read = getattr(response.usage, "cache_read_input_tokens", None)
if cache_creation is not None:
usage_dict["cache_creation_input_tokens"] = cache_creation
if cache_read is not None:
usage_dict["cache_read_input_tokens"] = cache_read
events.emit("response_received", payload={"provider": "anthropic", "model": _model, "usage": usage_dict, "round": round_idx})
_append_comms("IN", "response", {
"round": round_idx,
"stop_reason": response.stop_reason,
"text": "\n".join(text_blocks),
"tool_calls": tool_use_blocks,
"usage": usage_dict,
})
if response.stop_reason != "tool_use" or not tool_use_blocks:
break
if round_idx > MAX_TOOL_ROUNDS:
# The model ignored the MAX ROUNDS warning and kept calling tools.
# Force abort to prevent infinite loop.
break
tool_results: list[dict[str, Any]] = []
for block in response.content:
if getattr(block, "type", None) != "tool_use":
continue
b_name = getattr(block, "name", None)
b_id = getattr(block, "id", "")
b_input = getattr(block, "input", {})
# Check for tool confirmation if callback is provided
if pre_tool_callback:
payload_str = json.dumps({"tool": b_name, "args": b_input})
if not pre_tool_callback(payload_str):
output = "USER REJECTED: tool execution cancelled"
tool_results.append({
"type": "tool_result",
"tool_use_id": b_id,
"content": output,
})
continue
events.emit("tool_execution", payload={"status": "started", "tool": b_name, "args": b_input, "round": round_idx})
if b_name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": b_name, "id": b_id, "args": b_input})
output = mcp_client.dispatch(b_name, b_input)
_append_comms("IN", "tool_result", {"name": b_name, "id": b_id, "output": output})
truncated = _truncate_tool_output(output)
_cumulative_tool_bytes += len(truncated)
tool_results.append({
"type": "tool_result",
"tool_use_id": b_id,
"content": truncated,
})
events.emit("tool_execution", payload={"status": "completed", "tool": b_name, "result": output, "round": round_idx})
elif b_name == TOOL_NAME:
script = b_input.get("script", "")
_append_comms("OUT", "tool_call", {
"name": TOOL_NAME,
"id": b_id,
"script": script,
})
output = _run_script(script, base_dir, qa_callback)
_append_comms("IN", "tool_result", {
"name": TOOL_NAME,
"id": b_id,
"output": output,
})
truncated = _truncate_tool_output(output)
_cumulative_tool_bytes += len(truncated)
tool_results.append({
"type": "tool_result",
"tool_use_id": b_id,
"content": truncated,
})
events.emit("tool_execution", payload={"status": "completed", "tool": b_name, "result": output, "round": round_idx})
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
tool_results.append({
"type": "text",
"text": f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
})
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
# Refresh file context after tool calls — only inject CHANGED files
if file_items:
file_items, changed = _reread_file_items(file_items)
refreshed_ctx = _build_file_diff_text(changed)
if refreshed_ctx:
tool_results.append({
"type": "text",
"text": (
"[FILES UPDATED — current contents below. "
"Do NOT re-read these files with PowerShell.]\n\n"
+ refreshed_ctx
),
})
if round_idx == MAX_TOOL_ROUNDS:
tool_results.append({
"type": "text",
"text": "SYSTEM WARNING: MAX TOOL ROUNDS REACHED. YOU MUST PROVIDE YOUR FINAL ANSWER NOW WITHOUT CALLING ANY MORE TOOLS."
})
_anthropic_history.append({
"role": "user",
"content": tool_results,
})
_append_comms("OUT", "tool_result_send", {
"results": [
{"tool_use_id": r["tool_use_id"], "content": r["content"]}
for r in tool_results if r.get("type") == "tool_result"
],
})
final_text = "\n\n".join(all_text_parts)
return final_text if final_text.strip() else "(No text returned by the model)"
except ProviderError:
raise
except Exception as exc:
raise _classify_anthropic_error(exc) from exc
# ------------------------------------------------------------------ deepseek
def _ensure_deepseek_client() -> None:
global _deepseek_client
if _deepseek_client is None:
creds = _load_credentials()
# Placeholder for Dedicated DeepSeek SDK instantiation
# import deepseek
# _deepseek_client = deepseek.DeepSeek(api_key=creds["deepseek"]["api_key"])
pass
def _send_deepseek(md_content: str, user_message: str, base_dir: str,
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
stream: bool = False,
pre_tool_callback: Optional[Callable[[str], bool]] = None,
qa_callback: Optional[Callable[[str], str]] = None) -> str:
"""
Sends a message to the DeepSeek API, handling tool calls and history.
Supports streaming responses.
"""
try:
mcp_client.configure(file_items or [], [base_dir])
creds = _load_credentials()
api_key = creds.get("deepseek", {}).get("api_key")
if not api_key:
raise ValueError("DeepSeek API key not found in credentials.toml")
# DeepSeek API details
api_url = "https://api.deepseek.com/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# Build the messages for the current API call
current_api_messages: list[dict[str, Any]] = []
with _deepseek_history_lock:
for msg in _deepseek_history:
current_api_messages.append(msg)
# Add the current user's input for this turn
initial_user_message_content = user_message
if discussion_history:
initial_user_message_content = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"
current_api_messages.append({"role": "user", "content": initial_user_message_content})
# Construct the full request payload
request_payload: dict[str, Any] = {
"model": _model,
"messages": current_api_messages,
"temperature": _temperature,
"max_tokens": _max_tokens,
"stream": stream,
}
# Insert system prompt at the beginning
sys_msg = {"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}
request_payload["messages"].insert(0, sys_msg)
all_text_parts: list[str] = []
_cumulative_tool_bytes = 0
round_idx = 0
while round_idx <= MAX_TOOL_ROUNDS + 1:
events.emit("request_start", payload={"provider": "deepseek", "model": _model, "round": round_idx, "streaming": stream})
try:
response = requests.post(api_url, headers=headers, json=request_payload, timeout=60, stream=stream)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise _classify_deepseek_error(e) from e
# Process response
if stream:
aggregated_content = ""
aggregated_tool_calls: list[dict[str, Any]] = []
aggregated_reasoning = ""
current_usage: dict[str, Any] = {}
final_finish_reason = "stop"
for line in response.iter_lines():
if not line:
continue
decoded = line.decode('utf-8')
if decoded.startswith('data: '):
chunk_str = decoded[len('data: '):]
if chunk_str.strip() == '[DONE]':
continue
try:
chunk = json.loads(chunk_str)
delta = chunk.get("choices", [{}])[0].get("delta", {})
if delta.get("content"):
aggregated_content += delta["content"]
if delta.get("reasoning_content"):
aggregated_reasoning += delta["reasoning_content"]
if delta.get("tool_calls"):
# Simple aggregation of tool call deltas
for tc_delta in delta["tool_calls"]:
idx = tc_delta.get("index", 0)
while len(aggregated_tool_calls) <= idx:
aggregated_tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
target = aggregated_tool_calls[idx]
if tc_delta.get("id"):
target["id"] = tc_delta["id"]
if tc_delta.get("function", {}).get("name"):
target["function"]["name"] += tc_delta["function"]["name"]
if tc_delta.get("function", {}).get("arguments"):
target["function"]["arguments"] += tc_delta["function"]["arguments"]
if chunk.get("choices", [{}])[0].get("finish_reason"):
final_finish_reason = chunk["choices"][0]["finish_reason"]
if chunk.get("usage"):
current_usage = chunk["usage"]
except json.JSONDecodeError:
continue
assistant_text = aggregated_content
tool_calls_raw = aggregated_tool_calls
reasoning_content = aggregated_reasoning
finish_reason = final_finish_reason
usage = current_usage
else:
response_data = response.json()
choices = response_data.get("choices", [])
if not choices:
_append_comms("IN", "response", {"round": round_idx, "text": "(No choices returned)", "usage": response_data.get("usage", {})})
break
choice = choices[0]
message = choice.get("message", {})
assistant_text = message.get("content", "")
tool_calls_raw = message.get("tool_calls", [])
reasoning_content = message.get("reasoning_content", "")
finish_reason = choice.get("finish_reason", "stop")
usage = response_data.get("usage", {})
# Format reasoning content if it exists
thinking_tags = ""
if reasoning_content:
thinking_tags = f"<thinking>\n{reasoning_content}\n</thinking>\n"
full_assistant_text = thinking_tags + assistant_text
# Update history
with _deepseek_history_lock:
msg_to_store: dict[str, Any] = {"role": "assistant", "content": assistant_text}
if reasoning_content:
msg_to_store["reasoning_content"] = reasoning_content
if tool_calls_raw:
msg_to_store["tool_calls"] = tool_calls_raw
_deepseek_history.append(msg_to_store)
if full_assistant_text:
all_text_parts.append(full_assistant_text)
_append_comms("IN", "response", {
"round": round_idx,
"stop_reason": finish_reason,
"text": full_assistant_text,
"tool_calls": tool_calls_raw,
"usage": usage,
"streaming": stream
})
if finish_reason != "tool_calls" and not tool_calls_raw:
break
if round_idx > MAX_TOOL_ROUNDS:
break
tool_results_for_history: list[dict[str, Any]] = []
for i, tc_raw in enumerate(tool_calls_raw):
tool_info = tc_raw.get("function", {})
tool_name = tool_info.get("name")
tool_args_str = tool_info.get("arguments", "{}")
tool_id = tc_raw.get("id")
try:
tool_args = json.loads(tool_args_str)
except:
tool_args = {}
# Check for tool confirmation if callback is provided
if pre_tool_callback:
payload_str = json.dumps({"tool": tool_name, "args": tool_args})
if not pre_tool_callback(payload_str):
tool_output = "USER REJECTED: tool execution cancelled"
tool_results_for_history.append({
"role": "tool",
"tool_call_id": tool_id,
"content": tool_output,
})
_append_comms("IN", "tool_result", {"name": tool_name, "id": tool_id, "output": tool_output})
continue
events.emit("tool_execution", payload={"status": "started", "tool": tool_name, "args": tool_args, "round": round_idx})
if tool_name in mcp_client.TOOL_NAMES:
_append_comms("OUT", "tool_call", {"name": tool_name, "id": tool_id, "args": tool_args})
tool_output = mcp_client.dispatch(tool_name, tool_args)
elif tool_name == TOOL_NAME:
script = tool_args.get("script", "")
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": tool_id, "script": script})
tool_output = _run_script(script, base_dir, qa_callback)
else:
tool_output = f"ERROR: unknown tool '{tool_name}'"
if i == len(tool_calls_raw) - 1:
if file_items:
file_items, changed = _reread_file_items(file_items)
ctx = _build_file_diff_text(changed)
if ctx:
tool_output += f"\n\n[SYSTEM: FILES UPDATED]\n\n{ctx}"
if round_idx == MAX_TOOL_ROUNDS:
tool_output += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
tool_output = _truncate_tool_output(tool_output)
_cumulative_tool_bytes += len(tool_output)
tool_results_for_history.append({
"role": "tool",
"tool_call_id": tool_id,
"content": tool_output,
})
_append_comms("IN", "tool_result", {"name": tool_name, "id": tool_id, "output": tool_output})
events.emit("tool_execution", payload={"status": "completed", "tool": tool_name, "result": tool_output, "round": round_idx})
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
tool_results_for_history.append({
"role": "user",
"content": f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
})
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
with _deepseek_history_lock:
for tr in tool_results_for_history:
_deepseek_history.append(tr)
# Update for next round
next_messages: list[dict[str, Any]] = []
with _deepseek_history_lock:
for msg in _deepseek_history:
next_messages.append(msg)
next_messages.insert(0, sys_msg)
request_payload["messages"] = next_messages
round_idx += 1
return "\n\n".join(all_text_parts) if all_text_parts else "(No text returned)"
except Exception as e:
raise _classify_deepseek_error(e) from e
def run_tier4_analysis(stderr: str) -> str:
"""
Stateless Tier 4 QA analysis of an error message.
Uses gemini-2.5-flash-lite to summarize the error and suggest a fix.
"""
if not stderr or not stderr.strip():
return ""
try:
_ensure_gemini_client()
prompt = (
f"You are a Tier 4 QA Agent specializing in error analysis.\n"
f"Analyze the following stderr output from a PowerShell command:\n\n"
f"```\n{stderr}\n```\n\n"
f"Provide a concise summary of the failure and suggest a fix in approximately 20 words."
)
# Use flash-lite for cost-effective stateless analysis
model_name = "gemini-2.5-flash-lite"
# We don't use the chat session here to keep it stateless
resp = _gemini_client.models.generate_content(
model=model_name,
contents=prompt,
config=types.GenerateContentConfig(
temperature=0.0,
max_output_tokens=150,
)
)
analysis = resp.text.strip()
return analysis
except Exception as e:
# We don't want to crash the main loop if QA analysis fails
return f"[QA ANALYSIS FAILED] {e}"
# ------------------------------------------------------------------ unified send
def send(
md_content: str,
user_message: str,
base_dir: str = ".",
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
stream: bool = False,
pre_tool_callback: Optional[Callable[[str], bool]] = None,
qa_callback: Optional[Callable[[str], str]] = None,
) -> str:
"""
Send a message to the active provider.
md_content : aggregated markdown string (for Gemini: stable content only,
for Anthropic: full content including history)
user_message : the user question / instruction
base_dir : project base directory (for PowerShell tool calls)
file_items : list of file dicts from aggregate.build_file_items() for
dynamic context refresh after tool calls
discussion_history : discussion history text (used by Gemini to inject as
conversation message instead of caching it)
stream : Whether to use streaming (supported by DeepSeek)
pre_tool_callback : Optional callback (payload: str) -> bool called before tool execution
qa_callback : Optional callback (stderr: str) -> str called for Tier 4 error analysis
"""
with _send_lock:
if _provider == "gemini":
return _send_gemini(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback, qa_callback)
elif _provider == "gemini_cli":
return _send_gemini_cli(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback, qa_callback)
elif _provider == "anthropic":
return _send_anthropic(md_content, user_message, base_dir, file_items, discussion_history, pre_tool_callback, qa_callback)
elif _provider == "deepseek":
return _send_deepseek(md_content, user_message, base_dir, file_items, discussion_history, stream=stream, pre_tool_callback=pre_tool_callback, qa_callback=qa_callback)
raise ValueError(f"unknown provider: {_provider}")
def get_history_bleed_stats(md_content: str | None = None) -> dict[str, Any]:
"""
Calculates how close the current conversation history is to the token limit.
If md_content is provided and no chat session exists, it estimates based on md_content.
"""
if _provider == "anthropic":
# For Anthropic, we have a robust estimator
with _anthropic_history_lock:
history_snapshot = list(_anthropic_history)
current_tokens = _estimate_prompt_tokens([], history_snapshot)
if md_content:
current_tokens += max(1, int(len(md_content) / _CHARS_PER_TOKEN))
limit_tokens = _ANTHROPIC_MAX_PROMPT_TOKENS
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
return {
"provider": "anthropic",
"limit": limit_tokens,
"current": current_tokens,
"percentage": percentage,
}
elif _provider == "gemini":
effective_limit = _history_trunc_limit if _history_trunc_limit > 0 else _GEMINI_MAX_INPUT_TOKENS
if _gemini_chat:
try:
_ensure_gemini_client()
raw_history = list(_get_gemini_history_list(_gemini_chat))
# Copy and correct roles for counting
history = []
for c in raw_history:
# Gemini roles MUST be 'user' or 'model'
role = "model" if c.role in ["assistant", "model"] else "user"
history.append(types.Content(role=role, parts=c.parts))
if md_content:
# Prepend context as a user part for counting
history.insert(0, types.Content(role="user", parts=[types.Part.from_text(text=md_content)]))
if not history:
print("[DEBUG] Gemini count_tokens skipped: no history or md_content")
return {
"provider": "gemini",
"limit": effective_limit,
"current": 0,
"percentage": 0,
}
print(f"[DEBUG] Gemini count_tokens on {len(history)} messages using model {_model}")
resp = _gemini_client.models.count_tokens(
model=_model,
contents=history
)
current_tokens = resp.total_tokens
percentage = (current_tokens / effective_limit) * 100 if effective_limit > 0 else 0
print(f"[DEBUG] Gemini current_tokens={current_tokens}, percentage={percentage:.4f}%")
return {
"provider": "gemini",
"limit": effective_limit,
"current": current_tokens,
"percentage": percentage,
}
except Exception as e:
print(f"[DEBUG] Gemini count_tokens error: {e}")
pass
elif md_content:
try:
_ensure_gemini_client()
print(f"[DEBUG] Gemini count_tokens (MD ONLY) using model {_model}")
resp = _gemini_client.models.count_tokens(
model=_model,
contents=[types.Content(role="user", parts=[types.Part.from_text(text=md_content)])]
)
current_tokens = resp.total_tokens
percentage = (current_tokens / effective_limit) * 100 if effective_limit > 0 else 0
print(f"[DEBUG] Gemini (MD ONLY) current_tokens={current_tokens}, percentage={percentage:.4f}%")
return {
"provider": "gemini",
"limit": effective_limit,
"current": current_tokens,
"percentage": percentage,
}
except Exception as e:
print(f"[DEBUG] Gemini count_tokens (MD ONLY) error: {e}")
pass
return {
"provider": "gemini",
"limit": effective_limit,
"current": 0,
"percentage": 0,
}
elif _provider == "gemini_cli":
effective_limit = _history_trunc_limit if _history_trunc_limit > 0 else _GEMINI_MAX_INPUT_TOKENS
# For Gemini CLI, we don't have direct count_tokens access without making a call,
# so we report the limit and current usage from the last run if available.
limit_tokens = effective_limit
current_tokens = 0
if _gemini_cli_adapter and _gemini_cli_adapter.last_usage:
# Stats from CLI use 'input_tokens' or 'input'
u = _gemini_cli_adapter.last_usage
current_tokens = u.get("input_tokens") or u.get("input", 0)
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
return {
"provider": "gemini_cli",
"limit": limit_tokens,
"current": current_tokens,
"percentage": percentage,
}
elif _provider == "deepseek":
limit_tokens = 64000
current_tokens = 0
with _deepseek_history_lock:
for msg in _deepseek_history:
content = msg.get("content", "")
if isinstance(content, str):
current_tokens += len(content)
elif isinstance(content, list):
for block in content:
if isinstance(block, dict):
text = block.get("text", "")
if isinstance(text, str):
current_tokens += len(text)
inp = block.get("input")
if isinstance(inp, dict):
import json as _json
current_tokens += len(_json.dumps(inp, ensure_ascii=False))
if md_content:
current_tokens += len(md_content)
if user_message:
current_tokens += len(user_message)
current_tokens = max(1, int(current_tokens / _CHARS_PER_TOKEN))
percentage = (current_tokens / limit_tokens) * 100 if limit_tokens > 0 else 0
return {
"provider": "deepseek",
"limit": limit_tokens,
"current": current_tokens,
"percentage": percentage,
}
# Default empty state
return {
"provider": _provider,
"limit": 0,
"current": 0,
"percentage": 0,
}