Files
manual_slop/ai_client.py
2026-02-21 20:04:35 -05:00

594 lines
20 KiB
Python

# ai_client.py
import tomllib
import json
import datetime
from pathlib import Path
import file_cache
_provider: str = "gemini"
_model: str = "gemini-2.0-flash"
_gemini_client = None
_gemini_chat = None
_anthropic_client = None
_anthropic_history: list[dict] = []
# Injected by gui.py - called when AI wants to run a command.
# Signature: (script: str, base_dir: str) -> str | None
confirm_and_run_callback = None
# Injected by gui.py - called whenever a comms entry is appended.
# Signature: (entry: dict) -> None
comms_log_callback = None
# Injected by gui.py - called whenever a tool call completes.
# Signature: (script: str, result: str) -> None
tool_log_callback = None
MAX_TOOL_ROUNDS = 5
# Maximum characters per text chunk sent to Anthropic.
# Kept well under the ~200k token API limit.
_ANTHROPIC_CHUNK_SIZE = 180_000
_ANTHROPIC_SYSTEM = (
"You are a helpful coding assistant with access to a PowerShell tool. "
"When asked to create or edit files, prefer targeted edits over full rewrites. "
"Always explain what you are doing before invoking the tool."
)
# ------------------------------------------------------------------ comms log
_comms_log: list[dict] = []
COMMS_CLAMP_CHARS = 300
def _append_comms(direction: str, kind: str, payload: dict):
entry = {
"ts": datetime.datetime.now().strftime("%H:%M:%S"),
"direction": direction,
"kind": kind,
"provider": _provider,
"model": _model,
"payload": payload,
}
_comms_log.append(entry)
if comms_log_callback is not None:
comms_log_callback(entry)
def get_comms_log() -> list[dict]:
return list(_comms_log)
def clear_comms_log():
_comms_log.clear()
def _load_credentials() -> dict:
with open("credentials.toml", "rb") as f:
return tomllib.load(f)
# ------------------------------------------------------------------ provider errors
class ProviderError(Exception):
def __init__(self, kind: str, provider: str, original: Exception):
self.kind = kind
self.provider = provider
self.original = original
super().__init__(str(original))
def ui_message(self) -> str:
labels = {
"quota": "QUOTA EXHAUSTED",
"rate_limit": "RATE LIMITED",
"auth": "AUTH / API KEY ERROR",
"balance": "BALANCE / BILLING ERROR",
"network": "NETWORK / CONNECTION ERROR",
"unknown": "API ERROR",
}
label = labels.get(self.kind, "API ERROR")
return f"[{self.provider.upper()} {label}]\n\n{self.original}"
def _classify_anthropic_error(exc: Exception) -> ProviderError:
try:
import anthropic
if isinstance(exc, anthropic.RateLimitError):
return ProviderError("rate_limit", "anthropic", exc)
if isinstance(exc, anthropic.AuthenticationError):
return ProviderError("auth", "anthropic", exc)
if isinstance(exc, anthropic.PermissionDeniedError):
return ProviderError("auth", "anthropic", exc)
if isinstance(exc, anthropic.APIConnectionError):
return ProviderError("network", "anthropic", exc)
if isinstance(exc, anthropic.APIStatusError):
status = getattr(exc, "status_code", 0)
body = str(exc).lower()
if status == 429:
return ProviderError("rate_limit", "anthropic", exc)
if status in (401, 403):
return ProviderError("auth", "anthropic", exc)
if status == 402:
return ProviderError("balance", "anthropic", exc)
if "credit" in body or "balance" in body or "billing" in body:
return ProviderError("balance", "anthropic", exc)
if "quota" in body or "limit" in body or "exceeded" in body:
return ProviderError("quota", "anthropic", exc)
except ImportError:
pass
return ProviderError("unknown", "anthropic", exc)
def _classify_gemini_error(exc: Exception) -> ProviderError:
body = str(exc).lower()
try:
from google.api_core import exceptions as gac
if isinstance(exc, gac.ResourceExhausted):
return ProviderError("quota", "gemini", exc)
if isinstance(exc, gac.TooManyRequests):
return ProviderError("rate_limit", "gemini", exc)
if isinstance(exc, (gac.Unauthenticated, gac.PermissionDenied)):
return ProviderError("auth", "gemini", exc)
if isinstance(exc, gac.ServiceUnavailable):
return ProviderError("network", "gemini", exc)
except ImportError:
pass
if "429" in body or "quota" in body or "resource exhausted" in body:
return ProviderError("quota", "gemini", exc)
if "rate" in body and "limit" in body:
return ProviderError("rate_limit", "gemini", exc)
if "401" in body or "403" in body or "api key" in body or "unauthenticated" in body:
return ProviderError("auth", "gemini", exc)
if "402" in body or "billing" in body or "balance" in body or "payment" in body:
return ProviderError("balance", "gemini", exc)
if "connection" in body or "timeout" in body or "unreachable" in body:
return ProviderError("network", "gemini", exc)
return ProviderError("unknown", "gemini", exc)
# ------------------------------------------------------------------ provider setup
def set_provider(provider: str, model: str):
global _provider, _model
_provider = provider
_model = model
def reset_session():
global _gemini_client, _gemini_chat
global _anthropic_client, _anthropic_history
_gemini_client = None
_gemini_chat = None
_anthropic_client = None
_anthropic_history = []
file_cache.reset_client()
# ------------------------------------------------------------------ model listing
def list_models(provider: str) -> list[str]:
creds = _load_credentials()
if provider == "gemini":
return _list_gemini_models(creds["gemini"]["api_key"])
elif provider == "anthropic":
return _list_anthropic_models()
return []
def _list_gemini_models(api_key: str) -> list[str]:
from google import genai
try:
client = genai.Client(api_key=api_key)
models = []
for m in client.models.list():
name = m.name
if name.startswith("models/"):
name = name[len("models/"):]
if "gemini" in name.lower():
models.append(name)
return sorted(models)
except Exception as exc:
raise _classify_gemini_error(exc) from exc
def _list_anthropic_models() -> list[str]:
import anthropic
try:
creds = _load_credentials()
client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
models = []
for m in client.models.list():
models.append(m.id)
return sorted(models)
except Exception as exc:
raise _classify_anthropic_error(exc) from exc
# ------------------------------------------------------------------ tool definition
TOOL_NAME = "run_powershell"
_ANTHROPIC_TOOLS = [
{
"name": TOOL_NAME,
"description": (
"Run a PowerShell script within the project base_dir. "
"Use this to create, edit, rename, or delete files and directories. "
"The working directory is set to base_dir automatically. "
"Always prefer targeted edits over full rewrites where possible. "
"stdout and stderr are returned to you as the result."
),
"input_schema": {
"type": "object",
"properties": {
"script": {
"type": "string",
"description": "The PowerShell script to execute."
}
},
"required": ["script"]
},
"cache_control": {"type": "ephemeral"},
}
]
def _gemini_tool_declaration():
from google.genai import types
return types.Tool(
function_declarations=[
types.FunctionDeclaration(
name=TOOL_NAME,
description=(
"Run a PowerShell script within the project base_dir. "
"Use this to create, edit, rename, or delete files and directories. "
"The working directory is set to base_dir automatically. "
"stdout and stderr are returned to you as the result."
),
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"script": types.Schema(
type=types.Type.STRING,
description="The PowerShell script to execute."
)
},
required=["script"]
)
)
]
)
def _run_script(script: str, base_dir: str) -> str:
if confirm_and_run_callback is None:
return "ERROR: no confirmation handler registered"
result = confirm_and_run_callback(script, base_dir)
if result is None:
output = "USER REJECTED: command was not executed"
else:
output = result
if tool_log_callback is not None:
tool_log_callback(script, output)
return output
# ------------------------------------------------------------------ gemini
def _ensure_gemini_client():
global _gemini_client
if _gemini_client is None:
from google import genai
creds = _load_credentials()
_gemini_client = genai.Client(api_key=creds["gemini"]["api_key"])
def _send_gemini(md_content: str, user_message: str, base_dir: str) -> str:
global _gemini_chat
from google import genai
from google.genai import types
try:
_ensure_gemini_client()
if _gemini_chat is None:
_gemini_chat = _gemini_client.chats.create(
model=_model,
config=types.GenerateContentConfig(
tools=[_gemini_tool_declaration()]
)
)
full_message = f"<context>\n{md_content}\n</context>\n\n{user_message}"
_append_comms("OUT", "request", {
"message": f"[context {len(md_content)} chars + user message {len(user_message)} chars]",
})
response = _gemini_chat.send_message(full_message)
for round_idx in range(MAX_TOOL_ROUNDS):
text_parts_raw = [
part.text
for candidate in response.candidates
for part in candidate.content.parts
if hasattr(part, "text") and part.text
]
tool_calls = [
part.function_call
for candidate in response.candidates
for part in candidate.content.parts
if part.function_call is not None
]
_append_comms("IN", "response", {
"round": round_idx,
"text": "\n".join(text_parts_raw),
"tool_calls": [{"name": fc.name, "args": dict(fc.args)} for fc in tool_calls],
})
if not tool_calls:
break
function_responses = []
for fc in tool_calls:
if fc.name == TOOL_NAME:
script = fc.args.get("script", "")
_append_comms("OUT", "tool_call", {
"name": TOOL_NAME,
"script": script,
})
output = _run_script(script, base_dir)
_append_comms("IN", "tool_result", {
"name": TOOL_NAME,
"output": output,
})
function_responses.append(
types.Part.from_function_response(
name=TOOL_NAME,
response={"output": output}
)
)
if not function_responses:
break
response = _gemini_chat.send_message(function_responses)
text_parts = [
part.text
for candidate in response.candidates
for part in candidate.content.parts
if hasattr(part, "text") and part.text
]
return "\n".join(text_parts)
except ProviderError:
raise
except Exception as exc:
raise _classify_gemini_error(exc) from exc
# ------------------------------------------------------------------ anthropic
def _ensure_anthropic_client():
global _anthropic_client
if _anthropic_client is None:
import anthropic
creds = _load_credentials()
_anthropic_client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
def _chunk_text(text: str, chunk_size: int) -> list[str]:
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
def _build_chunked_context_blocks(md_content: str) -> list[dict]:
"""
Split md_content into <=_ANTHROPIC_CHUNK_SIZE char chunks.
cache_control:ephemeral is placed only on the LAST block so the whole
prefix is cached as one unit.
"""
chunks = _chunk_text(md_content, _ANTHROPIC_CHUNK_SIZE)
blocks = []
for i, chunk in enumerate(chunks):
block: dict = {"type": "text", "text": chunk}
if i == len(chunks) - 1:
block["cache_control"] = {"type": "ephemeral"}
blocks.append(block)
return blocks
def _strip_cache_controls(history: list[dict]):
"""
Remove cache_control from all content blocks in message history.
Anthropic allows max 4 cache_control blocks total across system + tools +
messages. We reserve those slots for the stable system/tools prefix and
the current turn's context block, so all older history entries must be clean.
"""
for msg in history:
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
block.pop("cache_control", None)
def _repair_anthropic_history(history: list[dict]):
"""
If history ends with an assistant message that contains tool_use blocks
without a following user tool_result message, append a synthetic tool_result
message so the history is valid before the next request.
"""
if not history:
return
last = history[-1]
if last.get("role") != "assistant":
return
content = last.get("content", [])
tool_use_ids = []
for block in content:
if isinstance(block, dict):
if block.get("type") == "tool_use":
tool_use_ids.append(block["id"])
else:
if getattr(block, "type", None) == "tool_use":
tool_use_ids.append(block.id)
if not tool_use_ids:
return
history.append({
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": tid,
"content": "Tool call was not completed (session interrupted).",
}
for tid in tool_use_ids
],
})
def _send_anthropic(md_content: str, user_message: str, base_dir: str) -> str:
try:
_ensure_anthropic_client()
context_blocks = _build_chunked_context_blocks(md_content)
user_content = context_blocks + [
{"type": "text", "text": user_message}
]
_strip_cache_controls(_anthropic_history)
_repair_anthropic_history(_anthropic_history)
_anthropic_history.append({"role": "user", "content": user_content})
n_chunks = len(context_blocks)
_append_comms("OUT", "request", {
"message": (
f"[{n_chunks} chunk(s), {len(md_content)} chars context] "
f"{user_message[:200]}{'...' if len(user_message) > 200 else ''}"
),
})
for round_idx in range(MAX_TOOL_ROUNDS):
response = _anthropic_client.messages.create(
model=_model,
max_tokens=8096,
system=[
{
"type": "text",
"text": _ANTHROPIC_SYSTEM,
"cache_control": {"type": "ephemeral"},
}
],
tools=_ANTHROPIC_TOOLS,
messages=_anthropic_history,
)
_anthropic_history.append({
"role": "assistant",
"content": response.content,
})
text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text]
tool_use_blocks = [
{"id": b.id, "name": b.name, "input": b.input}
for b in response.content
if b.type == "tool_use"
]
usage_dict: dict = {}
if response.usage:
usage_dict["input_tokens"] = response.usage.input_tokens
usage_dict["output_tokens"] = response.usage.output_tokens
cache_creation = getattr(response.usage, "cache_creation_input_tokens", None)
cache_read = getattr(response.usage, "cache_read_input_tokens", None)
if cache_creation is not None:
usage_dict["cache_creation_input_tokens"] = cache_creation
if cache_read is not None:
usage_dict["cache_read_input_tokens"] = cache_read
_append_comms("IN", "response", {
"round": round_idx,
"stop_reason": response.stop_reason,
"text": "\n".join(text_blocks),
"tool_calls": tool_use_blocks,
"usage": usage_dict,
})
if response.stop_reason != "tool_use":
break
tool_results = []
for block in response.content:
if block.type == "tool_use" and block.name == TOOL_NAME:
script = block.input.get("script", "")
_append_comms("OUT", "tool_call", {
"name": TOOL_NAME,
"id": block.id,
"script": script,
})
output = _run_script(script, base_dir)
_append_comms("IN", "tool_result", {
"name": TOOL_NAME,
"id": block.id,
"output": output,
})
tool_results.append({
"type": "tool_result",
"tool_use_id": block.id,
"content": output,
})
if not tool_results:
break
_anthropic_history.append({
"role": "user",
"content": tool_results,
})
_append_comms("OUT", "tool_result_send", {
"results": [
{"tool_use_id": r["tool_use_id"], "content": r["content"]}
for r in tool_results
],
})
text_parts = [
block.text
for block in response.content
if hasattr(block, "text") and block.text
]
return "\n".join(text_parts)
except ProviderError:
raise
except Exception as exc:
raise _classify_anthropic_error(exc) from exc
# ------------------------------------------------------------------ unified send
def send(
md_content: str,
user_message: str,
base_dir: str = ".",
) -> str:
"""
Send a message to the active provider.
md_content : aggregated markdown string from aggregate.run()
user_message: the user question / instruction
base_dir : project base directory (for PowerShell tool calls)
"""
if _provider == "gemini":
return _send_gemini(md_content, user_message, base_dir)
elif _provider == "anthropic":
return _send_anthropic(md_content, user_message, base_dir)
raise ValueError(f"unknown provider: {_provider}")