# ai_client.py import tomllib import json import datetime from pathlib import Path import file_cache _provider: str = "gemini" _model: str = "gemini-2.0-flash" _gemini_client = None _gemini_chat = None _anthropic_client = None _anthropic_history: list[dict] = [] # Injected by gui.py - called when AI wants to run a command. # Signature: (script: str, base_dir: str) -> str | None confirm_and_run_callback = None # Injected by gui.py - called whenever a comms entry is appended. # Signature: (entry: dict) -> None comms_log_callback = None # Injected by gui.py - called whenever a tool call completes. # Signature: (script: str, result: str) -> None tool_log_callback = None MAX_TOOL_ROUNDS = 5 # Maximum characters per text chunk sent to Anthropic. # Kept well under the ~200k token API limit. _ANTHROPIC_CHUNK_SIZE = 180_000 _ANTHROPIC_SYSTEM = ( "You are a helpful coding assistant with access to a PowerShell tool. " "When asked to create or edit files, prefer targeted edits over full rewrites. " "Always explain what you are doing before invoking the tool." ) # ------------------------------------------------------------------ comms log _comms_log: list[dict] = [] COMMS_CLAMP_CHARS = 300 def _append_comms(direction: str, kind: str, payload: dict): entry = { "ts": datetime.datetime.now().strftime("%H:%M:%S"), "direction": direction, "kind": kind, "provider": _provider, "model": _model, "payload": payload, } _comms_log.append(entry) if comms_log_callback is not None: comms_log_callback(entry) def get_comms_log() -> list[dict]: return list(_comms_log) def clear_comms_log(): _comms_log.clear() def _load_credentials() -> dict: with open("credentials.toml", "rb") as f: return tomllib.load(f) # ------------------------------------------------------------------ provider errors class ProviderError(Exception): def __init__(self, kind: str, provider: str, original: Exception): self.kind = kind self.provider = provider self.original = original super().__init__(str(original)) def ui_message(self) -> str: labels = { "quota": "QUOTA EXHAUSTED", "rate_limit": "RATE LIMITED", "auth": "AUTH / API KEY ERROR", "balance": "BALANCE / BILLING ERROR", "network": "NETWORK / CONNECTION ERROR", "unknown": "API ERROR", } label = labels.get(self.kind, "API ERROR") return f"[{self.provider.upper()} {label}]\n\n{self.original}" def _classify_anthropic_error(exc: Exception) -> ProviderError: try: import anthropic if isinstance(exc, anthropic.RateLimitError): return ProviderError("rate_limit", "anthropic", exc) if isinstance(exc, anthropic.AuthenticationError): return ProviderError("auth", "anthropic", exc) if isinstance(exc, anthropic.PermissionDeniedError): return ProviderError("auth", "anthropic", exc) if isinstance(exc, anthropic.APIConnectionError): return ProviderError("network", "anthropic", exc) if isinstance(exc, anthropic.APIStatusError): status = getattr(exc, "status_code", 0) body = str(exc).lower() if status == 429: return ProviderError("rate_limit", "anthropic", exc) if status in (401, 403): return ProviderError("auth", "anthropic", exc) if status == 402: return ProviderError("balance", "anthropic", exc) if "credit" in body or "balance" in body or "billing" in body: return ProviderError("balance", "anthropic", exc) if "quota" in body or "limit" in body or "exceeded" in body: return ProviderError("quota", "anthropic", exc) except ImportError: pass return ProviderError("unknown", "anthropic", exc) def _classify_gemini_error(exc: Exception) -> ProviderError: body = str(exc).lower() try: from google.api_core import exceptions as gac if isinstance(exc, gac.ResourceExhausted): return ProviderError("quota", "gemini", exc) if isinstance(exc, gac.TooManyRequests): return ProviderError("rate_limit", "gemini", exc) if isinstance(exc, (gac.Unauthenticated, gac.PermissionDenied)): return ProviderError("auth", "gemini", exc) if isinstance(exc, gac.ServiceUnavailable): return ProviderError("network", "gemini", exc) except ImportError: pass if "429" in body or "quota" in body or "resource exhausted" in body: return ProviderError("quota", "gemini", exc) if "rate" in body and "limit" in body: return ProviderError("rate_limit", "gemini", exc) if "401" in body or "403" in body or "api key" in body or "unauthenticated" in body: return ProviderError("auth", "gemini", exc) if "402" in body or "billing" in body or "balance" in body or "payment" in body: return ProviderError("balance", "gemini", exc) if "connection" in body or "timeout" in body or "unreachable" in body: return ProviderError("network", "gemini", exc) return ProviderError("unknown", "gemini", exc) # ------------------------------------------------------------------ provider setup def set_provider(provider: str, model: str): global _provider, _model _provider = provider _model = model def reset_session(): global _gemini_client, _gemini_chat global _anthropic_client, _anthropic_history _gemini_client = None _gemini_chat = None _anthropic_client = None _anthropic_history = [] file_cache.reset_client() # ------------------------------------------------------------------ model listing def list_models(provider: str) -> list[str]: creds = _load_credentials() if provider == "gemini": return _list_gemini_models(creds["gemini"]["api_key"]) elif provider == "anthropic": return _list_anthropic_models() return [] def _list_gemini_models(api_key: str) -> list[str]: from google import genai try: client = genai.Client(api_key=api_key) models = [] for m in client.models.list(): name = m.name if name.startswith("models/"): name = name[len("models/"):] if "gemini" in name.lower(): models.append(name) return sorted(models) except Exception as exc: raise _classify_gemini_error(exc) from exc def _list_anthropic_models() -> list[str]: import anthropic try: creds = _load_credentials() client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"]) models = [] for m in client.models.list(): models.append(m.id) return sorted(models) except Exception as exc: raise _classify_anthropic_error(exc) from exc # ------------------------------------------------------------------ tool definition TOOL_NAME = "run_powershell" _ANTHROPIC_TOOLS = [ { "name": TOOL_NAME, "description": ( "Run a PowerShell script within the project base_dir. " "Use this to create, edit, rename, or delete files and directories. " "The working directory is set to base_dir automatically. " "Always prefer targeted edits over full rewrites where possible. " "stdout and stderr are returned to you as the result." ), "input_schema": { "type": "object", "properties": { "script": { "type": "string", "description": "The PowerShell script to execute." } }, "required": ["script"] }, "cache_control": {"type": "ephemeral"}, } ] def _gemini_tool_declaration(): from google.genai import types return types.Tool( function_declarations=[ types.FunctionDeclaration( name=TOOL_NAME, description=( "Run a PowerShell script within the project base_dir. " "Use this to create, edit, rename, or delete files and directories. " "The working directory is set to base_dir automatically. " "stdout and stderr are returned to you as the result." ), parameters=types.Schema( type=types.Type.OBJECT, properties={ "script": types.Schema( type=types.Type.STRING, description="The PowerShell script to execute." ) }, required=["script"] ) ) ] ) def _run_script(script: str, base_dir: str) -> str: if confirm_and_run_callback is None: return "ERROR: no confirmation handler registered" result = confirm_and_run_callback(script, base_dir) if result is None: output = "USER REJECTED: command was not executed" else: output = result if tool_log_callback is not None: tool_log_callback(script, output) return output # ------------------------------------------------------------------ content block serialisation def _content_block_to_dict(block) -> dict: """ Convert an Anthropic SDK content block object to a plain dict. This ensures history entries are always JSON-serialisable dicts, not opaque SDK objects that may fail on re-serialisation. """ if isinstance(block, dict): return block if hasattr(block, "model_dump"): return block.model_dump() if hasattr(block, "to_dict"): return block.to_dict() # Fallback: manually construct based on type block_type = getattr(block, "type", None) if block_type == "text": return {"type": "text", "text": block.text} if block_type == "tool_use": return {"type": "tool_use", "id": block.id, "name": block.name, "input": block.input} return {"type": "text", "text": str(block)} # ------------------------------------------------------------------ gemini def _ensure_gemini_client(): global _gemini_client if _gemini_client is None: from google import genai creds = _load_credentials() _gemini_client = genai.Client(api_key=creds["gemini"]["api_key"]) def _send_gemini(md_content: str, user_message: str, base_dir: str) -> str: global _gemini_chat from google import genai from google.genai import types try: _ensure_gemini_client() if _gemini_chat is None: _gemini_chat = _gemini_client.chats.create( model=_model, config=types.GenerateContentConfig( tools=[_gemini_tool_declaration()] ) ) full_message = f"\n{md_content}\n\n\n{user_message}" _append_comms("OUT", "request", { "message": f"[context {len(md_content)} chars + user message {len(user_message)} chars]", }) response = _gemini_chat.send_message(full_message) for round_idx in range(MAX_TOOL_ROUNDS): text_parts_raw = [ part.text for candidate in response.candidates for part in candidate.content.parts if hasattr(part, "text") and part.text ] tool_calls = [ part.function_call for candidate in response.candidates for part in candidate.content.parts if hasattr(part, "function_call") and part.function_call is not None ] _append_comms("IN", "response", { "round": round_idx, "text": "\n".join(text_parts_raw), "tool_calls": [{"name": fc.name, "args": dict(fc.args)} for fc in tool_calls], }) if not tool_calls: break function_responses = [] for fc in tool_calls: if fc.name == TOOL_NAME: script = fc.args.get("script", "") _append_comms("OUT", "tool_call", { "name": TOOL_NAME, "script": script, }) output = _run_script(script, base_dir) _append_comms("IN", "tool_result", { "name": TOOL_NAME, "output": output, }) function_responses.append( types.Part.from_function_response( name=TOOL_NAME, response={"output": output} ) ) if not function_responses: break response = _gemini_chat.send_message(function_responses) text_parts = [ part.text for candidate in response.candidates for part in candidate.content.parts if hasattr(part, "text") and part.text ] return "\n".join(text_parts) except ProviderError: raise except Exception as exc: raise _classify_gemini_error(exc) from exc # ------------------------------------------------------------------ anthropic def _ensure_anthropic_client(): global _anthropic_client if _anthropic_client is None: import anthropic creds = _load_credentials() _anthropic_client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"]) def _chunk_text(text: str, chunk_size: int) -> list[str]: return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] def _build_chunked_context_blocks(md_content: str) -> list[dict]: """ Split md_content into <=_ANTHROPIC_CHUNK_SIZE char chunks. cache_control:ephemeral is placed only on the LAST block so the whole prefix is cached as one unit. """ chunks = _chunk_text(md_content, _ANTHROPIC_CHUNK_SIZE) blocks = [] for i, chunk in enumerate(chunks): block: dict = {"type": "text", "text": chunk} if i == len(chunks) - 1: block["cache_control"] = {"type": "ephemeral"} blocks.append(block) return blocks def _strip_cache_controls(history: list[dict]): """ Remove cache_control from all content blocks in message history. Anthropic allows max 4 cache_control blocks total across system + tools + messages. We reserve those slots for the stable system/tools prefix and the current turn's context block, so all older history entries must be clean. """ for msg in history: content = msg.get("content") if isinstance(content, list): for block in content: if isinstance(block, dict): block.pop("cache_control", None) def _repair_anthropic_history(history: list[dict]): """ If history ends with an assistant message that contains tool_use blocks without a following user tool_result message, append a synthetic tool_result message so the history is valid before the next request. """ if not history: return last = history[-1] if last.get("role") != "assistant": return content = last.get("content", []) tool_use_ids = [] for block in content: if isinstance(block, dict): if block.get("type") == "tool_use": tool_use_ids.append(block["id"]) if not tool_use_ids: return history.append({ "role": "user", "content": [ { "type": "tool_result", "tool_use_id": tid, "content": "Tool call was not completed (session interrupted).", } for tid in tool_use_ids ], }) def _send_anthropic(md_content: str, user_message: str, base_dir: str) -> str: try: _ensure_anthropic_client() context_blocks = _build_chunked_context_blocks(md_content) user_content = context_blocks + [ {"type": "text", "text": user_message} ] _strip_cache_controls(_anthropic_history) _repair_anthropic_history(_anthropic_history) _anthropic_history.append({"role": "user", "content": user_content}) n_chunks = len(context_blocks) _append_comms("OUT", "request", { "message": ( f"[{n_chunks} chunk(s), {len(md_content)} chars context] " f"{user_message[:200]}{'...' if len(user_message) > 200 else ''}" ), }) for round_idx in range(MAX_TOOL_ROUNDS): response = _anthropic_client.messages.create( model=_model, max_tokens=8096, system=[ { "type": "text", "text": _ANTHROPIC_SYSTEM, "cache_control": {"type": "ephemeral"}, } ], tools=_ANTHROPIC_TOOLS, messages=_anthropic_history, ) # Convert SDK content block objects to plain dicts before storing in history serialised_content = [_content_block_to_dict(b) for b in response.content] _anthropic_history.append({ "role": "assistant", "content": serialised_content, }) text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text] tool_use_blocks = [ {"id": b.id, "name": b.name, "input": b.input} for b in response.content if getattr(b, "type", None) == "tool_use" ] usage_dict: dict = {} if response.usage: usage_dict["input_tokens"] = response.usage.input_tokens usage_dict["output_tokens"] = response.usage.output_tokens cache_creation = getattr(response.usage, "cache_creation_input_tokens", None) cache_read = getattr(response.usage, "cache_read_input_tokens", None) if cache_creation is not None: usage_dict["cache_creation_input_tokens"] = cache_creation if cache_read is not None: usage_dict["cache_read_input_tokens"] = cache_read _append_comms("IN", "response", { "round": round_idx, "stop_reason": response.stop_reason, "text": "\n".join(text_blocks), "tool_calls": tool_use_blocks, "usage": usage_dict, }) if response.stop_reason != "tool_use": break tool_results = [] for block in response.content: if getattr(block, "type", None) == "tool_use" and getattr(block, "name", None) == TOOL_NAME: script = block.input.get("script", "") _append_comms("OUT", "tool_call", { "name": TOOL_NAME, "id": block.id, "script": script, }) output = _run_script(script, base_dir) _append_comms("IN", "tool_result", { "name": TOOL_NAME, "id": block.id, "output": output, }) tool_results.append({ "type": "tool_result", "tool_use_id": block.id, "content": output, }) if not tool_results: break _anthropic_history.append({ "role": "user", "content": tool_results, }) _append_comms("OUT", "tool_result_send", { "results": [ {"tool_use_id": r["tool_use_id"], "content": r["content"]} for r in tool_results ], }) text_parts = [ block.text for block in response.content if hasattr(block, "text") and block.text ] return "\n".join(text_parts) except ProviderError: raise except Exception as exc: raise _classify_anthropic_error(exc) from exc # ------------------------------------------------------------------ unified send def send( md_content: str, user_message: str, base_dir: str = ".", ) -> str: """ Send a message to the active provider. md_content : aggregated markdown string from aggregate.run() user_message: the user question / instruction base_dir : project base directory (for PowerShell tool calls) """ if _provider == "gemini": return _send_gemini(md_content, user_message, base_dir) elif _provider == "anthropic": return _send_anthropic(md_content, user_message, base_dir) raise ValueError(f"unknown provider: {_provider}")