diff --git a/src/ai_client.py b/src/ai_client.py index c7c3aa8a..8a257f6c 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -5,25 +5,26 @@ Note(Gemini): Acts as the unified interface for multiple LLM providers (Anthropic, Gemini). Abstracts away the differences in how they handle tool schemas, history, and caching. -For Anthropic: aggressively manages the ~200k token limit by manually culling -stale [FILES UPDATED] entries and dropping the oldest message pairs. +For Anthropic: aggressively manages the ~200k token limit by manually culling +stale [FILES UPDATED] entries and dropping the oldest message pairs. -For Gemini: injects the initial context directly into system_instruction +For Gemini: injects the initial context directly into system_instruction during chat creation to avoid massive history bloat. + +HEAVY IMPORTS (startup_speedup_20260606): The heavy SDKs (anthropic, +google.genai, openai, google.genai.types, requests) are NOT imported +at module level. They are warmed on AppController's _io_pool at +startup and accessed via _require_warmed() below. This keeps the +main thread's import chain lean and the GUI responsive on startup. """ -import anthropic -from google import genai -from openai import OpenAI - -from google.genai import types +import importlib import asyncio import datetime import difflib import hashlib import json import os -import requests # type: ignore[import-untyped] import sys import threading import time @@ -51,6 +52,26 @@ from src.tool_bias import ToolBiasEngine from src.tool_presets import ToolPresetManager +def _require_warmed(name: str) -> Any: + """Return a heavy module that the AppController's warmup should have loaded. + + Heavy SDKs (anthropic, google.genai, openai, google.genai.types, + requests) are warmed on AppController's _io_pool at startup. This + function expects them to already be in sys.modules and just returns + the cached module object. If the module is NOT in sys.modules (e.g. + in tests where warmup didn't run), falls back to importlib so the + call still works. + + In production: this is an O(1) sys.modules lookup. The 1+ second + import cost is paid during startup on a bg thread, NOT on the first + user-triggered AI call. + """ + mod = sys.modules.get(name) + if mod is not None: + return mod + return importlib.import_module(name) + + _provider: str = "gemini" _model: str = "gemini-2.5-flash-lite" _temperature: float = 0.0 @@ -333,11 +354,12 @@ def _load_credentials() -> dict[str, Any]: def _classify_anthropic_error(exc: Exception) -> ProviderError: try: + anthropic = _require_warmed("anthropic") if isinstance(exc, anthropic.RateLimitError): return ProviderError("rate_limit", "anthropic", exc) if isinstance(exc, anthropic.AuthenticationError): return ProviderError("auth", "anthropic", exc) if isinstance(exc, anthropic.PermissionDeniedError): return ProviderError("auth", "anthropic", exc) if isinstance(exc, anthropic.APIConnectionError): return ProviderError("network", "anthropic", exc) - if isinstance(exc, anthropic.APIStatusError): + if isinstance(exc, anthropic.APIStatusError): status = getattr(exc, "status_code", 0) body = str(exc).lower() if status == 429: return ProviderError("rate_limit", "anthropic", exc) @@ -366,6 +388,7 @@ def _classify_gemini_error(exc: Exception) -> ProviderError: return ProviderError("unknown", "gemini", exc) def _classify_deepseek_error(exc: Exception) -> ProviderError: + requests = _require_warmed("requests") body = "" if isinstance(exc, requests.exceptions.HTTPError) and exc.response is not None: try: @@ -389,6 +412,7 @@ def _classify_deepseek_error(exc: Exception) -> ProviderError: return ProviderError("unknown", "deepseek", Exception(body)) def _classify_minimax_error(exc: Exception) -> ProviderError: + requests = _require_warmed("requests") body = "" if isinstance(exc, requests.exceptions.HTTPError) and exc.response is not None: try: @@ -637,6 +661,7 @@ def _gemini_tool_declaration() -> Optional[types.Tool]: """ [C: tests/test_tool_access_exclusion.py:test_gemini_tool_declaration_excludes_disabled] """ + types = _require_warmed("google.genai.types") raw_tools: list[dict[str, Any]] = [] for spec in mcp_client.get_tool_schemas(): if _agent_tools.get(spec["name"], True): @@ -1075,6 +1100,7 @@ def _add_history_cache_breakpoint(history: list[dict[str, Any]]) -> None: def _list_anthropic_models() -> list[str]: try: + anthropic = _require_warmed("anthropic") creds = _load_credentials() client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"]) models: list[str] = [] @@ -1086,6 +1112,7 @@ def _list_anthropic_models() -> list[str]: def _ensure_anthropic_client() -> None: global _anthropic_client + anthropic = _require_warmed("anthropic") if _anthropic_client is None: creds = _load_credentials() _anthropic_client = anthropic.Anthropic( @@ -1150,8 +1177,10 @@ def _repair_anthropic_history(history: list[dict[str, Any]]) -> None: def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_items: list[dict[str, Any]] | None = None, discussion_history: str = "", pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, qa_callback: Optional[Callable[[str], str]] = None, stream_callback: Optional[Callable[[str], None]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: """ - [C: src/ai_server.py:_handle_send] + [C: src/ai_server.py:_handle_send] """ + anthropic = _require_warmed("anthropic") + types = _require_warmed("google.genai.types") monitor = performance_monitor.get_monitor() if monitor.enabled: monitor.start_component("ai_client._send_anthropic") try: @@ -1358,6 +1387,7 @@ def _list_gemini_cli_models() -> list[str]: def _list_gemini_models(api_key: str) -> list[str]: try: + genai = _require_warmed("google.genai") client = genai.Client(api_key=api_key) models: list[str] = [] for m in client.models.list(): @@ -1371,12 +1401,13 @@ def _list_gemini_models(api_key: str) -> list[str]: raise _classify_gemini_error(exc) from exc def _ensure_gemini_client() -> None: - """ - [C: src/rag_engine.py:GeminiEmbeddingProvider.embed] - """ - global _gemini_client - if _gemini_client is None: - creds = _load_credentials() + """ + [C: src/rag_engine.py:GeminiEmbeddingProvider.embed] + """ + global _gemini_client + genai = _require_warmed("google.genai") + if _gemini_client is None: + creds = _load_credentials() _gemini_client = genai.Client(api_key=creds["gemini"]["api_key"]) def _get_gemini_history_list(chat: Any | None) -> list[Any]: @@ -1401,6 +1432,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, [C: src/ai_server.py:_handle_send, tests/test_tier4_interceptor.py:test_gemini_provider_passes_qa_callback_to_run_script] """ global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths + types = _require_warmed("google.genai.types") monitor = performance_monitor.get_monitor() if monitor.enabled: monitor.start_component("ai_client._send_gemini") try: @@ -1782,6 +1814,7 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str, """ [C: src/ai_server.py:_handle_send] """ + requests = _require_warmed("requests") monitor = performance_monitor.get_monitor() if monitor.enabled: monitor.start_component("ai_client._send_deepseek") try: @@ -2033,6 +2066,8 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str, def _list_minimax_models(api_key: str) -> list[str]: try: + openai = _require_warmed("openai") + OpenAI = openai.OpenAI client = OpenAI(api_key=api_key, base_url="https://api.minimax.io/v1") models_list = client.models.list() found = [m.id for m in models_list] @@ -2093,6 +2128,7 @@ def _trim_minimax_history(system_blocks: list[dict[str, Any]], history: list[dic def _ensure_minimax_client() -> None: global _minimax_client + openai = _require_warmed("openai") if _minimax_client is None: creds = _load_credentials() api_key = creds.get("minimax", {}).get("api_key") @@ -2111,6 +2147,8 @@ def _send_minimax(md_content: str, user_message: str, base_dir: str, """ [C: src/ai_server.py:_handle_send] """ + openai = _require_warmed("openai") + requests = _require_warmed("requests") try: mcp_client.configure(file_items or [], [base_dir]) creds = _load_credentials() @@ -2332,6 +2370,7 @@ def _send_minimax(md_content: str, user_message: str, base_dir: str, def run_tier4_analysis(stderr: str) -> str: """ """ + types = _require_warmed("google.genai.types") if not stderr or not stderr.strip(): return "" try: @@ -2381,6 +2420,7 @@ def run_tier4_patch_generation(error: str, file_context: str) -> str: """ [C: src/gui_2.py:App.request_patch_from_tier4, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_calls_ai, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_empty_error, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_returns_diff] """ + types = _require_warmed("google.genai.types") if not error or not error.strip(): return "" try: @@ -2537,6 +2577,8 @@ def run_subagent_summarization(file_path: str, content: str, is_code: bool, outl """ [C: src/summarize.py:summarise_file, tests/test_subagent_summarization.py:test_run_subagent_summarization_anthropic, tests/test_subagent_summarization.py:test_run_subagent_summarization_gemini] """ + requests = _require_warmed("requests") + types = _require_warmed("google.genai.types") prompt_tmpl = mma_prompts.TIER4_SUMMARIZE_CODE_PROMPT if is_code else mma_prompts.TIER4_SUMMARIZE_TEXT_PROMPT prompt = prompt_tmpl.format(file_path=file_path, outline=outline, content=content) if _provider == "gemini": @@ -2584,6 +2626,8 @@ def run_subagent_summarization(file_path: str, content: str, is_code: bool, outl return "ERROR: Unsupported provider for sub-agent summarization" def run_discussion_compression(discussion_text: str) -> str: + types = _require_warmed("google.genai.types") + requests = _require_warmed("requests") # Robustly identify the provider string (handles case and whitespace) p = str(get_provider()).lower().strip() prompt = f"The following is a long conversation history.\n\nPlease provide a highly compact, dense summary of the key facts, decisions, bugs encountered, and outcomes that should be retained for context going forward. Categorize into User intent, Tool outputs, and AI reasoning. Omit pleasantries and redundant thoughts.\n\n[HISTORY]\n{discussion_text}" diff --git a/tests/test_tier4_patch_generation.py b/tests/test_tier4_patch_generation.py index b86bfef1..baf2f995 100644 --- a/tests/test_tier4_patch_generation.py +++ b/tests/test_tier4_patch_generation.py @@ -36,31 +36,33 @@ def test_run_tier4_patch_generation_empty_error() -> None: def test_run_tier4_patch_generation_calls_ai() -> None: """Test that run_tier4_patch_generation calls the AI with the correct prompt.""" + mock_types = MagicMock() + mock_types.GenerateContentConfig = MagicMock() with patch("src.ai_client._ensure_gemini_client"), \ patch("src.ai_client._gemini_client", create=True) as mock_client, \ - patch("src.ai_client.types") as mock_types: + patch("src.ai_client._require_warmed", return_value=mock_types): mock_resp = MagicMock() mock_resp.text = "--- a/test.py\n+++ b/test.py\n@@ -1 +1 @@\n-old\n+new" mock_client.models.generate_content.return_value = mock_resp - mock_types.GenerateContentConfig = MagicMock() - + error = "TypeError: unsupported operand" file_context = "def foo():\n pass" result = ai_client.run_tier4_patch_generation(error, file_context) - + mock_client.models.generate_content.assert_called() def test_run_tier4_patch_generation_returns_diff() -> None: """Test that run_tier4_patch_generation returns diff text.""" + mock_types = MagicMock() + mock_types.GenerateContentConfig = MagicMock() with patch("src.ai_client._ensure_gemini_client"), \ patch("src.ai_client._gemini_client", create=True) as mock_client, \ - patch("src.ai_client.types") as mock_types: + patch("src.ai_client._require_warmed", return_value=mock_types): expected_diff = "--- a/src/test.py\n+++ b/src/test.py\n@@ -10,5 +10,6 @@\n def test_func():\n- old_value = 1\n+ old_value = 1\n+ new_value = 2" mock_resp = MagicMock() mock_resp.text = expected_diff mock_client.models.generate_content.return_value = mock_resp - mock_types.GenerateContentConfig = MagicMock() - + result = ai_client.run_tier4_patch_generation("error", "context") assert "---" in result assert "+++" in result