588 lines
21 KiB
Python
588 lines
21 KiB
Python
# ai_client.py
|
|
import tomllib
|
|
import json
|
|
import datetime
|
|
from pathlib import Path
|
|
|
|
_provider: str = "gemini"
|
|
_model: str = "gemini-2.0-flash"
|
|
|
|
_gemini_client = None
|
|
_gemini_chat = None
|
|
|
|
_anthropic_client = None
|
|
_anthropic_history: list[dict] = []
|
|
|
|
# Injected by gui.py - called when AI wants to run a command.
|
|
# Signature: (script: str, base_dir: str) -> str | None
|
|
# Returns the output string if approved, None if rejected.
|
|
confirm_and_run_callback = None
|
|
|
|
# Injected by gui.py - called whenever a comms entry is appended.
|
|
# Signature: (entry: dict) -> None
|
|
comms_log_callback = None
|
|
|
|
# Injected by gui.py - called whenever a tool call completes.
|
|
# Signature: (script: str, result: str, script_path: str | None) -> None
|
|
tool_log_callback = None
|
|
|
|
MAX_TOOL_ROUNDS = 5
|
|
|
|
# Anthropic system prompt - cached as the first turn so it counts toward
|
|
# the prompt-cache prefix on every subsequent request.
|
|
_ANTHROPIC_SYSTEM = (
|
|
"You are a helpful coding assistant with access to a PowerShell tool. "
|
|
"When asked to create or edit files, prefer targeted edits over full rewrites. "
|
|
"Always explain what you are doing before invoking the tool."
|
|
)
|
|
|
|
# ------------------------------------------------------------------ comms log
|
|
|
|
_comms_log: list[dict] = []
|
|
|
|
MAX_FIELD_CHARS = 400 # beyond this we show a truncated preview in the UI
|
|
|
|
def _clamp(value, max_chars: int = MAX_FIELD_CHARS) -> tuple[str, bool]:
|
|
"""Return (display_str, was_truncated)."""
|
|
if isinstance(value, (dict, list)):
|
|
s = json.dumps(value, ensure_ascii=False, indent=2)
|
|
else:
|
|
s = str(value)
|
|
if len(s) > max_chars:
|
|
return s[:max_chars], True
|
|
return s, False
|
|
|
|
|
|
def _append_comms(direction: str, kind: str, payload: dict):
|
|
"""
|
|
direction : "OUT" | "IN"
|
|
kind : "request" | "response" | "tool_call" | "tool_result"
|
|
payload : raw dict describing the event
|
|
"""
|
|
entry = {
|
|
"ts": datetime.datetime.now().strftime("%H:%M:%S"),
|
|
"direction": direction,
|
|
"kind": kind,
|
|
"provider": _provider,
|
|
"model": _model,
|
|
"payload": payload,
|
|
}
|
|
_comms_log.append(entry)
|
|
if comms_log_callback is not None:
|
|
comms_log_callback(entry)
|
|
|
|
|
|
def get_comms_log() -> list[dict]:
|
|
return list(_comms_log)
|
|
|
|
|
|
def clear_comms_log():
|
|
_comms_log.clear()
|
|
|
|
|
|
def _load_credentials() -> dict:
|
|
with open("credentials.toml", "rb") as f:
|
|
return tomllib.load(f)
|
|
|
|
# ------------------------------------------------------------------ provider errors
|
|
|
|
class ProviderError(Exception):
|
|
"""
|
|
Raised when the upstream API returns a hard error we want to surface
|
|
distinctly in the UI (quota, rate-limit, auth, balance, etc.).
|
|
|
|
Attributes
|
|
----------
|
|
kind : str
|
|
One of: "quota", "rate_limit", "auth", "balance", "network", "unknown"
|
|
provider : str
|
|
"gemini" or "anthropic"
|
|
original : Exception
|
|
The underlying SDK exception.
|
|
"""
|
|
def __init__(self, kind: str, provider: str, original: Exception):
|
|
self.kind = kind
|
|
self.provider = provider
|
|
self.original = original
|
|
super().__init__(str(original))
|
|
|
|
# Human-readable banner shown in the Response panel
|
|
def ui_message(self) -> str:
|
|
labels = {
|
|
"quota": "QUOTA EXHAUSTED",
|
|
"rate_limit": "RATE LIMITED",
|
|
"auth": "AUTH / API KEY ERROR",
|
|
"balance": "BALANCE / BILLING ERROR",
|
|
"network": "NETWORK / CONNECTION ERROR",
|
|
"unknown": "API ERROR",
|
|
}
|
|
label = labels.get(self.kind, "API ERROR")
|
|
return f"[{self.provider.upper()} {label}]\n\n{self.original}"
|
|
|
|
|
|
def _classify_anthropic_error(exc: Exception) -> ProviderError:
|
|
"""Map an anthropic SDK exception to a ProviderError."""
|
|
try:
|
|
import anthropic
|
|
if isinstance(exc, anthropic.RateLimitError):
|
|
return ProviderError("rate_limit", "anthropic", exc)
|
|
if isinstance(exc, anthropic.AuthenticationError):
|
|
return ProviderError("auth", "anthropic", exc)
|
|
if isinstance(exc, anthropic.PermissionDeniedError):
|
|
return ProviderError("auth", "anthropic", exc)
|
|
if isinstance(exc, anthropic.APIConnectionError):
|
|
return ProviderError("network", "anthropic", exc)
|
|
if isinstance(exc, anthropic.APIStatusError):
|
|
status = getattr(exc, "status_code", 0)
|
|
body = str(exc).lower()
|
|
if status == 429:
|
|
return ProviderError("rate_limit", "anthropic", exc)
|
|
if status in (401, 403):
|
|
return ProviderError("auth", "anthropic", exc)
|
|
if status == 402:
|
|
return ProviderError("balance", "anthropic", exc)
|
|
# Anthropic puts credit-balance errors in the body at 400
|
|
if "credit" in body or "balance" in body or "billing" in body:
|
|
return ProviderError("balance", "anthropic", exc)
|
|
if "quota" in body or "limit" in body or "exceeded" in body:
|
|
return ProviderError("quota", "anthropic", exc)
|
|
except ImportError:
|
|
pass
|
|
return ProviderError("unknown", "anthropic", exc)
|
|
|
|
|
|
def _classify_gemini_error(exc: Exception) -> ProviderError:
|
|
"""Map a google-genai SDK exception to a ProviderError."""
|
|
body = str(exc).lower()
|
|
try:
|
|
from google.api_core import exceptions as gac
|
|
if isinstance(exc, gac.ResourceExhausted):
|
|
return ProviderError("quota", "gemini", exc)
|
|
if isinstance(exc, gac.TooManyRequests):
|
|
return ProviderError("rate_limit", "gemini", exc)
|
|
if isinstance(exc, (gac.Unauthenticated, gac.PermissionDenied)):
|
|
return ProviderError("auth", "gemini", exc)
|
|
if isinstance(exc, gac.ServiceUnavailable):
|
|
return ProviderError("network", "gemini", exc)
|
|
except ImportError:
|
|
pass
|
|
if "429" in body or "quota" in body or "resource exhausted" in body:
|
|
return ProviderError("quota", "gemini", exc)
|
|
if "rate" in body and "limit" in body:
|
|
return ProviderError("rate_limit", "gemini", exc)
|
|
if "401" in body or "403" in body or "api key" in body or "unauthenticated" in body:
|
|
return ProviderError("auth", "gemini", exc)
|
|
if "402" in body or "billing" in body or "balance" in body or "payment" in body:
|
|
return ProviderError("balance", "gemini", exc)
|
|
if "connection" in body or "timeout" in body or "unreachable" in body:
|
|
return ProviderError("network", "gemini", exc)
|
|
return ProviderError("unknown", "gemini", exc)
|
|
|
|
# ------------------------------------------------------------------ provider setup
|
|
|
|
def set_provider(provider: str, model: str):
|
|
global _provider, _model
|
|
_provider = provider
|
|
_model = model
|
|
|
|
def reset_session():
|
|
global _gemini_client, _gemini_chat
|
|
global _anthropic_client, _anthropic_history
|
|
_gemini_client = None
|
|
_gemini_chat = None
|
|
_anthropic_client = None
|
|
_anthropic_history = []
|
|
|
|
# ------------------------------------------------------------------ model listing
|
|
|
|
def list_models(provider: str) -> list[str]:
|
|
creds = _load_credentials()
|
|
if provider == "gemini":
|
|
return _list_gemini_models(creds["gemini"]["api_key"])
|
|
elif provider == "anthropic":
|
|
return _list_anthropic_models()
|
|
return []
|
|
|
|
def _list_gemini_models(api_key: str) -> list[str]:
|
|
from google import genai
|
|
try:
|
|
client = genai.Client(api_key=api_key)
|
|
models = []
|
|
for m in client.models.list():
|
|
name = m.name
|
|
if name.startswith("models/"):
|
|
name = name[len("models/"):]
|
|
if "gemini" in name.lower():
|
|
models.append(name)
|
|
return sorted(models)
|
|
except Exception as exc:
|
|
raise _classify_gemini_error(exc) from exc
|
|
|
|
def _list_anthropic_models() -> list[str]:
|
|
import anthropic
|
|
try:
|
|
creds = _load_credentials()
|
|
client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
|
|
models = []
|
|
for m in client.models.list():
|
|
models.append(m.id)
|
|
return sorted(models)
|
|
except Exception as exc:
|
|
raise _classify_anthropic_error(exc) from exc
|
|
|
|
|
|
# --------------------------------------------------------- tool definition
|
|
|
|
TOOL_NAME = "run_powershell"
|
|
|
|
_ANTHROPIC_TOOLS = [
|
|
{
|
|
"name": TOOL_NAME,
|
|
"description": (
|
|
"Run a PowerShell script within the project base_dir. "
|
|
"Use this to create, edit, rename, or delete files and directories. "
|
|
"The working directory is set to base_dir automatically. "
|
|
"Always prefer targeted edits over full rewrites where possible. "
|
|
"stdout and stderr are returned to you as the result."
|
|
),
|
|
"input_schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"script": {
|
|
"type": "string",
|
|
"description": "The PowerShell script to execute."
|
|
}
|
|
},
|
|
"required": ["script"]
|
|
}
|
|
}
|
|
]
|
|
|
|
def _gemini_tool_declaration():
|
|
from google.genai import types
|
|
return types.Tool(
|
|
function_declarations=[
|
|
types.FunctionDeclaration(
|
|
name=TOOL_NAME,
|
|
description=(
|
|
"Run a PowerShell script within the project base_dir. "
|
|
"Use this to create, edit, rename, or delete files and directories. "
|
|
"The working directory is set to base_dir automatically. "
|
|
"stdout and stderr are returned to you as the result."
|
|
),
|
|
parameters=types.Schema(
|
|
type=types.Type.OBJECT,
|
|
properties={
|
|
"script": types.Schema(
|
|
type=types.Type.STRING,
|
|
description="The PowerShell script to execute."
|
|
)
|
|
},
|
|
required=["script"]
|
|
)
|
|
)
|
|
]
|
|
)
|
|
|
|
def _run_script(script: str, base_dir: str) -> str:
|
|
"""
|
|
Delegate to the GUI confirmation callback.
|
|
Returns result string (stdout/stderr) or a rejection message.
|
|
Also fires tool_log_callback if registered.
|
|
"""
|
|
if confirm_and_run_callback is None:
|
|
return "ERROR: no confirmation handler registered"
|
|
# confirm_and_run_callback returns (result, script_path) or None
|
|
outcome = confirm_and_run_callback(script, base_dir)
|
|
if outcome is None:
|
|
result = "USER REJECTED: command was not executed"
|
|
if tool_log_callback is not None:
|
|
tool_log_callback(script, result, None)
|
|
return result
|
|
result, script_path = outcome
|
|
if tool_log_callback is not None:
|
|
tool_log_callback(script, result, script_path)
|
|
return result
|
|
|
|
# ------------------------------------------------------------------ gemini
|
|
|
|
def _ensure_gemini_client():
|
|
global _gemini_client
|
|
if _gemini_client is None:
|
|
from google import genai
|
|
creds = _load_credentials()
|
|
_gemini_client = genai.Client(api_key=creds["gemini"]["api_key"])
|
|
|
|
def _send_gemini(md_content: str, user_message: str, base_dir: str) -> str:
|
|
global _gemini_chat
|
|
from google import genai
|
|
from google.genai import types
|
|
|
|
try:
|
|
_ensure_gemini_client()
|
|
|
|
if _gemini_chat is None:
|
|
_gemini_chat = _gemini_client.chats.create(
|
|
model=_model,
|
|
config=types.GenerateContentConfig(
|
|
tools=[_gemini_tool_declaration()]
|
|
)
|
|
)
|
|
|
|
full_message = f"<context>\n{md_content}\n</context>\n\n{user_message}"
|
|
|
|
_append_comms("OUT", "request", {
|
|
"message": full_message,
|
|
})
|
|
|
|
response = _gemini_chat.send_message(full_message)
|
|
|
|
for round_idx in range(MAX_TOOL_ROUNDS):
|
|
text_parts_raw = [
|
|
part.text
|
|
for candidate in response.candidates
|
|
for part in candidate.content.parts
|
|
if hasattr(part, "text") and part.text
|
|
]
|
|
tool_calls = [
|
|
part.function_call
|
|
for candidate in response.candidates
|
|
for part in candidate.content.parts
|
|
if part.function_call is not None
|
|
]
|
|
|
|
_append_comms("IN", "response", {
|
|
"round": round_idx,
|
|
"text": "\n".join(text_parts_raw),
|
|
"tool_calls": [{"name": fc.name, "args": dict(fc.args)} for fc in tool_calls],
|
|
})
|
|
|
|
if not tool_calls:
|
|
break
|
|
|
|
function_responses = []
|
|
for fc in tool_calls:
|
|
if fc.name == TOOL_NAME:
|
|
script = fc.args.get("script", "")
|
|
_append_comms("OUT", "tool_call", {
|
|
"name": TOOL_NAME,
|
|
"script": script,
|
|
})
|
|
output = _run_script(script, base_dir)
|
|
_append_comms("IN", "tool_result", {
|
|
"name": TOOL_NAME,
|
|
"output": output,
|
|
})
|
|
function_responses.append(
|
|
types.Part.from_function_response(
|
|
name=TOOL_NAME,
|
|
response={"output": output}
|
|
)
|
|
)
|
|
|
|
if not function_responses:
|
|
break
|
|
|
|
response = _gemini_chat.send_message(function_responses)
|
|
|
|
text_parts = [
|
|
part.text
|
|
for candidate in response.candidates
|
|
for part in candidate.content.parts
|
|
if hasattr(part, "text") and part.text
|
|
]
|
|
return "\n".join(text_parts)
|
|
|
|
except ProviderError:
|
|
raise
|
|
except Exception as exc:
|
|
raise _classify_gemini_error(exc) from exc
|
|
|
|
# ------------------------------------------------------------------ anthropic
|
|
#
|
|
# Caching strategy (Anthropic prompt caching):
|
|
#
|
|
# The Anthropic API caches a prefix of the input tokens. To maximise hits:
|
|
#
|
|
# 1. A persistent system prompt is sent on every request with
|
|
# cache_control={"type":"ephemeral"} so it is cached after the first call
|
|
# and reused on subsequent calls within the 5-minute TTL window.
|
|
#
|
|
# 2. The context block (aggregated markdown) is placed as the FIRST user
|
|
# message in the history and also marked with cache_control. Because the
|
|
# system prompt and the context are stable across tool-use rounds within a
|
|
# single send() call, the cache hit rate is very high after round 0.
|
|
#
|
|
# 3. Tool definitions are passed with cache_control on the last tool so the
|
|
# entire tools array is also cached.
|
|
#
|
|
# Token accounting: the response payload contains cache_creation_input_tokens
|
|
# and cache_read_input_tokens in addition to the regular input_tokens field.
|
|
# These are included in the comms log under "usage".
|
|
|
|
def _anthropic_tools_with_cache() -> list[dict]:
|
|
"""Return the tools list with cache_control on the last entry."""
|
|
import copy
|
|
tools = copy.deepcopy(_ANTHROPIC_TOOLS)
|
|
# Mark the last tool so the entire prefix (system + tools) gets cached
|
|
tools[-1]["cache_control"] = {"type": "ephemeral"}
|
|
return tools
|
|
|
|
|
|
def _ensure_anthropic_client():
|
|
global _anthropic_client
|
|
if _anthropic_client is None:
|
|
import anthropic
|
|
creds = _load_credentials()
|
|
_anthropic_client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
|
|
|
|
|
|
def _send_anthropic(md_content: str, user_message: str, base_dir: str) -> str:
|
|
global _anthropic_history
|
|
import anthropic
|
|
|
|
try:
|
|
_ensure_anthropic_client()
|
|
|
|
# ----------------------------------------------------------------
|
|
# Build the user turn.
|
|
#
|
|
# Structure the content as two blocks so the large context portion
|
|
# can be cached independently of the user question:
|
|
#
|
|
# [0] context block <- cache_control applied here
|
|
# [1] user question <- not cached (changes every turn)
|
|
#
|
|
# The Anthropic cache anchors at the LAST cache_control marker in
|
|
# the prefix, so everything up to and including the context block
|
|
# will be served from cache on subsequent rounds.
|
|
# ----------------------------------------------------------------
|
|
user_content = [
|
|
{
|
|
"type": "text",
|
|
"text": f"<context>\n{md_content}\n</context>",
|
|
"cache_control": {"type": "ephemeral"},
|
|
},
|
|
{
|
|
"type": "text",
|
|
"text": user_message,
|
|
},
|
|
]
|
|
|
|
_anthropic_history.append({"role": "user", "content": user_content})
|
|
|
|
_append_comms("OUT", "request", {
|
|
"message": f"<context>\n{md_content}\n</context>\n\n{user_message}",
|
|
})
|
|
|
|
for round_idx in range(MAX_TOOL_ROUNDS):
|
|
response = _anthropic_client.messages.create(
|
|
model=_model,
|
|
max_tokens=8096,
|
|
system=[
|
|
{
|
|
"type": "text",
|
|
"text": _ANTHROPIC_SYSTEM,
|
|
"cache_control": {"type": "ephemeral"},
|
|
}
|
|
],
|
|
tools=_anthropic_tools_with_cache(),
|
|
messages=_anthropic_history,
|
|
# Ask the API to return cache token counts
|
|
# betas=["prompt-caching-2024-07-31"],
|
|
# TODO(Claude): betas is not a valid field:
|
|
# ERROR: Messages.create() got an unexpected keyword argument 'betas'
|
|
)
|
|
|
|
_anthropic_history.append({
|
|
"role": "assistant",
|
|
"content": response.content
|
|
})
|
|
|
|
text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text]
|
|
tool_use_blocks = [
|
|
{"id": b.id, "name": b.name, "input": b.input}
|
|
for b in response.content
|
|
if b.type == "tool_use"
|
|
]
|
|
|
|
# Extended usage includes cache fields when the beta header is set
|
|
usage_dict: dict = {}
|
|
if response.usage:
|
|
usage_dict = {
|
|
"input_tokens": response.usage.input_tokens,
|
|
"output_tokens": response.usage.output_tokens,
|
|
}
|
|
# cache fields are present when the beta is active
|
|
cache_creation = getattr(response.usage, "cache_creation_input_tokens", None)
|
|
cache_read = getattr(response.usage, "cache_read_input_tokens", None)
|
|
if cache_creation is not None:
|
|
usage_dict["cache_creation_input_tokens"] = cache_creation
|
|
if cache_read is not None:
|
|
usage_dict["cache_read_input_tokens"] = cache_read
|
|
|
|
_append_comms("IN", "response", {
|
|
"round": round_idx,
|
|
"stop_reason": response.stop_reason,
|
|
"text": "\n".join(text_blocks),
|
|
"tool_calls": tool_use_blocks,
|
|
"usage": usage_dict,
|
|
})
|
|
|
|
if response.stop_reason != "tool_use":
|
|
break
|
|
|
|
tool_results = []
|
|
for block in response.content:
|
|
if block.type == "tool_use" and block.name == TOOL_NAME:
|
|
script = block.input.get("script", "")
|
|
_append_comms("OUT", "tool_call", {
|
|
"name": TOOL_NAME,
|
|
"id": block.id,
|
|
"script": script,
|
|
})
|
|
output = _run_script(script, base_dir)
|
|
_append_comms("IN", "tool_result", {
|
|
"name": TOOL_NAME,
|
|
"id": block.id,
|
|
"output": output,
|
|
})
|
|
tool_results.append({
|
|
"type": "tool_result",
|
|
"tool_use_id": block.id,
|
|
"content": output
|
|
})
|
|
|
|
if not tool_results:
|
|
break
|
|
|
|
_anthropic_history.append({
|
|
"role": "user",
|
|
"content": tool_results
|
|
})
|
|
|
|
_append_comms("OUT", "tool_result_send", {
|
|
"results": [{"tool_use_id": r["tool_use_id"], "content": r["content"]} for r in tool_results],
|
|
})
|
|
|
|
text_parts = [
|
|
block.text
|
|
for block in response.content
|
|
if hasattr(block, "text") and block.text
|
|
]
|
|
return "\n".join(text_parts)
|
|
|
|
except ProviderError:
|
|
raise
|
|
except Exception as exc:
|
|
raise _classify_anthropic_error(exc) from exc
|
|
|
|
# ------------------------------------------------------------------ unified send
|
|
|
|
def send(md_content: str, user_message: str, base_dir: str = ".") -> str:
|
|
if _provider == "gemini":
|
|
return _send_gemini(md_content, user_message, base_dir)
|
|
elif _provider == "anthropic":
|
|
return _send_anthropic(md_content, user_message, base_dir)
|
|
raise ValueError(f"unknown provider: {_provider}")
|