From 34b1349c4f61a3e5ca4d834636670bc13140a998 Mon Sep 17 00:00:00 2001 From: Ed_ Date: Wed, 13 May 2026 19:06:33 -0400 Subject: [PATCH] WIP: cleaning up ai_client.py --- conductor/workflow.md | 1 - src/ai_client.py | 945 ++++++++++++++++++++++-------------------- 2 files changed, 493 insertions(+), 453 deletions(-) diff --git a/conductor/workflow.md b/conductor/workflow.md index 55af969..f74b726 100644 --- a/conductor/workflow.md +++ b/conductor/workflow.md @@ -6,7 +6,6 @@ - **1-space indentation** for ALL Python code (NO EXCEPTIONS) - **CRLF line endings** on Windows -- Use `./scripts/ai_style_formatter.py` for formatting validation - **NO COMMENTS** unless explicitly requested - Type hints required for all public functions - **ImGui Defer Patterns:** Use `imscope` context managers or `_render_window_if_open` dispatch helpers to prevent resource leaks and keep the main loop flat. See `conductor/code_styleguides/python.md` for details. diff --git a/src/ai_client.py b/src/ai_client.py index fd0b7a4..9d1ea30 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -12,31 +12,33 @@ For Gemini: injects the initial context directly into system_instruction during chat creation to avoid massive history bloat. """ # ai_client.py -import tomllib +import anthropic +from google import genai +from google.genai import types import asyncio -import json -import sys -import time import datetime -from src import performance_monitor -import hashlib import difflib -import threading -import requests # type: ignore[import-untyped] -from typing import Optional, Callable, Any, List, Union, cast, Iterable +import hashlib +import json import os +import requests # type: ignore[import-untyped] +import sys +import threading +import time +import tomllib +from collections import deque +from typing import Optional, Callable, Any, List, Union, cast, Iterable from pathlib import Path +from src.events import EventEmitter from src import project_manager from src import file_cache from src import mcp_client from src import mma_prompts +from src import performance_monitor +from src import project_manager from src.tool_bias import ToolBiasEngine from src.models import ToolPreset, BiasProfile, Tool -import anthropic from src.gemini_cli_adapter import GeminiCliAdapter as GeminiCliAdapter -from google import genai -from google.genai import types -from src.events import EventEmitter _provider: str = "gemini" _model: str = "gemini-2.5-flash-lite" @@ -49,8 +51,6 @@ _history_trunc_limit: int = 8000 # Global event emitter for API lifecycle events events: EventEmitter = EventEmitter() -_ai_proxy = None - class ProviderError(Exception): def __init__(self, kind: str, provider: str, original: Exception) -> None: self.kind = kind @@ -70,16 +70,7 @@ class ProviderError(Exception): label = labels.get(self.kind, "API ERROR") return f"[{self.provider.upper()} {label}]\n\n{self.original}" -def _get_proxy(): - global _ai_proxy - if _ai_proxy is None and os.environ.get("AI_SERVER_ENABLED"): - try: - from src.ai_client_proxy import AIProxyClient - _ai_proxy = AIProxyClient() - _ai_proxy.start_server() - except Exception: - _ai_proxy = None - return _ai_proxy +#region: Provider Configuration def set_model_params(temp: float, max_tok: int, trunc_limit: int = 8000, top_p: float = 1.0) -> None: """ @@ -199,6 +190,10 @@ _base_system_prompt_override: str = "" _use_default_base_system_prompt: bool = True _project_context_marker: str = "" +#endregion: Provider Configuration + +#region: System Prompt Management + def set_custom_system_prompt(prompt: str) -> None: """ Sets a custom system prompt to be combined with the default instructions. @@ -255,12 +250,14 @@ def get_combined_system_prompt(preset: Optional[ToolPreset] = None, bias: Option """ return _get_combined_system_prompt(preset, bias) -from collections import deque - _comms_log: deque[dict[str, Any]] = deque(maxlen=1000) COMMS_CLAMP_CHARS: int = 300 +#endregion: System Prompt Management + +#region: Comms Log + def _append_comms(direction: str, kind: str, payload: dict[str, Any]) -> None: """ [C: tests/test_ai_client_concurrency.py:run_t1, tests/test_ai_client_concurrency.py:run_t2, tests/test_mma_agent_focus_phase1.py:test_append_comms_has_source_tier_key, tests/test_mma_agent_focus_phase1.py:test_append_comms_source_tier_none_when_unset, tests/test_mma_agent_focus_phase1.py:test_append_comms_source_tier_set_when_current_tier_set, tests/test_mma_agent_focus_phase1.py:test_append_comms_source_tier_tier2] @@ -522,22 +519,6 @@ def reset_session() -> None: _CACHED_DEEPSEEK_TOOLS = None file_cache.reset_client() -def get_gemini_cache_stats() -> dict[str, Any]: - """ - [C: src/app_controller.py:AppController._recalculate_session_usage, src/app_controller.py:AppController._update_cached_stats, tests/test_ai_cache_tracking.py:test_gemini_cache_tracking, tests/test_gemini_metrics.py:test_get_gemini_cache_stats_with_mock_client] - """ - _ensure_gemini_client() - if not _gemini_client: - return {"cache_count": 0, "total_size_bytes": 0, "cached_files": []} - caches_iterator = _gemini_client.caches.list() - caches = list(caches_iterator) - total_size_bytes = sum(getattr(c, 'size_bytes', 0) for c in caches) - return { - "cache_count": len(caches), - "total_size_bytes": total_size_bytes, - "cached_files": _gemini_cached_file_paths, - } - def list_models(provider: str) -> list[str]: """ [C: src/app_controller.py:AppController.do_fetch, tests/test_agent_capabilities.py:test_agent_capabilities_listing, tests/test_ai_client_list_models.py:test_list_models_gemini_cli, tests/test_deepseek_infra.py:test_deepseek_model_listing, tests/test_minimax_provider.py:test_minimax_list_models] @@ -560,60 +541,14 @@ def list_models(provider: str) -> list[str]: return _list_minimax_models(creds["minimax"]["api_key"]) return [] -def _list_gemini_cli_models() -> list[str]: - return [ - "gemini-3-flash-preview", - "gemini-3.1-pro-preview", - "gemini-2.5-pro", - "gemini-2.5-flash", - "gemini-2.0-flash", - "gemini-2.5-flash-lite", - ] - -def _list_gemini_models(api_key: str) -> list[str]: - try: - client = genai.Client(api_key=api_key) - models: list[str] = [] - for m in client.models.list(): - name = m.name - if name and name.startswith("models/"): - name = name[len("models/"):] - if name and "gemini" in name.lower(): - models.append(name) - return sorted(models) - except Exception as exc: - raise _classify_gemini_error(exc) from exc - -def _list_anthropic_models() -> list[str]: - try: - creds = _load_credentials() - client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"]) - models: list[str] = [] - for m in client.models.list(): - models.append(m.id) - return sorted(models) - except Exception as exc: - raise _classify_anthropic_error(exc) from exc - -def _list_deepseek_models(api_key: str) -> list[str]: - return ["deepseek-chat", "deepseek-reasoner"] - -def _list_minimax_models(api_key: str) -> list[str]: - try: - from openai import OpenAI - client = OpenAI(api_key=api_key, base_url="https://api.minimax.io/v1") - models_list = client.models.list() - found = [m.id for m in models_list] - if found: - return sorted(found) - except Exception: - pass - return ["MiniMax-M2.7", "MiniMax-M2.5", "MiniMax-M2.1", "MiniMax-M2"] +#endregion: Comms Log TOOL_NAME: str = "run_powershell" _agent_tools: dict[str, bool] = {} +#region: Tool Configuration + def set_agent_tools(tools: dict[str, bool]) -> None: """ Configures which tools are enabled for the AI agent. @@ -788,6 +723,10 @@ def _gemini_tool_declaration() -> Optional[types.Tool]: )) return types.Tool(function_declarations=declarations) if declarations else None +#endregion: Tool Configuration + +#region: Tool Execution + async def _execute_tool_calls_concurrently( calls: list[Any], base_dir: str, @@ -920,6 +859,10 @@ def _truncate_tool_output(output: str) -> str: return output[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS.]" return output +#endregion: Tool Execution + +#region: File Context Building + def _reread_file_items(file_items: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: refreshed: list[dict[str, Any]] = [] changed: list[dict[str, Any]] = [] @@ -1050,6 +993,414 @@ def _content_block_to_dict(block: Any) -> dict[str, Any]: return {"type": "tool_use", "id": getattr(block, "id"), "name": getattr(block, "name"), "input": getattr(block, "input")} return {"type": "text", "text": str(block)} +#endregion: File Context Building + +#region: Token Estimation + +_CHARS_PER_TOKEN: float = 3.5 +_ANTHROPIC_MAX_PROMPT_TOKENS: int = 180_000 +_GEMINI_MAX_INPUT_TOKENS: int = 900_000 +_FILE_REFRESH_MARKER: str = _project_context_marker if _project_context_marker.strip() else "[SYSTEM: FILES UPDATED]" + +def _estimate_message_tokens(msg: dict[str, Any]) -> int: + cached = msg.get("_est_tokens") + if cached is not None: + return cast(int, cached) + total_chars = 0 + content = msg.get("content", "") + if isinstance(content, str): + total_chars += len(content) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict): + text = block.get("text", "") or block.get("content", "") + if isinstance(text, str): + total_chars += len(text) + inp = block.get("input") + if isinstance(inp, dict): + import json as _json + total_chars += len(_json.dumps(inp, ensure_ascii=False)) + elif isinstance(block, str): + total_chars += len(block) + est = max(1, int(total_chars / _CHARS_PER_TOKEN)) + msg["_est_tokens"] = est + return est + +def _invalidate_token_estimate(msg: dict[str, Any]) -> None: + msg.pop("_est_tokens", None) + +def _estimate_prompt_tokens(system_blocks: list[dict[str, Any]], history: list[dict[str, Any]]) -> int: + total = 0 + for block in system_blocks: + text = cast(str, block.get("text", "")) + total += max(1, int(len(text) / _CHARS_PER_TOKEN)) + total += 2500 + for msg in history: + total += _estimate_message_tokens(msg) + return total + +def _strip_stale_file_refreshes(history: list[dict[str, Any]]) -> None: + if len(history) < 2: + return + last_user_idx = -1 + for i in range(len(history) - 1, -1, -1): + if history[i].get("role") == "user": + last_user_idx = i + break + for i, msg in enumerate(history): + if msg.get("role") != "user" or i == last_user_idx: + continue + content = msg.get("content") + if not isinstance(content, list): + continue + cleaned: list[dict[str, Any]] = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = cast(str, block.get("text", "")) + if text.startswith(_FILE_REFRESH_MARKER): + continue + cleaned.append(block) + if len(cleaned) < len(content): + msg["content"] = cleaned + _invalidate_token_estimate(msg) + +def _chunk_text(text: str, chunk_size: int) -> list[str]: + """ + [C: src/rag_engine.py:RAGEngine._chunk_code, src/rag_engine.py:RAGEngine.index_file] + """ + return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] + +def _build_chunked_context_blocks(md_content: str) -> list[dict[str, Any]]: + chunks = _chunk_text(md_content, _ANTHROPIC_CHUNK_SIZE) + blocks: list[dict[str, Any]] = [] + for i, chunk in enumerate(chunks): + block: dict[str, Any] = {"type": "text", "text": chunk} + if i == len(chunks) - 1: + block["cache_control"] = {"type": "ephemeral"} + blocks.append(block) + return blocks + +def _strip_cache_controls(history: list[dict[str, Any]]) -> None: + for msg in history: + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict): + block.pop("cache_control", None) + +def _add_history_cache_breakpoint(history: list[dict[str, Any]]) -> None: + user_indices = [i for i, m in enumerate(history) if m.get("role") == "user"] + if len(user_indices) < 2: + return + target_idx = user_indices[-2] + content = history[target_idx].get("content") + if isinstance(content, list) and content: + last_block = content[-1] + if isinstance(last_block, dict): + last_block["cache_control"] = {"type": "ephemeral"} + elif isinstance(content, str): + history[target_idx]["content"] = [ + {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}} + ] + +#endregion: Token Estimation + +#region: Anthropic Provider + +def _list_anthropic_models() -> list[str]: + try: + creds = _load_credentials() + client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"]) + models: list[str] = [] + for m in client.models.list(): + models.append(m.id) + return sorted(models) + except Exception as exc: + raise _classify_anthropic_error(exc) from exc + +def _ensure_anthropic_client() -> None: + global _anthropic_client + if _anthropic_client is None: + creds = _load_credentials() + _anthropic_client = anthropic.Anthropic( + api_key=creds["anthropic"]["api_key"], + default_headers={"anthropic-beta": "prompt-caching-2024-07-31"} + ) + +def _trim_anthropic_history(system_blocks: list[dict[str, Any]], history: list[dict[str, Any]]) -> int: + _strip_stale_file_refreshes(history) + est = _estimate_prompt_tokens(system_blocks, history) + if est <= _ANTHROPIC_MAX_PROMPT_TOKENS: + return 0 + dropped = 0 + while len(history) > 3 and est > _ANTHROPIC_MAX_PROMPT_TOKENS: + if history[1].get("role") == "assistant" and len(history) > 2 and history[2].get("role") == "user": + removed_asst = history.pop(1) + removed_user = history.pop(1) + dropped += 2 + est -= _estimate_message_tokens(removed_asst) + est -= _estimate_message_tokens(removed_user) + while len(history) > 2 and history[1].get("role") == "assistant" and history[2].get("role") == "user": + content = history[2].get("content", []) + if isinstance(content, list) and content and isinstance(content[0], dict) and content[0].get("type") == "tool_result": + r_a = history.pop(1) + r_u = history.pop(1) + dropped += 2 + est -= _estimate_message_tokens(r_a) + est -= _estimate_message_tokens(r_u) + else: + break + else: + removed = history.pop(1) + dropped += 1 + est -= _estimate_message_tokens(removed) + return dropped + +def _repair_anthropic_history(history: list[dict[str, Any]]) -> None: + if not history: + return + last = history[-1] + if last.get("role") != "assistant": + return + content = last.get("content", []) + tool_use_ids: list[str] = [] + for block in content: + if isinstance(block, dict): + if block.get("type") == "tool_use": + tool_use_ids.append(cast(str, block["id"])) + if not tool_use_ids: + return + history.append({ + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tid, + "content": "Tool call was not completed (session interrupted).", + } + for tid in tool_use_ids + ], + }) + +def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_items: list[dict[str, Any]] | None = None, discussion_history: str = "", pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, qa_callback: Optional[Callable[[str], str]] = None, stream_callback: Optional[Callable[[str], None]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: + monitor = performance_monitor.get_monitor() + if monitor.enabled: monitor.start_component("ai_client._send_anthropic") + try: + _ensure_anthropic_client() + mcp_client.configure(file_items or [], [base_dir]) + stable_prompt = _get_combined_system_prompt() + stable_blocks: list[dict[str, Any]] = [{"type": "text", "text": stable_prompt, "cache_control": {"type": "ephemeral"}}] + context_text = f"\n\n\n{md_content}\n" + context_blocks = _build_chunked_context_blocks(context_text) + system_blocks = stable_blocks + context_blocks + if discussion_history and not _anthropic_history: + user_content: list[dict[str, Any]] = [{"type": "text", "text": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"}] + else: + user_content = [{"type": "text", "text": user_message}] + for msg in _anthropic_history: + if msg.get("role") == "user" and isinstance(msg.get("content"), list): + modified = False + for block in cast(List[dict[str, Any]], msg["content"]): + if isinstance(block, dict) and block.get("type") == "tool_result": + t_content = block.get("content", "") + if _history_trunc_limit > 0 and isinstance(t_content, str) and len(t_content) > _history_trunc_limit: + block["content"] = t_content[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS. Original output was too large.]" + modified = True + if modified: + _invalidate_token_estimate(msg) + _strip_cache_controls(_anthropic_history) + _repair_anthropic_history(_anthropic_history) + _anthropic_history.append({"role": "user", "content": user_content}) + _add_history_cache_breakpoint(_anthropic_history) + all_text_parts: list[str] = [] + _cumulative_tool_bytes = 0 + + def _strip_private_keys(history: list[dict[str, Any]]) -> list[dict[str, Any]]: + return [{k: v for k, v in m.items() if not k.startswith("_")} for m in history] + + for round_idx in range(MAX_TOOL_ROUNDS + 2): + response: Any = None + dropped = _trim_anthropic_history(system_blocks, _anthropic_history) + if dropped > 0: + est_tokens = _estimate_prompt_tokens(system_blocks, _anthropic_history) + _append_comms("OUT", "request", { + "message": ( + f"[HISTORY TRIMMED: dropped {dropped} old messages to fit token budget. " + f"Estimated {est_tokens} tokens remaining. {len(_anthropic_history)} messages in history.]" + ), + }) + + events.emit("request_start", payload={"provider": "anthropic", "model": _model, "round": round_idx}) + assert _anthropic_client is not None + if stream_callback: + with _anthropic_client.messages.stream( + model=_model, + max_tokens=_max_tokens, + temperature=_temperature, + top_p=_top_p, + system=cast(Iterable[anthropic.types.TextBlockParam], system_blocks), + tools=cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()), + messages=cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)), + ) as stream: + for event in stream: + if isinstance(event, anthropic.types.ContentBlockDeltaEvent) and event.delta.type == "text_delta": + stream_callback(event.delta.text) + response = stream.get_final_message() + else: + response = _anthropic_client.messages.create( + model=_model, + max_tokens=_max_tokens, + temperature=_temperature, + top_p=_top_p, + system=cast(Iterable[anthropic.types.TextBlockParam], system_blocks), + tools=cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()), + messages=cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)), + ) + serialised_content = [_content_block_to_dict(b) for b in response.content] + _anthropic_history.append({ + "role": "assistant", + "content": serialised_content, + }) + text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text] + if text_blocks: + all_text_parts.append("\n".join(text_blocks)) + tool_use_blocks = [ + {"id": getattr(b, "id"), "name": getattr(b, "name"), "input": getattr(b, "input")} + for b in response.content + if getattr(b, "type", None) == "tool_use" + ] + usage_dict: dict[str, Any] = {} + if response.usage: + usage_dict["input_tokens"] = response.usage.input_tokens + usage_dict["output_tokens"] = response.usage.output_tokens + cache_creation = getattr(response.usage, "cache_creation_input_tokens", None) + cache_read = getattr(response.usage, "cache_read_input_tokens", None) + if cache_creation is not None: + usage_dict["cache_creation_input_tokens"] = cache_creation + if cache_read is not None: + usage_dict["cache_read_input_tokens"] = cache_read + events.emit("response_received", payload={"provider": "anthropic", "model": _model, "usage": usage_dict, "round": round_idx}) + _append_comms("IN", "response", { + "round": round_idx, + "stop_reason": response.stop_reason, + "text": "\n".join(text_blocks), + "tool_calls": tool_use_blocks, + "usage": usage_dict, + }) + if response.stop_reason != "tool_use" or not tool_use_blocks: + break + if round_idx > MAX_TOOL_ROUNDS: + break + + # Execute tools concurrently + try: + loop = asyncio.get_running_loop() + results = asyncio.run_coroutine_threadsafe( + _execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic", patch_callback), + loop + ).result() + except RuntimeError: + results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic", patch_callback)) + + tool_results: list[dict[str, Any]] = [] + for i, (name, call_id, out, _) in enumerate(results): + truncated = _truncate_tool_output(out) + _cumulative_tool_bytes += len(truncated) + tool_results.append({ + "type": "tool_result", + "tool_use_id": call_id, + "content": truncated, + }) + _append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out}) + events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx}) + + if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: + tool_results.append({ + "type": "text", + "text": f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now." + }) + _append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"}) + if file_items: + file_items, changed = _reread_file_items(file_items) + refreshed_ctx = _build_file_diff_text(changed) + if refreshed_ctx: + tool_results.append({ + "type": "text", + "text": ( + f"{_get_context_marker()}\n\n" + + refreshed_ctx + ), + }) + if round_idx == MAX_TOOL_ROUNDS: + tool_results.append({ + "type": "text", + "text": "SYSTEM WARNING: MAX TOOL ROUNDS REACHED. YOU MUST PROVIDE YOUR FINAL ANSWER NOW WITHOUT CALLING ANY MORE TOOLS." + }) + _anthropic_history.append({ + "role": "user", + "content": tool_results, + }) + _append_comms("OUT", "tool_result_send", { + "results": [ + {"tool_use_id": r["tool_use_id"], "content": r["content"]} + for r in tool_results if r.get("type") == "tool_result" + ], + }) + final_text = "\n\n".join(all_text_parts) + res = final_text if final_text.strip() else "(No text returned by the model)" + if monitor.enabled: monitor.end_component("ai_client._send_anthropic") + return res + except ProviderError: + if monitor.enabled: monitor.end_component("ai_client._send_anthropic") + raise + except Exception as exc: + if monitor.enabled: monitor.end_component("ai_client._send_anthropic") + raise _classify_anthropic_error(exc) from exc + +#endregion: Anthropic Provider + +#region: Gemini Provider + +def get_gemini_cache_stats() -> dict[str, Any]: + """ + [C: src/app_controller.py:AppController._recalculate_session_usage, src/app_controller.py:AppController._update_cached_stats, tests/test_ai_cache_tracking.py:test_gemini_cache_tracking, tests/test_gemini_metrics.py:test_get_gemini_cache_stats_with_mock_client] + """ + _ensure_gemini_client() + if not _gemini_client: + return {"cache_count": 0, "total_size_bytes": 0, "cached_files": []} + caches_iterator = _gemini_client.caches.list() + caches = list(caches_iterator) + total_size_bytes = sum(getattr(c, 'size_bytes', 0) for c in caches) + return { + "cache_count": len(caches), + "total_size_bytes": total_size_bytes, + "cached_files": _gemini_cached_file_paths, + } + +def _list_gemini_cli_models() -> list[str]: + return [ + "gemini-3-flash-preview", + "gemini-3.1-pro-preview", + "gemini-2.5-pro", + "gemini-2.5-flash", + "gemini-2.0-flash", + "gemini-2.5-flash-lite", + ] + +def _list_gemini_models(api_key: str) -> list[str]: + try: + client = genai.Client(api_key=api_key) + models: list[str] = [] + for m in client.models.list(): + name = m.name + if name and name.startswith("models/"): + name = name[len("models/"):] + if name and "gemini" in name.lower(): + models.append(name) + return sorted(models) + except Exception as exc: + raise _classify_gemini_error(exc) from exc + def _ensure_gemini_client() -> None: """ [C: src/rag_engine.py:GeminiEmbeddingProvider.embed] @@ -1414,366 +1765,12 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, except Exception as e: raise ProviderError("unknown", "gemini_cli", e) -_CHARS_PER_TOKEN: float = 3.5 -_ANTHROPIC_MAX_PROMPT_TOKENS: int = 180_000 -_GEMINI_MAX_INPUT_TOKENS: int = 900_000 -_FILE_REFRESH_MARKER: str = _project_context_marker if _project_context_marker.strip() else "[SYSTEM: FILES UPDATED]" +#endregion: Gemini Provider -def _estimate_message_tokens(msg: dict[str, Any]) -> int: - cached = msg.get("_est_tokens") - if cached is not None: - return cast(int, cached) - total_chars = 0 - content = msg.get("content", "") - if isinstance(content, str): - total_chars += len(content) - elif isinstance(content, list): - for block in content: - if isinstance(block, dict): - text = block.get("text", "") or block.get("content", "") - if isinstance(text, str): - total_chars += len(text) - inp = block.get("input") - if isinstance(inp, dict): - import json as _json - total_chars += len(_json.dumps(inp, ensure_ascii=False)) - elif isinstance(block, str): - total_chars += len(block) - est = max(1, int(total_chars / _CHARS_PER_TOKEN)) - msg["_est_tokens"] = est - return est +#region: DeepSeek Provider -def _invalidate_token_estimate(msg: dict[str, Any]) -> None: - msg.pop("_est_tokens", None) - -def _estimate_prompt_tokens(system_blocks: list[dict[str, Any]], history: list[dict[str, Any]]) -> int: - total = 0 - for block in system_blocks: - text = cast(str, block.get("text", "")) - total += max(1, int(len(text) / _CHARS_PER_TOKEN)) - total += 2500 - for msg in history: - total += _estimate_message_tokens(msg) - return total - -def _strip_stale_file_refreshes(history: list[dict[str, Any]]) -> None: - if len(history) < 2: - return - last_user_idx = -1 - for i in range(len(history) - 1, -1, -1): - if history[i].get("role") == "user": - last_user_idx = i - break - for i, msg in enumerate(history): - if msg.get("role") != "user" or i == last_user_idx: - continue - content = msg.get("content") - if not isinstance(content, list): - continue - cleaned: list[dict[str, Any]] = [] - for block in content: - if isinstance(block, dict) and block.get("type") == "text": - text = cast(str, block.get("text", "")) - if text.startswith(_FILE_REFRESH_MARKER): - continue - cleaned.append(block) - if len(cleaned) < len(content): - msg["content"] = cleaned - _invalidate_token_estimate(msg) - -def _trim_anthropic_history(system_blocks: list[dict[str, Any]], history: list[dict[str, Any]]) -> int: - _strip_stale_file_refreshes(history) - est = _estimate_prompt_tokens(system_blocks, history) - if est <= _ANTHROPIC_MAX_PROMPT_TOKENS: - return 0 - dropped = 0 - while len(history) > 3 and est > _ANTHROPIC_MAX_PROMPT_TOKENS: - if history[1].get("role") == "assistant" and len(history) > 2 and history[2].get("role") == "user": - removed_asst = history.pop(1) - removed_user = history.pop(1) - dropped += 2 - est -= _estimate_message_tokens(removed_asst) - est -= _estimate_message_tokens(removed_user) - while len(history) > 2 and history[1].get("role") == "assistant" and history[2].get("role") == "user": - content = history[2].get("content", []) - if isinstance(content, list) and content and isinstance(content[0], dict) and content[0].get("type") == "tool_result": - r_a = history.pop(1) - r_u = history.pop(1) - dropped += 2 - est -= _estimate_message_tokens(r_a) - est -= _estimate_message_tokens(r_u) - else: - break - else: - removed = history.pop(1) - dropped += 1 - est -= _estimate_message_tokens(removed) - return dropped - -def _ensure_anthropic_client() -> None: - global _anthropic_client - if _anthropic_client is None: - creds = _load_credentials() - _anthropic_client = anthropic.Anthropic( - api_key=creds["anthropic"]["api_key"], - default_headers={"anthropic-beta": "prompt-caching-2024-07-31"} - ) - -def _chunk_text(text: str, chunk_size: int) -> list[str]: - """ - [C: src/rag_engine.py:RAGEngine._chunk_code, src/rag_engine.py:RAGEngine.index_file] - """ - return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] - -def _build_chunked_context_blocks(md_content: str) -> list[dict[str, Any]]: - chunks = _chunk_text(md_content, _ANTHROPIC_CHUNK_SIZE) - blocks: list[dict[str, Any]] = [] - for i, chunk in enumerate(chunks): - block: dict[str, Any] = {"type": "text", "text": chunk} - if i == len(chunks) - 1: - block["cache_control"] = {"type": "ephemeral"} - blocks.append(block) - return blocks - -def _strip_cache_controls(history: list[dict[str, Any]]) -> None: - for msg in history: - content = msg.get("content") - if isinstance(content, list): - for block in content: - if isinstance(block, dict): - block.pop("cache_control", None) - -def _add_history_cache_breakpoint(history: list[dict[str, Any]]) -> None: - user_indices = [i for i, m in enumerate(history) if m.get("role") == "user"] - if len(user_indices) < 2: - return - target_idx = user_indices[-2] - content = history[target_idx].get("content") - if isinstance(content, list) and content: - last_block = content[-1] - if isinstance(last_block, dict): - last_block["cache_control"] = {"type": "ephemeral"} - elif isinstance(content, str): - history[target_idx]["content"] = [ - {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}} - ] - -def _repair_anthropic_history(history: list[dict[str, Any]]) -> None: - if not history: - return - last = history[-1] - if last.get("role") != "assistant": - return - content = last.get("content", []) - tool_use_ids: list[str] = [] - for block in content: - if isinstance(block, dict): - if block.get("type") == "tool_use": - tool_use_ids.append(cast(str, block["id"])) - if not tool_use_ids: - return - history.append({ - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": tid, - "content": "Tool call was not completed (session interrupted).", - } - for tid in tool_use_ids - ], - }) - -def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_items: list[dict[str, Any]] | None = None, discussion_history: str = "", pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, qa_callback: Optional[Callable[[str], str]] = None, stream_callback: Optional[Callable[[str], None]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: - monitor = performance_monitor.get_monitor() - if monitor.enabled: monitor.start_component("ai_client._send_anthropic") - try: - _ensure_anthropic_client() - mcp_client.configure(file_items or [], [base_dir]) - stable_prompt = _get_combined_system_prompt() - stable_blocks: list[dict[str, Any]] = [{"type": "text", "text": stable_prompt, "cache_control": {"type": "ephemeral"}}] - context_text = f"\n\n\n{md_content}\n" - context_blocks = _build_chunked_context_blocks(context_text) - system_blocks = stable_blocks + context_blocks - if discussion_history and not _anthropic_history: - user_content: list[dict[str, Any]] = [{"type": "text", "text": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"}] - else: - user_content = [{"type": "text", "text": user_message}] - for msg in _anthropic_history: - if msg.get("role") == "user" and isinstance(msg.get("content"), list): - modified = False - for block in cast(List[dict[str, Any]], msg["content"]): - if isinstance(block, dict) and block.get("type") == "tool_result": - t_content = block.get("content", "") - if _history_trunc_limit > 0 and isinstance(t_content, str) and len(t_content) > _history_trunc_limit: - block["content"] = t_content[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS. Original output was too large.]" - modified = True - if modified: - _invalidate_token_estimate(msg) - _strip_cache_controls(_anthropic_history) - _repair_anthropic_history(_anthropic_history) - _anthropic_history.append({"role": "user", "content": user_content}) - _add_history_cache_breakpoint(_anthropic_history) - all_text_parts: list[str] = [] - _cumulative_tool_bytes = 0 - - def _strip_private_keys(history: list[dict[str, Any]]) -> list[dict[str, Any]]: - return [{k: v for k, v in m.items() if not k.startswith("_")} for m in history] - - for round_idx in range(MAX_TOOL_ROUNDS + 2): - response: Any = None - dropped = _trim_anthropic_history(system_blocks, _anthropic_history) - if dropped > 0: - est_tokens = _estimate_prompt_tokens(system_blocks, _anthropic_history) - _append_comms("OUT", "request", { - "message": ( - f"[HISTORY TRIMMED: dropped {dropped} old messages to fit token budget. " - f"Estimated {est_tokens} tokens remaining. {len(_anthropic_history)} messages in history.]" - ), - }) - - events.emit("request_start", payload={"provider": "anthropic", "model": _model, "round": round_idx}) - assert _anthropic_client is not None - if stream_callback: - with _anthropic_client.messages.stream( - model=_model, - max_tokens=_max_tokens, - temperature=_temperature, - top_p=_top_p, - system=cast(Iterable[anthropic.types.TextBlockParam], system_blocks), - tools=cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()), - messages=cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)), - ) as stream: - for event in stream: - if isinstance(event, anthropic.types.ContentBlockDeltaEvent) and event.delta.type == "text_delta": - stream_callback(event.delta.text) - response = stream.get_final_message() - else: - response = _anthropic_client.messages.create( - model=_model, - max_tokens=_max_tokens, - temperature=_temperature, - top_p=_top_p, - system=cast(Iterable[anthropic.types.TextBlockParam], system_blocks), - tools=cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()), - messages=cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)), - ) - serialised_content = [_content_block_to_dict(b) for b in response.content] - _anthropic_history.append({ - "role": "assistant", - "content": serialised_content, - }) - text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text] - if text_blocks: - all_text_parts.append("\n".join(text_blocks)) - tool_use_blocks = [ - {"id": getattr(b, "id"), "name": getattr(b, "name"), "input": getattr(b, "input")} - for b in response.content - if getattr(b, "type", None) == "tool_use" - ] - usage_dict: dict[str, Any] = {} - if response.usage: - usage_dict["input_tokens"] = response.usage.input_tokens - usage_dict["output_tokens"] = response.usage.output_tokens - cache_creation = getattr(response.usage, "cache_creation_input_tokens", None) - cache_read = getattr(response.usage, "cache_read_input_tokens", None) - if cache_creation is not None: - usage_dict["cache_creation_input_tokens"] = cache_creation - if cache_read is not None: - usage_dict["cache_read_input_tokens"] = cache_read - events.emit("response_received", payload={"provider": "anthropic", "model": _model, "usage": usage_dict, "round": round_idx}) - _append_comms("IN", "response", { - "round": round_idx, - "stop_reason": response.stop_reason, - "text": "\n".join(text_blocks), - "tool_calls": tool_use_blocks, - "usage": usage_dict, - }) - if response.stop_reason != "tool_use" or not tool_use_blocks: - break - if round_idx > MAX_TOOL_ROUNDS: - break - - # Execute tools concurrently - try: - loop = asyncio.get_running_loop() - results = asyncio.run_coroutine_threadsafe( - _execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic", patch_callback), - loop - ).result() - except RuntimeError: - results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic", patch_callback)) - - tool_results: list[dict[str, Any]] = [] - for i, (name, call_id, out, _) in enumerate(results): - truncated = _truncate_tool_output(out) - _cumulative_tool_bytes += len(truncated) - tool_results.append({ - "type": "tool_result", - "tool_use_id": call_id, - "content": truncated, - }) - _append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out}) - events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx}) - - if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES: - tool_results.append({ - "type": "text", - "text": f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now." - }) - _append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"}) - if file_items: - file_items, changed = _reread_file_items(file_items) - refreshed_ctx = _build_file_diff_text(changed) - if refreshed_ctx: - tool_results.append({ - "type": "text", - "text": ( - f"{_get_context_marker()}\n\n" - + refreshed_ctx - ), - }) - if round_idx == MAX_TOOL_ROUNDS: - tool_results.append({ - "type": "text", - "text": "SYSTEM WARNING: MAX TOOL ROUNDS REACHED. YOU MUST PROVIDE YOUR FINAL ANSWER NOW WITHOUT CALLING ANY MORE TOOLS." - }) - _anthropic_history.append({ - "role": "user", - "content": tool_results, - }) - _append_comms("OUT", "tool_result_send", { - "results": [ - {"tool_use_id": r["tool_use_id"], "content": r["content"]} - for r in tool_results if r.get("type") == "tool_result" - ], - }) - final_text = "\n\n".join(all_text_parts) - res = final_text if final_text.strip() else "(No text returned by the model)" - if monitor.enabled: monitor.end_component("ai_client._send_anthropic") - return res - except ProviderError: - if monitor.enabled: monitor.end_component("ai_client._send_anthropic") - raise - except Exception as exc: - if monitor.enabled: monitor.end_component("ai_client._send_anthropic") - raise _classify_anthropic_error(exc) from exc - -def _ensure_deepseek_client() -> None: - global _deepseek_client - if _deepseek_client is None: - _load_credentials() - pass - -def _ensure_minimax_client() -> None: - global _minimax_client - if _minimax_client is None: - from openai import OpenAI - creds = _load_credentials() - api_key = creds.get("minimax", {}).get("api_key") - if not api_key: - raise ValueError("MiniMax API key not found in credentials.toml") - _minimax_client = OpenAI(api_key=api_key, base_url="https://api.minimax.chat/v1") +def _list_deepseek_models(api_key: str) -> list[str]: + return ["deepseek-chat", "deepseek-reasoner"] def _repair_deepseek_history(history: list[dict[str, Any]]) -> None: if not history: @@ -1791,10 +1788,17 @@ def _repair_deepseek_history(history: list[dict[str, Any]]) -> None: if not already_has: history.append({ "role": "tool", + "tool_call_id": cid, "content": "ERROR: Session was interrupted before tool result was recorded.", }) +def _ensure_deepseek_client() -> None: + global _deepseek_client + if _deepseek_client is None: + _load_credentials() + pass + def _send_deepseek(md_content: str, user_message: str, base_dir: str, file_items: list[dict[str, Any]] | None = None, discussion_history: str = "", @@ -2048,6 +2052,32 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str, if monitor.enabled: monitor.end_component("ai_client._send_deepseek") raise _classify_deepseek_error(e) from e +#endregion: DeepSeek Provider + +#region: MiniMax Provider + +def _list_minimax_models(api_key: str) -> list[str]: + try: + from openai import OpenAI + client = OpenAI(api_key=api_key, base_url="https://api.minimax.io/v1") + models_list = client.models.list() + found = [m.id for m in models_list] + if found: + return sorted(found) + except Exception: + pass + return ["MiniMax-M2.7", "MiniMax-M2.5", "MiniMax-M2.1", "MiniMax-M2"] + +def _ensure_minimax_client() -> None: + global _minimax_client + if _minimax_client is None: + from openai import OpenAI + creds = _load_credentials() + api_key = creds.get("minimax", {}).get("api_key") + if not api_key: + raise ValueError("MiniMax API key not found in credentials.toml") + _minimax_client = OpenAI(api_key=api_key, base_url="https://api.minimax.chat/v1") + def _send_minimax(md_content: str, user_message: str, base_dir: str, file_items: list[dict[str, Any]] | None = None, discussion_history: str = "", @@ -2266,6 +2296,9 @@ def _send_minimax(md_content: str, user_message: str, base_dir: str, except Exception as e: raise _classify_minimax_error(e) from e +#endregion: MiniMax Provider + +#region: Tier 4 Analysis def run_tier4_analysis(stderr: str) -> str: """ @@ -2297,10 +2330,12 @@ def run_tier4_analysis(stderr: str) -> str: except Exception as e: return f"[QA ANALYSIS FAILED] {e}" +#endregion: Tier 4 Analysis + +#region: Session & Public API def run_tier4_patch_callback(stderr: str, base_dir: str) -> Optional[str]: try: - from src import project_manager file_items = project_manager.get_current_file_items() file_context = "" for item in file_items[:5]: @@ -2314,7 +2349,6 @@ def run_tier4_patch_callback(stderr: str, base_dir: str) -> Optional[str]: except Exception as e: return None - def run_tier4_patch_generation(error: str, file_context: str) -> str: """ [C: src/gui_2.py:App.request_patch_from_tier4, src/native_orchestrator.py:NativeOrchestrator.run_tier4_patch, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_calls_ai, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_empty_error, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_returns_diff] @@ -2457,6 +2491,7 @@ def _add_bleed_derived(d: dict[str, Any], sys_tok: int = 0, tool_tok: int = 0) - d["tool_tokens"] = tool_tok d["history_tokens"] = max(0, cur - sys_tok - tool_tok) return d + # Check for tool preset in environment variable (headless mode) if os.environ.get("SLOP_TOOL_PRESET"): try: @@ -2464,6 +2499,10 @@ if os.environ.get("SLOP_TOOL_PRESET"): except Exception: pass +#endregion: Session & Public API + +#region: Subagent Summarization + def run_subagent_summarization(file_path: str, content: str, is_code: bool, outline: str) -> str: """ Performs a stateless summarization request using a sub-agent prompt. [C: src/summarize.py:summarise_file, tests/test_subagent_summarization.py:test_run_subagent_summarization_anthropic, tests/test_subagent_summarization.py:test_run_subagent_summarization_gemini] @@ -2514,3 +2553,5 @@ def run_subagent_summarization(file_path: str, content: str, is_code: bool, outl resp_data = adapter.send(prompt, model=_model) return resp_data.get("text", "") return "ERROR: Unsupported provider for sub-agent summarization" + +#endregion: Subagent Summarization