refactor(ai_client): remove top-level SDK imports; use _require_warmed
Phase 3 T3.2 + T3.3 of startup_speedup_20260606 track. The 5 heavy SDKs (anthropic, google.genai, openai, google.genai.types, requests) are no longer imported at module level. Each function that needs them now calls _require_warmed(name) to get the module from sys.modules (populated by AppController's warmup on _io_pool). This is the load-bearing wall of the Main Thread Purity Invariant: heavy modules are never in the main thread's import chain. run_discussion_compression now uses _require_warmed for both google.genai.types (gemini branch) and requests (deepseek branch). Tests/test_tier4_patch_generation.py adapted: the 2 tests that mocked 'src.ai_client.types' (no longer a module-level attr) now mock 'src.ai_client._require_warmed' (the new public mechanism). T3.1 tests now pass (9/9). T3.3 breakage fixed. All 25 ai_client + tier4 tests pass.
This commit is contained in:
+61
-17
@@ -5,25 +5,26 @@ Note(Gemini):
|
||||
Acts as the unified interface for multiple LLM providers (Anthropic, Gemini).
|
||||
Abstracts away the differences in how they handle tool schemas, history, and caching.
|
||||
|
||||
For Anthropic: aggressively manages the ~200k token limit by manually culling
|
||||
stale [FILES UPDATED] entries and dropping the oldest message pairs.
|
||||
For Anthropic: aggressively manages the ~200k token limit by manually culling
|
||||
stale [FILES UPDATED] entries and dropping the oldest message pairs.
|
||||
|
||||
For Gemini: injects the initial context directly into system_instruction
|
||||
For Gemini: injects the initial context directly into system_instruction
|
||||
during chat creation to avoid massive history bloat.
|
||||
|
||||
HEAVY IMPORTS (startup_speedup_20260606): The heavy SDKs (anthropic,
|
||||
google.genai, openai, google.genai.types, requests) are NOT imported
|
||||
at module level. They are warmed on AppController's _io_pool at
|
||||
startup and accessed via _require_warmed() below. This keeps the
|
||||
main thread's import chain lean and the GUI responsive on startup.
|
||||
"""
|
||||
import anthropic
|
||||
from google import genai
|
||||
from openai import OpenAI
|
||||
|
||||
from google.genai import types
|
||||
|
||||
import importlib
|
||||
import asyncio
|
||||
import datetime
|
||||
import difflib
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import requests # type: ignore[import-untyped]
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -51,6 +52,26 @@ from src.tool_bias import ToolBiasEngine
|
||||
from src.tool_presets import ToolPresetManager
|
||||
|
||||
|
||||
def _require_warmed(name: str) -> Any:
|
||||
"""Return a heavy module that the AppController's warmup should have loaded.
|
||||
|
||||
Heavy SDKs (anthropic, google.genai, openai, google.genai.types,
|
||||
requests) are warmed on AppController's _io_pool at startup. This
|
||||
function expects them to already be in sys.modules and just returns
|
||||
the cached module object. If the module is NOT in sys.modules (e.g.
|
||||
in tests where warmup didn't run), falls back to importlib so the
|
||||
call still works.
|
||||
|
||||
In production: this is an O(1) sys.modules lookup. The 1+ second
|
||||
import cost is paid during startup on a bg thread, NOT on the first
|
||||
user-triggered AI call.
|
||||
"""
|
||||
mod = sys.modules.get(name)
|
||||
if mod is not None:
|
||||
return mod
|
||||
return importlib.import_module(name)
|
||||
|
||||
|
||||
_provider: str = "gemini"
|
||||
_model: str = "gemini-2.5-flash-lite"
|
||||
_temperature: float = 0.0
|
||||
@@ -333,11 +354,12 @@ def _load_credentials() -> dict[str, Any]:
|
||||
|
||||
def _classify_anthropic_error(exc: Exception) -> ProviderError:
|
||||
try:
|
||||
anthropic = _require_warmed("anthropic")
|
||||
if isinstance(exc, anthropic.RateLimitError): return ProviderError("rate_limit", "anthropic", exc)
|
||||
if isinstance(exc, anthropic.AuthenticationError): return ProviderError("auth", "anthropic", exc)
|
||||
if isinstance(exc, anthropic.PermissionDeniedError): return ProviderError("auth", "anthropic", exc)
|
||||
if isinstance(exc, anthropic.APIConnectionError): return ProviderError("network", "anthropic", exc)
|
||||
if isinstance(exc, anthropic.APIStatusError):
|
||||
if isinstance(exc, anthropic.APIStatusError):
|
||||
status = getattr(exc, "status_code", 0)
|
||||
body = str(exc).lower()
|
||||
if status == 429: return ProviderError("rate_limit", "anthropic", exc)
|
||||
@@ -366,6 +388,7 @@ def _classify_gemini_error(exc: Exception) -> ProviderError:
|
||||
return ProviderError("unknown", "gemini", exc)
|
||||
|
||||
def _classify_deepseek_error(exc: Exception) -> ProviderError:
|
||||
requests = _require_warmed("requests")
|
||||
body = ""
|
||||
if isinstance(exc, requests.exceptions.HTTPError) and exc.response is not None:
|
||||
try:
|
||||
@@ -389,6 +412,7 @@ def _classify_deepseek_error(exc: Exception) -> ProviderError:
|
||||
return ProviderError("unknown", "deepseek", Exception(body))
|
||||
|
||||
def _classify_minimax_error(exc: Exception) -> ProviderError:
|
||||
requests = _require_warmed("requests")
|
||||
body = ""
|
||||
if isinstance(exc, requests.exceptions.HTTPError) and exc.response is not None:
|
||||
try:
|
||||
@@ -637,6 +661,7 @@ def _gemini_tool_declaration() -> Optional[types.Tool]:
|
||||
"""
|
||||
[C: tests/test_tool_access_exclusion.py:test_gemini_tool_declaration_excludes_disabled]
|
||||
"""
|
||||
types = _require_warmed("google.genai.types")
|
||||
raw_tools: list[dict[str, Any]] = []
|
||||
for spec in mcp_client.get_tool_schemas():
|
||||
if _agent_tools.get(spec["name"], True):
|
||||
@@ -1075,6 +1100,7 @@ def _add_history_cache_breakpoint(history: list[dict[str, Any]]) -> None:
|
||||
|
||||
def _list_anthropic_models() -> list[str]:
|
||||
try:
|
||||
anthropic = _require_warmed("anthropic")
|
||||
creds = _load_credentials()
|
||||
client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
|
||||
models: list[str] = []
|
||||
@@ -1086,6 +1112,7 @@ def _list_anthropic_models() -> list[str]:
|
||||
|
||||
def _ensure_anthropic_client() -> None:
|
||||
global _anthropic_client
|
||||
anthropic = _require_warmed("anthropic")
|
||||
if _anthropic_client is None:
|
||||
creds = _load_credentials()
|
||||
_anthropic_client = anthropic.Anthropic(
|
||||
@@ -1150,8 +1177,10 @@ def _repair_anthropic_history(history: list[dict[str, Any]]) -> None:
|
||||
|
||||
def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_items: list[dict[str, Any]] | None = None, discussion_history: str = "", pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, qa_callback: Optional[Callable[[str], str]] = None, stream_callback: Optional[Callable[[str], None]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
|
||||
"""
|
||||
[C: src/ai_server.py:_handle_send]
|
||||
[C: src/ai_server.py:_handle_send]
|
||||
"""
|
||||
anthropic = _require_warmed("anthropic")
|
||||
types = _require_warmed("google.genai.types")
|
||||
monitor = performance_monitor.get_monitor()
|
||||
if monitor.enabled: monitor.start_component("ai_client._send_anthropic")
|
||||
try:
|
||||
@@ -1358,6 +1387,7 @@ def _list_gemini_cli_models() -> list[str]:
|
||||
|
||||
def _list_gemini_models(api_key: str) -> list[str]:
|
||||
try:
|
||||
genai = _require_warmed("google.genai")
|
||||
client = genai.Client(api_key=api_key)
|
||||
models: list[str] = []
|
||||
for m in client.models.list():
|
||||
@@ -1371,12 +1401,13 @@ def _list_gemini_models(api_key: str) -> list[str]:
|
||||
raise _classify_gemini_error(exc) from exc
|
||||
|
||||
def _ensure_gemini_client() -> None:
|
||||
"""
|
||||
[C: src/rag_engine.py:GeminiEmbeddingProvider.embed]
|
||||
"""
|
||||
global _gemini_client
|
||||
if _gemini_client is None:
|
||||
creds = _load_credentials()
|
||||
"""
|
||||
[C: src/rag_engine.py:GeminiEmbeddingProvider.embed]
|
||||
"""
|
||||
global _gemini_client
|
||||
genai = _require_warmed("google.genai")
|
||||
if _gemini_client is None:
|
||||
creds = _load_credentials()
|
||||
_gemini_client = genai.Client(api_key=creds["gemini"]["api_key"])
|
||||
|
||||
def _get_gemini_history_list(chat: Any | None) -> list[Any]:
|
||||
@@ -1401,6 +1432,7 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
||||
[C: src/ai_server.py:_handle_send, tests/test_tier4_interceptor.py:test_gemini_provider_passes_qa_callback_to_run_script]
|
||||
"""
|
||||
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths
|
||||
types = _require_warmed("google.genai.types")
|
||||
monitor = performance_monitor.get_monitor()
|
||||
if monitor.enabled: monitor.start_component("ai_client._send_gemini")
|
||||
try:
|
||||
@@ -1782,6 +1814,7 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
||||
"""
|
||||
[C: src/ai_server.py:_handle_send]
|
||||
"""
|
||||
requests = _require_warmed("requests")
|
||||
monitor = performance_monitor.get_monitor()
|
||||
if monitor.enabled: monitor.start_component("ai_client._send_deepseek")
|
||||
try:
|
||||
@@ -2033,6 +2066,8 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
||||
|
||||
def _list_minimax_models(api_key: str) -> list[str]:
|
||||
try:
|
||||
openai = _require_warmed("openai")
|
||||
OpenAI = openai.OpenAI
|
||||
client = OpenAI(api_key=api_key, base_url="https://api.minimax.io/v1")
|
||||
models_list = client.models.list()
|
||||
found = [m.id for m in models_list]
|
||||
@@ -2093,6 +2128,7 @@ def _trim_minimax_history(system_blocks: list[dict[str, Any]], history: list[dic
|
||||
|
||||
def _ensure_minimax_client() -> None:
|
||||
global _minimax_client
|
||||
openai = _require_warmed("openai")
|
||||
if _minimax_client is None:
|
||||
creds = _load_credentials()
|
||||
api_key = creds.get("minimax", {}).get("api_key")
|
||||
@@ -2111,6 +2147,8 @@ def _send_minimax(md_content: str, user_message: str, base_dir: str,
|
||||
"""
|
||||
[C: src/ai_server.py:_handle_send]
|
||||
"""
|
||||
openai = _require_warmed("openai")
|
||||
requests = _require_warmed("requests")
|
||||
try:
|
||||
mcp_client.configure(file_items or [], [base_dir])
|
||||
creds = _load_credentials()
|
||||
@@ -2332,6 +2370,7 @@ def _send_minimax(md_content: str, user_message: str, base_dir: str,
|
||||
def run_tier4_analysis(stderr: str) -> str:
|
||||
"""
|
||||
"""
|
||||
types = _require_warmed("google.genai.types")
|
||||
if not stderr or not stderr.strip():
|
||||
return ""
|
||||
try:
|
||||
@@ -2381,6 +2420,7 @@ def run_tier4_patch_generation(error: str, file_context: str) -> str:
|
||||
"""
|
||||
[C: src/gui_2.py:App.request_patch_from_tier4, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_calls_ai, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_empty_error, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_returns_diff]
|
||||
"""
|
||||
types = _require_warmed("google.genai.types")
|
||||
if not error or not error.strip():
|
||||
return ""
|
||||
try:
|
||||
@@ -2537,6 +2577,8 @@ def run_subagent_summarization(file_path: str, content: str, is_code: bool, outl
|
||||
"""
|
||||
[C: src/summarize.py:summarise_file, tests/test_subagent_summarization.py:test_run_subagent_summarization_anthropic, tests/test_subagent_summarization.py:test_run_subagent_summarization_gemini]
|
||||
"""
|
||||
requests = _require_warmed("requests")
|
||||
types = _require_warmed("google.genai.types")
|
||||
prompt_tmpl = mma_prompts.TIER4_SUMMARIZE_CODE_PROMPT if is_code else mma_prompts.TIER4_SUMMARIZE_TEXT_PROMPT
|
||||
prompt = prompt_tmpl.format(file_path=file_path, outline=outline, content=content)
|
||||
if _provider == "gemini":
|
||||
@@ -2584,6 +2626,8 @@ def run_subagent_summarization(file_path: str, content: str, is_code: bool, outl
|
||||
return "ERROR: Unsupported provider for sub-agent summarization"
|
||||
|
||||
def run_discussion_compression(discussion_text: str) -> str:
|
||||
types = _require_warmed("google.genai.types")
|
||||
requests = _require_warmed("requests")
|
||||
# Robustly identify the provider string (handles case and whitespace)
|
||||
p = str(get_provider()).lower().strip()
|
||||
prompt = f"The following is a long conversation history.\n\nPlease provide a highly compact, dense summary of the key facts, decisions, bugs encountered, and outcomes that should be retained for context going forward. Categorize into User intent, Tool outputs, and AI reasoning. Omit pleasantries and redundant thoughts.\n\n[HISTORY]\n{discussion_text}"
|
||||
|
||||
@@ -36,31 +36,33 @@ def test_run_tier4_patch_generation_empty_error() -> None:
|
||||
|
||||
def test_run_tier4_patch_generation_calls_ai() -> None:
|
||||
"""Test that run_tier4_patch_generation calls the AI with the correct prompt."""
|
||||
mock_types = MagicMock()
|
||||
mock_types.GenerateContentConfig = MagicMock()
|
||||
with patch("src.ai_client._ensure_gemini_client"), \
|
||||
patch("src.ai_client._gemini_client", create=True) as mock_client, \
|
||||
patch("src.ai_client.types") as mock_types:
|
||||
patch("src.ai_client._require_warmed", return_value=mock_types):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.text = "--- a/test.py\n+++ b/test.py\n@@ -1 +1 @@\n-old\n+new"
|
||||
mock_client.models.generate_content.return_value = mock_resp
|
||||
mock_types.GenerateContentConfig = MagicMock()
|
||||
|
||||
|
||||
error = "TypeError: unsupported operand"
|
||||
file_context = "def foo():\n pass"
|
||||
result = ai_client.run_tier4_patch_generation(error, file_context)
|
||||
|
||||
|
||||
mock_client.models.generate_content.assert_called()
|
||||
|
||||
def test_run_tier4_patch_generation_returns_diff() -> None:
|
||||
"""Test that run_tier4_patch_generation returns diff text."""
|
||||
mock_types = MagicMock()
|
||||
mock_types.GenerateContentConfig = MagicMock()
|
||||
with patch("src.ai_client._ensure_gemini_client"), \
|
||||
patch("src.ai_client._gemini_client", create=True) as mock_client, \
|
||||
patch("src.ai_client.types") as mock_types:
|
||||
patch("src.ai_client._require_warmed", return_value=mock_types):
|
||||
expected_diff = "--- a/src/test.py\n+++ b/src/test.py\n@@ -10,5 +10,6 @@\n def test_func():\n- old_value = 1\n+ old_value = 1\n+ new_value = 2"
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.text = expected_diff
|
||||
mock_client.models.generate_content.return_value = mock_resp
|
||||
mock_types.GenerateContentConfig = MagicMock()
|
||||
|
||||
|
||||
result = ai_client.run_tier4_patch_generation("error", "context")
|
||||
assert "---" in result
|
||||
assert "+++" in result
|
||||
|
||||
Reference in New Issue
Block a user