# ai_client.py import tomllib import json import datetime from pathlib import Path _provider: str = "gemini" _model: str = "gemini-2.0-flash" _gemini_client = None _gemini_chat = None _anthropic_client = None _anthropic_history: list[dict] = [] # Injected by gui.py - called when AI wants to run a command. # Signature: (script: str) -> str | None # Returns the output string if approved, None if rejected. confirm_and_run_callback = None # Injected by gui.py - called whenever a comms entry is appended. # Signature: (entry: dict) -> None comms_log_callback = None MAX_TOOL_ROUNDS = 5 # ------------------------------------------------------------------ comms log _comms_log: list[dict] = [] MAX_FIELD_CHARS = 400 # beyond this we show a truncated preview in the UI def _clamp(value, max_chars: int = MAX_FIELD_CHARS) -> tuple[str, bool]: """Return (display_str, was_truncated).""" if isinstance(value, (dict, list)): s = json.dumps(value, ensure_ascii=False, indent=2) else: s = str(value) if len(s) > max_chars: return s[:max_chars], True return s, False def _append_comms(direction: str, kind: str, payload: dict): """ direction : "OUT" | "IN" kind : "request" | "response" | "tool_call" | "tool_result" payload : raw dict describing the event """ entry = { "ts": datetime.datetime.now().strftime("%H:%M:%S"), "direction": direction, "kind": kind, "provider": _provider, "model": _model, "payload": payload, } _comms_log.append(entry) if comms_log_callback is not None: comms_log_callback(entry) def get_comms_log() -> list[dict]: return list(_comms_log) def clear_comms_log(): _comms_log.clear() def _load_credentials() -> dict: with open("credentials.toml", "rb") as f: return tomllib.load(f) # ------------------------------------------------------------------ provider errors class ProviderError(Exception): """ Raised when the upstream API returns a hard error we want to surface distinctly in the UI (quota, rate-limit, auth, balance, etc.). Attributes ---------- kind : str One of: "quota", "rate_limit", "auth", "balance", "network", "unknown" provider : str "gemini" or "anthropic" original : Exception The underlying SDK exception. """ def __init__(self, kind: str, provider: str, original: Exception): self.kind = kind self.provider = provider self.original = original super().__init__(str(original)) # Human-readable banner shown in the Response panel def ui_message(self) -> str: labels = { "quota": "QUOTA EXHAUSTED", "rate_limit": "RATE LIMITED", "auth": "AUTH / API KEY ERROR", "balance": "BALANCE / BILLING ERROR", "network": "NETWORK / CONNECTION ERROR", "unknown": "API ERROR", } label = labels.get(self.kind, "API ERROR") return f"[{self.provider.upper()} {label}]\n\n{self.original}" def _classify_anthropic_error(exc: Exception) -> ProviderError: """Map an anthropic SDK exception to a ProviderError.""" try: import anthropic if isinstance(exc, anthropic.RateLimitError): return ProviderError("rate_limit", "anthropic", exc) if isinstance(exc, anthropic.AuthenticationError): return ProviderError("auth", "anthropic", exc) if isinstance(exc, anthropic.PermissionDeniedError): return ProviderError("auth", "anthropic", exc) if isinstance(exc, anthropic.APIConnectionError): return ProviderError("network", "anthropic", exc) if isinstance(exc, anthropic.APIStatusError): status = getattr(exc, "status_code", 0) body = str(exc).lower() if status == 429: return ProviderError("rate_limit", "anthropic", exc) if status in (401, 403): return ProviderError("auth", "anthropic", exc) if status == 402: return ProviderError("balance", "anthropic", exc) # Anthropic puts credit-balance errors in the body at 400 if "credit" in body or "balance" in body or "billing" in body: return ProviderError("balance", "anthropic", exc) if "quota" in body or "limit" in body or "exceeded" in body: return ProviderError("quota", "anthropic", exc) except ImportError: pass return ProviderError("unknown", "anthropic", exc) def _classify_gemini_error(exc: Exception) -> ProviderError: """Map a google-genai SDK exception to a ProviderError.""" body = str(exc).lower() # google-genai surfaces HTTP errors as google.api_core exceptions or # google.genai exceptions; inspect the message text as a reliable fallback. try: from google.api_core import exceptions as gac if isinstance(exc, gac.ResourceExhausted): return ProviderError("quota", "gemini", exc) if isinstance(exc, gac.TooManyRequests): return ProviderError("rate_limit", "gemini", exc) if isinstance(exc, (gac.Unauthenticated, gac.PermissionDenied)): return ProviderError("auth", "gemini", exc) if isinstance(exc, gac.ServiceUnavailable): return ProviderError("network", "gemini", exc) except ImportError: pass # Fallback: parse status code / message string if "429" in body or "quota" in body or "resource exhausted" in body: return ProviderError("quota", "gemini", exc) if "rate" in body and "limit" in body: return ProviderError("rate_limit", "gemini", exc) if "401" in body or "403" in body or "api key" in body or "unauthenticated" in body: return ProviderError("auth", "gemini", exc) if "402" in body or "billing" in body or "balance" in body or "payment" in body: return ProviderError("balance", "gemini", exc) if "connection" in body or "timeout" in body or "unreachable" in body: return ProviderError("network", "gemini", exc) return ProviderError("unknown", "gemini", exc) # ------------------------------------------------------------------ provider setup def set_provider(provider: str, model: str): global _provider, _model _provider = provider _model = model def reset_session(): global _gemini_client, _gemini_chat global _anthropic_client, _anthropic_history _gemini_client = None _gemini_chat = None _anthropic_client = None _anthropic_history = [] # ------------------------------------------------------------------ model listing def list_models(provider: str) -> list[str]: creds = _load_credentials() if provider == "gemini": return _list_gemini_models(creds["gemini"]["api_key"]) elif provider == "anthropic": return _list_anthropic_models() return [] def _list_gemini_models(api_key: str) -> list[str]: from google import genai try: client = genai.Client(api_key=api_key) models = [] for m in client.models.list(): name = m.name if name.startswith("models/"): name = name[len("models/"):] if "gemini" in name.lower(): models.append(name) return sorted(models) except Exception as exc: raise _classify_gemini_error(exc) from exc def _list_anthropic_models() -> list[str]: import anthropic try: creds = _load_credentials() client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"]) models = [] for m in client.models.list(): models.append(m.id) return sorted(models) except Exception as exc: raise _classify_anthropic_error(exc) from exc # --------------------------------------------------------- tool definition TOOL_NAME = "run_powershell" _ANTHROPIC_TOOLS = [ { "name": TOOL_NAME, "description": ( "Run a PowerShell script within the project base_dir. " "Use this to create, edit, rename, or delete files and directories. " "The working directory is set to base_dir automatically. " "Always prefer targeted edits over full rewrites where possible. " "stdout and stderr are returned to you as the result." ), "input_schema": { "type": "object", "properties": { "script": { "type": "string", "description": "The PowerShell script to execute." } }, "required": ["script"] } } ] def _gemini_tool_declaration(): from google.genai import types return types.Tool( function_declarations=[ types.FunctionDeclaration( name=TOOL_NAME, description=( "Run a PowerShell script within the project base_dir. " "Use this to create, edit, rename, or delete files and directories. " "The working directory is set to base_dir automatically. " "stdout and stderr are returned to you as the result." ), parameters=types.Schema( type=types.Type.OBJECT, properties={ "script": types.Schema( type=types.Type.STRING, description="The PowerShell script to execute." ) }, required=["script"] ) ) ] ) def _run_script(script: str, base_dir: str) -> str: """ Delegate to the GUI confirmation callback. Returns result string (stdout/stderr) or a rejection message. """ if confirm_and_run_callback is None: return "ERROR: no confirmation handler registered" result = confirm_and_run_callback(script, base_dir) if result is None: return "USER REJECTED: command was not executed" return result # ------------------------------------------------------------------ gemini def _ensure_gemini_client(): global _gemini_client if _gemini_client is None: from google import genai creds = _load_credentials() _gemini_client = genai.Client(api_key=creds["gemini"]["api_key"]) def _send_gemini(md_content: str, user_message: str, base_dir: str) -> str: global _gemini_chat from google import genai from google.genai import types try: _ensure_gemini_client() if _gemini_chat is None: _gemini_chat = _gemini_client.chats.create( model=_model, config=types.GenerateContentConfig( tools=[_gemini_tool_declaration()] ) ) full_message = f"\n{md_content}\n\n\n{user_message}" _append_comms("OUT", "request", { "message": full_message, }) response = _gemini_chat.send_message(full_message) for round_idx in range(MAX_TOOL_ROUNDS): # Log the raw response candidates as text summary text_parts_raw = [ part.text for candidate in response.candidates for part in candidate.content.parts if hasattr(part, "text") and part.text ] tool_calls = [ part.function_call for candidate in response.candidates for part in candidate.content.parts if part.function_call is not None ] _append_comms("IN", "response", { "round": round_idx, "text": "\n".join(text_parts_raw), "tool_calls": [{"name": fc.name, "args": dict(fc.args)} for fc in tool_calls], }) if not tool_calls: break function_responses = [] for fc in tool_calls: if fc.name == TOOL_NAME: script = fc.args.get("script", "") _append_comms("OUT", "tool_call", { "name": TOOL_NAME, "script": script, }) output = _run_script(script, base_dir) _append_comms("IN", "tool_result", { "name": TOOL_NAME, "output": output, }) function_responses.append( types.Part.from_function_response( name=TOOL_NAME, response={"output": output} ) ) if not function_responses: break response = _gemini_chat.send_message(function_responses) text_parts = [ part.text for candidate in response.candidates for part in candidate.content.parts if hasattr(part, "text") and part.text ] return "\n".join(text_parts) except ProviderError: raise except Exception as exc: raise _classify_gemini_error(exc) from exc # ------------------------------------------------------------------ anthropic def _ensure_anthropic_client(): global _anthropic_client if _anthropic_client is None: import anthropic creds = _load_credentials() _anthropic_client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"]) def _send_anthropic(md_content: str, user_message: str, base_dir: str) -> str: global _anthropic_history import anthropic try: _ensure_anthropic_client() full_message = f"\n{md_content}\n\n\n{user_message}" _anthropic_history.append({"role": "user", "content": full_message}) _append_comms("OUT", "request", { "message": full_message, }) for round_idx in range(MAX_TOOL_ROUNDS): response = _anthropic_client.messages.create( model=_model, max_tokens=8096, tools=_ANTHROPIC_TOOLS, messages=_anthropic_history ) _anthropic_history.append({ "role": "assistant", "content": response.content }) # Summarise the response content for the log text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text] tool_use_blocks = [ {"id": b.id, "name": b.name, "input": b.input} for b in response.content if b.type == "tool_use" ] _append_comms("IN", "response", { "round": round_idx, "stop_reason": response.stop_reason, "text": "\n".join(text_blocks), "tool_calls": tool_use_blocks, "usage": { "input_tokens": response.usage.input_tokens, "output_tokens": response.usage.output_tokens, } if response.usage else {}, }) if response.stop_reason != "tool_use": break tool_results = [] for block in response.content: if block.type == "tool_use" and block.name == TOOL_NAME: script = block.input.get("script", "") _append_comms("OUT", "tool_call", { "name": TOOL_NAME, "id": block.id, "script": script, }) output = _run_script(script, base_dir) _append_comms("IN", "tool_result", { "name": TOOL_NAME, "id": block.id, "output": output, }) tool_results.append({ "type": "tool_result", "tool_use_id": block.id, "content": output }) if not tool_results: break _anthropic_history.append({ "role": "user", "content": tool_results }) _append_comms("OUT", "tool_result_send", { "results": [{"tool_use_id": r["tool_use_id"], "content": r["content"]} for r in tool_results], }) text_parts = [ block.text for block in response.content if hasattr(block, "text") and block.text ] return "\n".join(text_parts) except ProviderError: raise except Exception as exc: raise _classify_anthropic_error(exc) from exc # ------------------------------------------------------------------ unified send def send(md_content: str, user_message: str, base_dir: str = ".") -> str: if _provider == "gemini": return _send_gemini(md_content, user_message, base_dir) elif _provider == "anthropic": return _send_anthropic(md_content, user_message, base_dir) raise ValueError(f"unknown provider: {_provider}")