37ece145fa
Per styleguide §7.6 Pattern 1: 'catch + convert + raise as different type'
requires 'raise X from e' to preserve the original exception in the
traceback.
Sites updated:
Site 1 (L277 _load_credentials):
except FileNotFoundError as e:
raise FileNotFoundError(f'...') from e
Sites 2+3 (L878+L879 _default_send, nested in run_with_tool_loop):
if not res.ok:
raise res.errors[0].original from None
raise RuntimeError(...) from None
The exceptions come from a Result, not a local except; 'from None'
suppresses the implicit context.
Site 5 (L2061 _send inside _send_gemini_cli):
raise cast(Exception, send_result.errors[0].original) from None
Site 6 (L2742 _dashscope_call):
raise classify_dashscope_error(_dashscope_exception_from_response(resp)) from None
KNOWN LIMITATION: the audit script does not have a heuristic for
'raise X from e' / 'from None' (Pattern 1). The sites remain
INTERNAL_RETHROW in the audit. INTERNAL_RETHROW is 'suspicious but
not violation' (strict mode accepts). Adding a heuristic requires
Tier 1 approval per the conventions.
Audit: ai_client RETHROW 6 -> 5 (site 4 migrated separately; these
4 sites stay as INTERNAL_RETHROW by audit classification but follow
Pattern 1 by styleguide).
3477 lines
148 KiB
Python
3477 lines
148 KiB
Python
# ai_client.py
|
|
from __future__ import annotations
|
|
"""
|
|
Note(Gemini):
|
|
Acts as the unified interface for multiple LLM providers (Anthropic, Gemini).
|
|
Abstracts away the differences in how they handle tool schemas, history, and caching.
|
|
|
|
For Anthropic: aggressively manages the ~200k token limit by manually culling
|
|
stale [FILES UPDATED] entries and dropping the oldest message pairs.
|
|
|
|
For Gemini: injects the initial context directly into system_instruction
|
|
during chat creation to avoid massive history bloat.
|
|
|
|
HEAVY IMPORTS (startup_speedup_20260606): The heavy SDKs (anthropic,
|
|
google.genai, openai, google.genai.types, requests) are NOT imported
|
|
at module level. They are warmed on AppController's _io_pool at
|
|
startup and accessed via _require_warmed() below. This keeps the
|
|
main thread's import chain lean and the GUI responsive on startup.
|
|
"""
|
|
|
|
import importlib
|
|
import asyncio
|
|
import datetime
|
|
import difflib
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import sys
|
|
import threading
|
|
import time
|
|
import tomllib
|
|
|
|
# TODO(Ed): Eliminate These?
|
|
from collections import deque
|
|
from pathlib import Path as _P
|
|
from pathlib import Path
|
|
from typing import Optional, Callable, Any, List, Union, cast, Iterable
|
|
|
|
from src import project_manager
|
|
from src import file_cache
|
|
from src import mcp_client
|
|
from src import mma_prompts
|
|
from src import performance_monitor
|
|
from src import project_manager
|
|
from src.vendor_capabilities import VendorCapabilities, get_capabilities
|
|
|
|
# TODO(Ed): Eliminate these?
|
|
from src.events import EventEmitter
|
|
from src.gemini_cli_adapter import GeminiCliAdapter
|
|
from src.models import ToolPreset, BiasProfile, Tool
|
|
from src.paths import get_credentials_path
|
|
from src.tool_bias import ToolBiasEngine
|
|
from src.tool_presets import ToolPresetManager
|
|
from src.tool_presets import ToolPresetManager
|
|
|
|
PROVIDERS: List[str] = ["gemini", "anthropic", "gemini_cli", "deepseek", "minimax", "qwen", "grok", "llama"]
|
|
|
|
# _require_warmed lives
|
|
# _require_warmed lives in src/module_loader.py to avoid duplicating the
|
|
# lookup logic across files that need heavy modules. Re-exported here so
|
|
# existing call sites and the T3.1 test (which asserts
|
|
# hasattr(src.ai_client, '_require_warmed')) continue to work.
|
|
from src.module_loader import _require_warmed # noqa: E402,F401
|
|
from src.result_types import ErrorInfo, ErrorKind, Result # noqa: E402,F401
|
|
|
|
_provider: str = "gemini"
|
|
_model: str = "gemini-2.5-flash-lite"
|
|
_temperature: float = 0.0
|
|
_top_p: float = 1.0
|
|
_max_tokens: int = 8192
|
|
|
|
_history_trunc_limit: int = 8000
|
|
|
|
# Global event emitter for API lifecycle events
|
|
events: EventEmitter = EventEmitter()
|
|
|
|
#region: Provider Configuration
|
|
|
|
def set_model_params(temp: float, max_tok: int, trunc_limit: int = 8000, top_p: float = 1.0) -> None:
|
|
"""Sets global generation parameters like temperature and max tokens."""
|
|
global _temperature, _max_tokens, _history_trunc_limit, _top_p
|
|
_temperature = temp
|
|
_max_tokens = max_tok
|
|
_history_trunc_limit = trunc_limit
|
|
_top_p = top_p
|
|
|
|
_gemini_client: Optional[genai.Client] = None
|
|
_gemini_chat: Any = None
|
|
_gemini_cache: Any = None
|
|
_gemini_cache_md_hash: Optional[str] = None
|
|
_gemini_cache_created_at: Optional[float] = None
|
|
_gemini_cached_file_paths: list[str] = []
|
|
|
|
# Gemini cache TTL in seconds. Caches are created with this TTL and
|
|
# proactively rebuilt at 90% of this value to avoid stale-reference errors.
|
|
_GEMINI_CACHE_TTL: int = 3600
|
|
|
|
_anthropic_client: Optional[anthropic.Anthropic] = None
|
|
_anthropic_history: list[dict[str, Any]] = []
|
|
_anthropic_history_lock: threading.Lock = threading.Lock()
|
|
|
|
_deepseek_client: Any = None
|
|
_deepseek_history: list[dict[str, Any]] = []
|
|
_deepseek_history_lock: threading.Lock = threading.Lock()
|
|
|
|
_minimax_client: Any = None
|
|
_minimax_history: list[dict[str, Any]] = []
|
|
_minimax_history_lock: threading.Lock = threading.Lock()
|
|
|
|
_qwen_client: Any = None
|
|
_qwen_history: list[dict[str, Any]] = []
|
|
_qwen_history_lock: threading.Lock = threading.Lock()
|
|
_qwen_region: str = "china"
|
|
|
|
_grok_client: Any = None
|
|
_grok_history: list[dict[str, Any]] = []
|
|
_grok_history_lock: threading.Lock = threading.Lock()
|
|
|
|
_llama_client: Any = None
|
|
_llama_history: list[dict[str, Any]] = []
|
|
_llama_history_lock: threading.Lock = threading.Lock()
|
|
_llama_base_url: str = "http://localhost:11434/v1"
|
|
_llama_api_key: str = "ollama"
|
|
|
|
_send_lock: threading.Lock = threading.Lock()
|
|
|
|
_BIAS_ENGINE = ToolBiasEngine()
|
|
_active_tool_preset: Optional[ToolPreset] = None
|
|
_active_bias_profile: Optional[BiasProfile] = None
|
|
|
|
_gemini_cli_adapter: Optional[GeminiCliAdapter] = None
|
|
|
|
# Injected by gui.py - called when AI wants to run a command.
|
|
confirm_and_run_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]], Optional[Callable[[str, str], Optional[str]]]], Optional[str]]] = None
|
|
|
|
# Injected by gui.py - called whenever a comms entry is appended.
|
|
# Use get_comms_log_callback/set_comms_log_callback for thread-safe access.
|
|
comms_log_callback: Optional[Callable[[dict[str, Any]], None]] = None
|
|
|
|
# Injected by gui.py - called whenever a tool call completes.
|
|
tool_log_callback: Optional[Callable[[str, str], None]] = None
|
|
|
|
_local_storage = threading.local()
|
|
|
|
_tool_approval_modes: dict[str, str] = {}
|
|
|
|
def get_current_tier() -> Optional[str]:
|
|
"""Returns the current tier from thread-local storage."""
|
|
return getattr(_local_storage, "current_tier", None)
|
|
|
|
def set_current_tier(tier: Optional[str]) -> None:
|
|
"""Sets the current tier in thread-local storage."""
|
|
_local_storage.current_tier = tier
|
|
|
|
# Increased to allow thorough code exploration before forcing a summary
|
|
MAX_TOOL_ROUNDS: int = 10
|
|
|
|
# Maximum cumulative bytes of tool output allowed per send() call.
|
|
_MAX_TOOL_OUTPUT_BYTES: int = 500_000
|
|
|
|
# Maximum characters per text chunk sent to Anthropic.
|
|
_ANTHROPIC_CHUNK_SIZE: int = 120_000
|
|
|
|
_SYSTEM_PROMPT: str = (
|
|
"You are a helpful coding assistant with access to a PowerShell tool (run_powershell) and MCP tools (file access: read_file, list_directory, search_files, get_file_summary, web access: web_search, fetch_url). "
|
|
"When calling file/directory tools, always use the 'path' parameter for the target path. "
|
|
"When asked to create or edit files, prefer targeted edits over full rewrites. "
|
|
"Always explain what you are doing before invoking the tool.\n\n"
|
|
"When writing or rewriting large files (especially those containing quotes, backticks, or special characters), "
|
|
"avoid python -c with inline strings. Instead: (1) write a .py helper script to disk using a PS here-string "
|
|
"(@'...'@ for literal content), (2) run it with `python <script>`, (3) delete the helper. "
|
|
"For small targeted edits, use PowerShell's (Get-Content) / .Replace() / Set-Content or Add-Content directly.\n\n"
|
|
"When making function calls using tools that accept array or object parameters "
|
|
"ensure those are structured using JSON. For example:\n"
|
|
"When you need to verify a change, rely on the exit code and stdout/stderr from the tool \u2014 "
|
|
"the user's context files are automatically refreshed after every tool call, so you do NOT "
|
|
"need to re-read files that are already provided in the <context> block."
|
|
)
|
|
|
|
_custom_system_prompt: str = ""
|
|
_base_system_prompt_override: str = ""
|
|
_use_default_base_system_prompt: bool = True
|
|
_project_context_marker: str = ""
|
|
|
|
#endregion: Provider Configuration
|
|
|
|
#region: System Prompt Management
|
|
|
|
def set_custom_system_prompt(prompt: str) -> None:
|
|
"""Sets a custom system prompt to be combined with the default instructions."""
|
|
global _custom_system_prompt
|
|
_custom_system_prompt = prompt
|
|
|
|
def set_base_system_prompt(prompt: str) -> None:
|
|
global _base_system_prompt_override
|
|
_base_system_prompt_override = prompt
|
|
|
|
def set_use_default_base_prompt(use_default: bool) -> None:
|
|
global _use_default_base_system_prompt
|
|
_use_default_base_system_prompt = use_default
|
|
|
|
def set_project_context_marker(marker: str) -> None:
|
|
global _project_context_marker
|
|
_project_context_marker = marker
|
|
|
|
def _get_context_marker() -> str:
|
|
return _project_context_marker if _project_context_marker.strip() else "[SYSTEM: FILES UPDATED]"
|
|
|
|
def _get_combined_system_prompt(preset: Optional[ToolPreset] = None, bias: Optional[BiasProfile] = None) -> str:
|
|
if preset is None: preset = _active_tool_preset
|
|
if bias is None: bias = _active_bias_profile
|
|
if _use_default_base_system_prompt:
|
|
base = _SYSTEM_PROMPT
|
|
else:
|
|
base = _base_system_prompt_override
|
|
if _custom_system_prompt.strip():
|
|
base = f"{base}\n\n[USER SYSTEM PROMPT]\n{_custom_system_prompt}"
|
|
if preset and bias:
|
|
strategy = _BIAS_ENGINE.generate_tooling_strategy(preset, bias)
|
|
if strategy:
|
|
base += f"\n\n{strategy}"
|
|
return base
|
|
|
|
def get_combined_system_prompt(preset: Optional[ToolPreset] = None, bias: Optional[BiasProfile] = None) -> str:
|
|
return _get_combined_system_prompt(preset, bias)
|
|
|
|
_comms_log: deque[dict[str, Any]] = deque(maxlen=1000)
|
|
|
|
COMMS_CLAMP_CHARS: int = 300
|
|
|
|
#endregion: System Prompt Management
|
|
|
|
#region: Comms Log
|
|
|
|
def get_comms_log_callback() -> Optional[Callable[[dict[str, Any]], None]]:
|
|
tl_cb = getattr(_local_storage, "comms_log_callback", None)
|
|
if tl_cb: return tl_cb
|
|
return comms_log_callback
|
|
|
|
def set_comms_log_callback(cb: Optional[Callable[[dict[str, Any]], None]]) -> None:
|
|
global comms_log_callback
|
|
comms_log_callback = cb
|
|
_local_storage.comms_log_callback = cb
|
|
|
|
def _append_comms(direction: str, kind: str, payload: dict[str, Any]) -> None:
|
|
entry: dict[str, Any] = {
|
|
"ts": datetime.datetime.now().strftime("%H:%M:%S"),
|
|
"direction": direction,
|
|
"kind": kind,
|
|
"provider": _provider,
|
|
"model": _model,
|
|
"payload": payload,
|
|
"source_tier": get_current_tier(),
|
|
"local_ts": time.time(),
|
|
}
|
|
_comms_log.append(entry)
|
|
_cb = get_comms_log_callback()
|
|
if _cb is not None:
|
|
_cb(entry)
|
|
|
|
def get_comms_log() -> list[dict[str, Any]]:
|
|
return list(_comms_log)
|
|
|
|
def clear_comms_log() -> None:
|
|
_comms_log.clear()
|
|
|
|
def get_credentials_path() -> Path:
|
|
return Path(os.environ.get("SLOP_CREDENTIALS", str(Path(__file__).parent.parent / "credentials.toml")))
|
|
|
|
def _load_credentials() -> dict[str, Any]:
|
|
cred_path = get_credentials_path()
|
|
try:
|
|
with open(cred_path, "rb") as f:
|
|
return tomllib.load(f)
|
|
except FileNotFoundError as e:
|
|
raise FileNotFoundError(
|
|
f"Credentials file not found: {cred_path}\n"
|
|
f"Create a credentials.toml with:\n"
|
|
f" [gemini]\n api_key = \"your-key\"\n"
|
|
f" [anthropic]\n api_key = \"your-key\"\n"
|
|
f" [deepseek]\n api_key = \"your-key\"\n"
|
|
f" [minimax]\n api_key = \"your-key\"\n"
|
|
f"Or set SLOP_CREDENTIALS env var to a custom path."
|
|
) from e
|
|
|
|
def _try_warm_sdk_result(name: str) -> Result[Any]:
|
|
"""Try to get a warmed SDK module. Returns Result[Any].
|
|
|
|
Lazy-loading sentinel: the caller checks result.ok and uses result.data
|
|
on success. On failure, returns Result(errors=[ErrorInfo]). The caller
|
|
falls back to body-string matching, preserving the original behavior.
|
|
Per Phase 11 anti-sliming protocol: NOT a sentinel-None return; the
|
|
caller observes the Result explicitly.
|
|
"""
|
|
try:
|
|
return Result(data=_require_warmed(name))
|
|
except (ImportError, AttributeError) as e:
|
|
return Result(
|
|
data=None,
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"SDK module '{name}' unavailable: {e}", source=f"ai_client._try_warm_sdk_result", original=e)],
|
|
)
|
|
|
|
def _classify_anthropic_error(exc: Exception, source: str = "ai_client.anthropic") -> ErrorInfo:
|
|
sdk_result = _try_warm_sdk_result("anthropic")
|
|
if sdk_result.ok:
|
|
anthropic = sdk_result.data
|
|
if isinstance(exc, anthropic.RateLimitError): return ErrorInfo(kind=ErrorKind.RATE_LIMIT, message=str(exc), source=source, original=exc)
|
|
if isinstance(exc, anthropic.AuthenticationError): return ErrorInfo(kind=ErrorKind.AUTH, message=str(exc), source=source, original=exc)
|
|
if isinstance(exc, anthropic.PermissionDeniedError): return ErrorInfo(kind=ErrorKind.AUTH, message=str(exc), source=source, original=exc)
|
|
if isinstance(exc, anthropic.APIConnectionError): return ErrorInfo(kind=ErrorKind.NETWORK, message=str(exc), source=source, original=exc)
|
|
if isinstance(exc, anthropic.APIStatusError):
|
|
status = getattr(exc, "status_code", 0)
|
|
body = str(exc).lower()
|
|
if status == 429: return ErrorInfo(kind=ErrorKind.RATE_LIMIT, message=str(exc), source=source, original=exc)
|
|
if status in (401, 403): return ErrorInfo(kind=ErrorKind.AUTH, message=str(exc), source=source, original=exc)
|
|
if status == 402: return ErrorInfo(kind=ErrorKind.BALANCE, message=str(exc), source=source, original=exc)
|
|
if "credit" in body or "balance" in body or "billing" in body: return ErrorInfo(kind=ErrorKind.BALANCE, message=str(exc), source=source, original=exc)
|
|
if "quota" in body or "limit" in body or "exceeded" in body: return ErrorInfo(kind=ErrorKind.QUOTA, message=str(exc), source=source, original=exc)
|
|
return ErrorInfo(kind=ErrorKind.UNKNOWN, message=str(exc), source=source, original=exc)
|
|
|
|
def _classify_gemini_error(exc: Exception, source: str = "ai_client.gemini") -> ErrorInfo:
|
|
body = str(exc).lower()
|
|
sdk_result = _try_warm_sdk_result("google.api_core.exceptions")
|
|
if sdk_result.ok:
|
|
gac = sdk_result.data
|
|
if isinstance(exc, gac.ResourceExhausted): return ErrorInfo(kind=ErrorKind.QUOTA, message=str(exc), source=source, original=exc)
|
|
if isinstance(exc, gac.TooManyRequests): return ErrorInfo(kind=ErrorKind.RATE_LIMIT, message=str(exc), source=source, original=exc)
|
|
if isinstance(exc, (gac.Unauthenticated, gac.PermissionDenied)): return ErrorInfo(kind=ErrorKind.AUTH, message=str(exc), source=source, original=exc)
|
|
if isinstance(exc, gac.ServiceUnavailable): return ErrorInfo(kind=ErrorKind.NETWORK, message=str(exc), source=source, original=exc)
|
|
if "429" in body or "quota" in body or "resource exhausted" in body: return ErrorInfo(kind=ErrorKind.QUOTA, message=str(exc), source=source, original=exc)
|
|
if "rate" in body and "limit" in body: return ErrorInfo(kind=ErrorKind.RATE_LIMIT, message=str(exc), source=source, original=exc)
|
|
if "401" in body or "403" in body or "api key" in body or "unauthenticated" in body: return ErrorInfo(kind=ErrorKind.AUTH, message=str(exc), source=source, original=exc)
|
|
if "402" in body or "billing" in body or "balance" in body or "payment" in body: return ErrorInfo(kind=ErrorKind.BALANCE, message=str(exc), source=source, original=exc)
|
|
if "connection" in body or "timeout" in body or "unreachable" in body: return ErrorInfo(kind=ErrorKind.NETWORK, message=str(exc), source=source, original=exc)
|
|
return ErrorInfo(kind=ErrorKind.UNKNOWN, message=str(exc), source=source, original=exc)
|
|
|
|
def _classify_deepseek_error(exc: Exception, source: str = "ai_client.deepseek") -> ErrorInfo:
|
|
requests = _require_warmed("requests")
|
|
body = ""
|
|
if isinstance(exc, requests.exceptions.HTTPError) and exc.response is not None:
|
|
try:
|
|
# Try to get the detailed error from DeepSeek's JSON response
|
|
err_data = exc.response.json()
|
|
if "error" in err_data: body = str(err_data["error"].get("message", exc.response.text))
|
|
else: body = exc.response.text
|
|
except (ValueError, AttributeError) as e:
|
|
# JSON parse failed; cannot classify specific error codes.
|
|
# Return structured UNKNOWN error with original exception preserved.
|
|
return ErrorInfo(kind=ErrorKind.UNKNOWN, message=exc.response.text, source=source, original=e)
|
|
else:
|
|
body = str(exc)
|
|
|
|
body_l = body.lower()
|
|
if "429" in body_l or "rate" in body_l: return ErrorInfo(kind=ErrorKind.RATE_LIMIT, message=body, source=source, original=exc)
|
|
if "401" in body_l or "403" in body_l or "auth" in body_l or "api key" in body_l: return ErrorInfo(kind=ErrorKind.AUTH, message=body, source=source, original=exc)
|
|
if "402" in body_l or "balance" in body_l or "billing" in body_l: return ErrorInfo(kind=ErrorKind.BALANCE, message=body, source=source, original=exc)
|
|
if "quota" in body_l or "limit exceeded" in body_l: return ErrorInfo(kind=ErrorKind.QUOTA, message=body, source=source, original=exc)
|
|
if "connection" in body_l or "timeout" in body_l or "network" in body_l: return ErrorInfo(kind=ErrorKind.NETWORK, message=body, source=source, original=exc)
|
|
# If we have a body for a 400 error, wrap it
|
|
if "400" in body_l or "bad request" in body_l: return ErrorInfo(kind=ErrorKind.UNKNOWN, message=f"DeepSeek Bad Request: {body}", source=source, original=exc)
|
|
return ErrorInfo(kind=ErrorKind.UNKNOWN, message=body, source=source, original=exc)
|
|
|
|
def _classify_minimax_error(exc: Exception, source: str = "ai_client.minimax") -> ErrorInfo:
|
|
requests = _require_warmed("requests")
|
|
body = ""
|
|
if isinstance(exc, requests.exceptions.HTTPError) and exc.response is not None:
|
|
try:
|
|
err_data = exc.response.json()
|
|
if "error" in err_data: body = str(err_data["error"].get("message", exc.response.text))
|
|
else: body = exc.response.text
|
|
except (ValueError, AttributeError) as e:
|
|
return ErrorInfo(kind=ErrorKind.UNKNOWN, message=exc.response.text, source=source, original=e)
|
|
else:
|
|
body = str(exc)
|
|
|
|
body_l = body.lower()
|
|
if "429" in body_l or "rate" in body_l: return ErrorInfo(kind=ErrorKind.RATE_LIMIT, message=body, source=source, original=exc)
|
|
if "401" in body_l or "403" in body_l or "auth" in body_l or "api key" in body_l: return ErrorInfo(kind=ErrorKind.AUTH, message=body, source=source, original=exc)
|
|
if "402" in body_l or "balance" in body_l or "billing" in body_l: return ErrorInfo(kind=ErrorKind.BALANCE, message=body, source=source, original=exc)
|
|
if "quota" in body_l or "limit exceeded" in body_l: return ErrorInfo(kind=ErrorKind.QUOTA, message=body, source=source, original=exc)
|
|
if "connection" in body_l or "timeout" in body_l or "network" in body_l: return ErrorInfo(kind=ErrorKind.NETWORK, message=body, source=source, original=exc)
|
|
|
|
if "400" in body_l or "bad request" in body_l: return ErrorInfo(kind=ErrorKind.UNKNOWN, message=f"MiniMax Bad Request: {body}", source=source, original=exc)
|
|
return ErrorInfo(kind=ErrorKind.UNKNOWN, message=body, source=source, original=exc)
|
|
|
|
def _set_minimax_provider_result(model: str) -> Result[list[str]]:
|
|
"""Load minimax credentials and fetch the list of valid models.
|
|
|
|
Returns the list of valid model names. On credentials load failure,
|
|
returns Result(data=[], errors=[ErrorInfo(...)]). The legacy caller
|
|
(set_provider) inspects result.ok to decide whether to use the
|
|
fetched list or fall back to _list_minimax_models("") for empty key.
|
|
"""
|
|
try:
|
|
creds = _load_credentials()
|
|
api_key = creds.get("minimax", {}).get("api_key", "")
|
|
return Result(data=_list_minimax_models(api_key))
|
|
except (OSError, ValueError) as e:
|
|
return Result(
|
|
data=[],
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to load minimax credentials: {e}", source="ai_client._set_minimax_provider_result", original=e)],
|
|
)
|
|
|
|
|
|
def set_provider(provider: str, model: str, validate: bool = True) -> None:
|
|
"""Updates the active LLM provider and model name.
|
|
|
|
When validate is True (default), the model is checked against the provider's
|
|
LIVE model list, which for gemini_cli/minimax means a blocking subprocess /
|
|
network call (and importing the provider SDK). Pass validate=False during
|
|
startup so the GUI's first frame is not blocked — AppController._fetch_models
|
|
corrects the model against the live list shortly after, off the main thread.
|
|
"""
|
|
global _provider, _model
|
|
_provider = provider
|
|
if not validate:
|
|
_model = model
|
|
return
|
|
if provider == "gemini_cli":
|
|
valid_models = _list_gemini_cli_models()
|
|
if model != "mock" and (model not in valid_models or model.startswith("deepseek")):
|
|
_model = "gemini-3-flash-preview"
|
|
else:
|
|
_model = model
|
|
elif provider == "minimax":
|
|
result = _set_minimax_provider_result(model)
|
|
valid_models = result.data if result.ok else _list_minimax_models("")
|
|
if model not in valid_models:
|
|
_model = "MiniMax-M2.5"
|
|
else:
|
|
_model = model
|
|
else:
|
|
_model = model
|
|
|
|
def get_provider() -> str:
|
|
"""Returns the current active provider name."""
|
|
return _provider
|
|
|
|
def cleanup() -> None:
|
|
"""Performs cleanup operations like deleting server-side Gemini caches."""
|
|
global _gemini_client, _gemini_cache, _gemini_cached_file_paths
|
|
if _gemini_client and _gemini_cache:
|
|
_delete_gemini_cache_result()
|
|
_gemini_cached_file_paths = []
|
|
|
|
def reset_session() -> None:
|
|
"""Clears conversation history and resets provider-specific session state."""
|
|
global _gemini_client, _gemini_chat, _gemini_cache
|
|
global _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths
|
|
global _anthropic_client, _anthropic_history
|
|
global _deepseek_client, _deepseek_history
|
|
global _minimax_client, _minimax_history
|
|
global _qwen_client, _qwen_history
|
|
global _CACHED_ANTHROPIC_TOOLS, _CACHED_DEEPSEEK_TOOLS
|
|
global _gemini_cli_adapter
|
|
if _gemini_client and _gemini_cache:
|
|
_delete_gemini_cache_result()
|
|
_gemini_client = None
|
|
_gemini_chat = None
|
|
_gemini_cache = None
|
|
_gemini_cache_md_hash = None
|
|
_gemini_cache_created_at = None
|
|
_gemini_cached_file_paths = []
|
|
|
|
# Preserve binary_path if adapter exists
|
|
old_path = _gemini_cli_adapter.binary_path if _gemini_cli_adapter else "gemini"
|
|
_gemini_cli_adapter = GeminiCliAdapter(binary_path=old_path)
|
|
|
|
_anthropic_client = None
|
|
with _anthropic_history_lock:
|
|
_anthropic_history = []
|
|
_deepseek_client = None
|
|
with _deepseek_history_lock:
|
|
_deepseek_history = []
|
|
_minimax_client = None
|
|
with _minimax_history_lock:
|
|
_minimax_history = []
|
|
_qwen_client = None
|
|
with _qwen_history_lock:
|
|
_qwen_history = []
|
|
_grok_client = None
|
|
with _grok_history_lock:
|
|
_grok_history = []
|
|
_llama_client = None
|
|
with _llama_history_lock:
|
|
_llama_history = []
|
|
_llama_base_url = "http://localhost:11434/v1"
|
|
_llama_api_key = "ollama"
|
|
_CACHED_ANTHROPIC_TOOLS = None
|
|
_CACHED_DEEPSEEK_TOOLS = None
|
|
file_cache.reset_client()
|
|
|
|
def list_models(provider: str) -> list[str]:
|
|
creds = _load_credentials()
|
|
if provider == "gemini": return _list_gemini_models(creds["gemini"]["api_key"])
|
|
elif provider == "anthropic": return _list_anthropic_models()
|
|
elif provider == "deepseek": return _list_deepseek_models(creds["deepseek"]["api_key"])
|
|
elif provider == "gemini_cli": return _list_gemini_cli_models()
|
|
elif provider == "minimax": return _list_minimax_models(creds["minimax"]["api_key"])
|
|
elif provider == "qwen": return _list_qwen_models()
|
|
elif provider == "grok": return _list_grok_models()
|
|
elif provider == "llama": return _list_llama_models()
|
|
return []
|
|
|
|
#endregion: Comms Log
|
|
|
|
TOOL_NAME: str = "run_powershell"
|
|
|
|
_agent_tools: dict[str, bool] = {}
|
|
|
|
#region: Tool Configuration
|
|
|
|
def set_agent_tools(tools: dict[str, bool]) -> None:
|
|
"""Configures which tools are enabled for the AI agent."""
|
|
global _agent_tools, _CACHED_ANTHROPIC_TOOLS, _CACHED_DEEPSEEK_TOOLS
|
|
_agent_tools = tools
|
|
_CACHED_ANTHROPIC_TOOLS = None
|
|
_CACHED_DEEPSEEK_TOOLS = None
|
|
|
|
def _set_tool_preset_result(preset_name: Optional[str]) -> Result[None]:
|
|
"""Load a tool preset by name and apply it. Returns Result[None].
|
|
|
|
On I/O or parsing failure, returns Result(data=None, errors=[ErrorInfo])
|
|
capturing the original exception. The legacy caller (set_tool_preset)
|
|
calls this helper for the load step; on Result errors, the caller still
|
|
completes (state remains partially-set; the cache invalidation runs).
|
|
"""
|
|
if not preset_name or preset_name == "None":
|
|
return Result(data=None)
|
|
try:
|
|
manager = ToolPresetManager()
|
|
presets = manager.load_all()
|
|
if preset_name in presets:
|
|
preset = presets[preset_name]
|
|
_active_tool_preset = preset
|
|
new_tools = {name: False for name in mcp_client.TOOL_NAMES}
|
|
new_tools[TOOL_NAME] = False
|
|
for cat in preset.categories.values():
|
|
for tool in cat:
|
|
name = tool.name
|
|
new_tools[name] = True
|
|
_tool_approval_modes[name] = tool.approval
|
|
_agent_tools = new_tools
|
|
return Result(data=None)
|
|
except (OSError, ValueError, AttributeError) as e:
|
|
return Result(
|
|
data=None,
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to set tool preset '{preset_name}': {e}", source="ai_client._set_tool_preset_result", original=e)],
|
|
)
|
|
|
|
|
|
def set_tool_preset(preset_name: Optional[str]) -> None:
|
|
"""Loads a tool preset and applies it via set_agent_tools."""
|
|
global _agent_tools, _CACHED_ANTHROPIC_TOOLS, _CACHED_DEEPSEEK_TOOLS, _tool_approval_modes, _active_tool_preset
|
|
_tool_approval_modes = {}
|
|
if not preset_name or preset_name == "None":
|
|
# Enable all tools if no preset
|
|
_agent_tools = {name: True for name in mcp_client.TOOL_NAMES}
|
|
_agent_tools[TOOL_NAME] = True
|
|
_active_tool_preset = None
|
|
else:
|
|
_set_tool_preset_result(preset_name)
|
|
_CACHED_ANTHROPIC_TOOLS = None
|
|
_CACHED_DEEPSEEK_TOOLS = None
|
|
|
|
def _set_bias_profile_result(profile_name: Optional[str]) -> Result[None]:
|
|
"""Load a bias profile by name and apply it. Returns Result[None].
|
|
|
|
On I/O or parsing failure, returns Result(data=None, errors=[ErrorInfo]).
|
|
The legacy caller (set_bias_profile) delegates to this helper.
|
|
"""
|
|
if not profile_name or profile_name == "None":
|
|
return Result(data=None)
|
|
try:
|
|
manager = ToolPresetManager()
|
|
profiles = manager.load_all_bias_profiles()
|
|
if profile_name in profiles:
|
|
_active_bias_profile = profiles[profile_name]
|
|
return Result(data=None)
|
|
except (OSError, ValueError, AttributeError) as e:
|
|
return Result(
|
|
data=None,
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to set bias profile '{profile_name}': {e}", source="ai_client._set_bias_profile_result", original=e)],
|
|
)
|
|
|
|
|
|
def set_bias_profile(profile_name: Optional[str]) -> None:
|
|
"""Sets the active tool bias profile for tuning model behavior."""
|
|
global _active_bias_profile
|
|
if not profile_name or profile_name == "None":
|
|
_active_bias_profile = None
|
|
else:
|
|
_set_bias_profile_result(profile_name)
|
|
|
|
def get_bias_profile() -> Optional[str]:
|
|
"""Returns the name of the currently active bias profile."""
|
|
return _active_bias_profile.name if _active_bias_profile else None
|
|
|
|
def _build_anthropic_tools() -> list[dict[str, Any]]:
|
|
"""
|
|
[C: tests/test_agent_tools_wiring.py:test_build_anthropic_tools_conversion, tests/test_tool_access_exclusion.py:test_build_anthropic_tools_excludes_disabled]
|
|
"""
|
|
raw_tools: list[dict[str, Any]] = []
|
|
for spec in mcp_client.get_tool_schemas():
|
|
if _agent_tools.get(spec["name"], True):
|
|
raw_tools.append({
|
|
"name": spec["name"],
|
|
"description": spec["description"],
|
|
"input_schema": spec["parameters"],
|
|
})
|
|
if _agent_tools.get(TOOL_NAME, True):
|
|
raw_tools.append({
|
|
"name": TOOL_NAME,
|
|
"description": (
|
|
"Run a PowerShell script within the project base_dir. "
|
|
"Use this to create, edit, rename, or delete files and directories. "
|
|
"The working directory is set to base_dir automatically. "
|
|
"Always prefer targeted edits over full rewrites where possible. "
|
|
"stdout and stderr are returned to you as the result."
|
|
),
|
|
"input_schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"script": {
|
|
"type": "string",
|
|
"description": "The PowerShell script to execute."
|
|
}
|
|
},
|
|
"required": ["script"]
|
|
}
|
|
})
|
|
if _active_tool_preset:
|
|
_BIAS_ENGINE.apply_semantic_nudges(raw_tools, _active_tool_preset)
|
|
if raw_tools:
|
|
raw_tools[-1]["cache_control"] = {"type": "ephemeral"}
|
|
return raw_tools
|
|
|
|
_CACHED_ANTHROPIC_TOOLS: Optional[list[dict[str, Any]]] = None
|
|
|
|
def _get_anthropic_tools() -> list[dict[str, Any]]:
|
|
"""
|
|
[C: tests/test_bias_efficacy.py:test_bias_efficacy_prompt_generation, tests/test_bias_efficacy.py:test_bias_parameter_nudging, tests/test_bias_integration.py:test_tool_declaration_biasing_anthropic]
|
|
"""
|
|
global _CACHED_ANTHROPIC_TOOLS
|
|
if _CACHED_ANTHROPIC_TOOLS is None:
|
|
_CACHED_ANTHROPIC_TOOLS = _build_anthropic_tools()
|
|
return _CACHED_ANTHROPIC_TOOLS
|
|
|
|
def _gemini_tool_declaration() -> Optional[types.Tool]:
|
|
"""
|
|
[C: tests/test_tool_access_exclusion.py:test_gemini_tool_declaration_excludes_disabled]
|
|
"""
|
|
# Note: We look up the PARENT package `google.genai` and access `.types`
|
|
# as an attribute, not `_require_warmed("google.genai.types")` directly.
|
|
# The latter triggers a latent circular-import bug in google-genai's
|
|
# __init__.py chain in fresh pytest processes. Using the parent
|
|
# completes the chain once, then `.types` is just an attribute access.
|
|
genai = _require_warmed("google.genai")
|
|
types = genai.types
|
|
raw_tools: list[dict[str, Any]] = []
|
|
for spec in mcp_client.get_tool_schemas():
|
|
if _agent_tools.get(spec["name"], True):
|
|
raw_tools.append({
|
|
"name": spec["name"],
|
|
"description": spec["description"],
|
|
"parameters": spec["parameters"]
|
|
})
|
|
if _agent_tools.get(TOOL_NAME, True):
|
|
raw_tools.append({
|
|
"name": TOOL_NAME,
|
|
"description": (
|
|
"Run a PowerShell script within the project base_dir. "
|
|
"Use this to create, edit, rename, or delete files and directories. "
|
|
"The working directory is set to base_dir automatically. "
|
|
"Always prefer targeted edits over full rewrites where possible. "
|
|
"stdout and stderr are returned to you as the result."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"script": {
|
|
"type": "string",
|
|
"description": "The PowerShell script to execute."
|
|
}
|
|
},
|
|
"required": ["script"]
|
|
}
|
|
})
|
|
if _active_tool_preset:
|
|
_BIAS_ENGINE.apply_semantic_nudges(raw_tools, _active_tool_preset)
|
|
declarations: list[types.FunctionDeclaration] = []
|
|
for tool_def in raw_tools:
|
|
props = {}
|
|
params = tool_def.get("parameters", {})
|
|
for pname, pdef in params.get("properties", {}).items():
|
|
ptype_str = pdef.get("type", "string").upper()
|
|
ptype = getattr(types.Type, ptype_str, types.Type.STRING)
|
|
props[pname] = types.Schema(
|
|
type=ptype,
|
|
description=pdef.get("description", ""),
|
|
)
|
|
declarations.append(types.FunctionDeclaration(
|
|
name = tool_def["name"],
|
|
description = tool_def["description"],
|
|
parameters = types.Schema(
|
|
type = types.Type.OBJECT,
|
|
properties = props,
|
|
required = params.get("required", []),
|
|
),
|
|
))
|
|
return types.Tool(function_declarations=declarations) if declarations else None
|
|
|
|
#endregion: Tool Configuration
|
|
|
|
#region: Tool Execution
|
|
|
|
def _parse_tool_args_result(tool_args_str: str) -> Result[dict[str, Any]]:
|
|
"""Parse tool call arguments from JSON. Returns Result[dict, ErrorInfo].
|
|
|
|
On JSON parse failure, returns Result(data={}, errors=[ErrorInfo(...)]).
|
|
The legacy caller accumulates errors into file_errors and falls back to
|
|
empty args (preserving original behavior). Per TIER1_REVIEW 2026-06-20:
|
|
empty-default is NOT a drain — the caller must observe the errors.
|
|
"""
|
|
try:
|
|
return Result(data=json.loads(tool_args_str))
|
|
except (ValueError, TypeError) as e:
|
|
return Result(
|
|
data={},
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to parse tool args: {e}", source="ai_client._parse_tool_args_result", original=e)],
|
|
)
|
|
|
|
|
|
async def _execute_tool_calls_concurrently(
|
|
calls: list[Any],
|
|
base_dir: str,
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
|
|
qa_callback: Optional[Callable[[str], str]],
|
|
r_idx: int,
|
|
provider: str,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
|
|
) -> list[tuple[str, str, str, str]]: # tool_name, call_id, output, original_name
|
|
"""
|
|
Executes tool calls concurrently using asyncio.gather.
|
|
|
|
Functional Purpose:
|
|
Concurrently dispatches tool calls to _execute_single_tool_call_async.
|
|
|
|
Parameters & Inputs:
|
|
calls (list[Any]): List of tool calls.
|
|
base_dir (str): Workspace path.
|
|
pre_tool_callback (Optional[Callable]): HITL/approval callback.
|
|
qa_callback (Optional[Callable]): QA verification callback.
|
|
r_idx (int): Round index.
|
|
provider (str): LLM provider.
|
|
patch_callback (Optional[Callable]): Patch verification callback.
|
|
|
|
Returns:
|
|
list[tuple[str, str, str, str]]: List of (tool_name, call_id, output, original_name).
|
|
|
|
Immediate-Mode DAG / Thread Context:
|
|
Called by: run_with_tool_loop
|
|
Calls: _execute_single_tool_call_async
|
|
|
|
SSDL:
|
|
`[I:gather] => o-> [I:_execute_single_tool_call_async] -> [M] -> [T:tool_results]`
|
|
|
|
Thread Boundaries:
|
|
Runs in the active asyncio event loop thread.
|
|
|
|
[C: tests/test_async_tools.py:test_execute_tool_calls_concurrently_exception_handling, tests/test_async_tools.py:test_execute_tool_calls_concurrently_timing]
|
|
"""
|
|
monitor = performance_monitor.get_monitor()
|
|
if monitor.enabled: monitor.start_component("ai_client._execute_tool_calls_concurrently")
|
|
tier = get_current_tier()
|
|
file_errors: list[ErrorInfo] = []
|
|
tasks = []
|
|
for fc in calls:
|
|
if provider == "gemini": name, args, call_id = fc.name, dict(fc.args), fc.name # Gemini 1.0.0 doesn't have call IDs in types.Part
|
|
elif provider == "gemini_cli": name, args, call_id = cast(str, fc.get("name")), cast(dict[str, Any], fc.get("args", {})), cast(str, fc.get("id"))
|
|
elif provider == "anthropic": name, args, call_id = cast(str, getattr(fc, "name")), cast(dict[str, Any], getattr(fc, "input")), cast(str, getattr(fc, "id"))
|
|
elif provider == "deepseek":
|
|
tool_info = fc.get("function", {})
|
|
name = cast(str, tool_info.get("name"))
|
|
tool_args_str = cast(str, tool_info.get("arguments", "{}"))
|
|
call_id = cast(str, fc.get("id"))
|
|
parsed = _parse_tool_args_result(tool_args_str)
|
|
if parsed.errors:
|
|
file_errors.extend(parsed.errors)
|
|
args = parsed.data
|
|
elif provider == "minimax":
|
|
tool_info = fc.get("function", {})
|
|
name = cast(str, tool_info.get("name"))
|
|
tool_args_str = cast(str, tool_info.get("arguments", "{}"))
|
|
call_id = cast(str, fc.get("id"))
|
|
parsed = _parse_tool_args_result(tool_args_str)
|
|
if parsed.errors:
|
|
file_errors.extend(parsed.errors)
|
|
args = parsed.data
|
|
else:
|
|
continue
|
|
|
|
tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx, tier, patch_callback))
|
|
|
|
results = await asyncio.gather(*tasks)
|
|
if monitor.enabled: monitor.end_component("ai_client._execute_tool_calls_concurrently")
|
|
return results
|
|
|
|
def run_with_tool_loop(
|
|
client: Any,
|
|
request: Union[OpenAICompatibleRequest, Callable[[int], OpenAICompatibleRequest]],
|
|
*,
|
|
capabilities: Optional[VendorCapabilities] = None,
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None,
|
|
base_dir: str,
|
|
vendor_name: str,
|
|
history_lock: Optional[threading.Lock] = None,
|
|
history: Optional[list[dict[str, Any]]] = None,
|
|
trim_func: Optional[Callable[[list[dict[str, Any]]], None]] = None,
|
|
reasoning_extractor: Optional[Callable[[Any], str]] = None,
|
|
send_func: Optional[Callable[[int], NormalizedResponse]] = None,
|
|
on_pre_dispatch: Optional[Callable[[int, list[dict[str, Any]]], list[dict[str, Any]]]] = None,
|
|
wrap_reasoning_in_text: bool = False,
|
|
) -> str:
|
|
"""
|
|
Orchestrates the LLM conversation loop, executing tool calls and updating history.
|
|
|
|
Functional Purpose:
|
|
Runs a multi-round tool loop (up to MAX_TOOL_ROUNDS + 2). It dispatches client requests,
|
|
executes any generated tool calls concurrently, updates history, and repeats until completion.
|
|
|
|
Parameters & Inputs:
|
|
client (Any): Active client instance.
|
|
request (Union[OpenAICompatibleRequest, Callable]): Initial request or builder callback.
|
|
capabilities (Optional[VendorCapabilities]): Capabilities config.
|
|
pre_tool_callback (Optional[Callable]): Human-in-the-loop validation callback.
|
|
qa_callback (Optional[Callable]): QA verification callback.
|
|
stream_callback (Optional[Callable]): Streaming callback.
|
|
patch_callback (Optional[Callable]): Verification callback for code patches.
|
|
base_dir (str): Base workspace directory.
|
|
vendor_name (str): The vendor name.
|
|
history_lock (Optional[threading.Lock]): Lock for thread safety on history.
|
|
history (Optional[list[dict[str, Any]]]): Conversation history.
|
|
trim_func (Optional[Callable]): Trimming callback for history.
|
|
reasoning_extractor (Optional[Callable]): Callback to extract reasoning content.
|
|
send_func (Optional[Callable]): Dispatch sender callback.
|
|
on_pre_dispatch (Optional[Callable]): Callback to adjust tools.
|
|
wrap_reasoning_in_text (bool): When True and reasoning_content is non-empty, the
|
|
returned text is prepended with `<thinking>...</thinking>` wrapping the
|
|
reasoning. This lets thinking_parser.parse_thinking_trace extract a
|
|
ThinkingSegment for the discussion entry. Default False (callers that
|
|
already wrap inline, e.g. DeepSeek, pass False).
|
|
|
|
Returns:
|
|
str: The final text response returned by the LLM.
|
|
|
|
Immediate-Mode DAG / Thread Context:
|
|
Called by: _send_anthropic, _send_deepseek, _send_minimax, _send_qwen, _send_llama,
|
|
_send_grok, _send_llama_native
|
|
Calls: dispatch_send, _execute_tool_calls_concurrently
|
|
|
|
SSDL:
|
|
`o-> [I:dispatch_send] -> [B:tool_calls?] => [I:_execute_tool_calls_concurrently] -> [T:response_text]`
|
|
|
|
Thread Boundaries:
|
|
Runs synchronously in caller thread; synchronizes history modifications using history_lock.
|
|
"""
|
|
def _default_send(_round_idx: int) -> NormalizedResponse:
|
|
from src.openai_compatible import send_openai_compatible as _send_oc
|
|
assert capabilities is not None, "capabilities required when send_func is not provided"
|
|
res = _send_oc(client, request_builder(_round_idx), capabilities=capabilities)
|
|
if not res.ok:
|
|
if res.errors and res.errors[0].original:
|
|
raise res.errors[0].original from None
|
|
raise RuntimeError(res.errors[0].message if res.errors else "Unknown OpenAI error") from None
|
|
return res.data
|
|
request_builder: Callable[[int], OpenAICompatibleRequest] = (request if callable(request) else (lambda _i: request))
|
|
dispatch_send: Callable[[int], NormalizedResponse] = send_func or _default_send
|
|
response_text: str = ""
|
|
for _round_idx in range(MAX_TOOL_ROUNDS + 2):
|
|
response = dispatch_send(_round_idx)
|
|
reasoning_content: str = reasoning_extractor(response.raw_response) if reasoning_extractor else ""
|
|
response_text = response.text or ""
|
|
if history_lock is not None and history is not None:
|
|
with history_lock:
|
|
msg: dict[str, Any] = {"role": "assistant", "content": response.text or None}
|
|
if reasoning_content: msg["reasoning_content"] = reasoning_content
|
|
if response.tool_calls: msg["tool_calls"] = response.tool_calls
|
|
history.append(msg)
|
|
if not response.tool_calls: break
|
|
if on_pre_dispatch is not None: _adjusted_calls = on_pre_dispatch(_round_idx, response.tool_calls)
|
|
else: _adjusted_calls = response.tool_calls
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
results = asyncio.run_coroutine_threadsafe(
|
|
_execute_tool_calls_concurrently(
|
|
_adjusted_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback,
|
|
),
|
|
loop,
|
|
).result()
|
|
except RuntimeError:
|
|
results = asyncio.run(_execute_tool_calls_concurrently(
|
|
_adjusted_calls, base_dir, pre_tool_callback, qa_callback, _round_idx, vendor_name, patch_callback,
|
|
))
|
|
if history_lock is not None and history is not None:
|
|
with history_lock:
|
|
for _i, (tool_name, call_id, out, _err) in enumerate(results):
|
|
history.append({
|
|
"role": "tool",
|
|
"tool_call_id": call_id,
|
|
"content": str(out) if out else "",
|
|
})
|
|
if trim_func is not None: trim_func(history)
|
|
if wrap_reasoning_in_text and reasoning_content:
|
|
response_text = f"<thinking>\n{reasoning_content}\n</thinking>\n\n{response_text}"
|
|
return response_text
|
|
|
|
async def _execute_single_tool_call_async(
|
|
name: str,
|
|
args: dict[str, Any],
|
|
call_id: str,
|
|
base_dir: str,
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
|
|
qa_callback: Optional[Callable[[str], str]],
|
|
r_idx: int,
|
|
tier: str | None = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
|
|
) -> tuple[str, str, str, str]:
|
|
"""
|
|
Executes a single tool call asynchronously, checking the approval clutch.
|
|
|
|
Functional Purpose:
|
|
Executes a tool call (either PowerShell script or MCP tool) based on tool approval clutch settings.
|
|
Uses pre_tool_callback for human approval when required.
|
|
|
|
Parameters & Inputs:
|
|
name (str): The name of the tool to execute.
|
|
args (dict[str, Any]): Arguments passed to the tool.
|
|
call_id (str): Unique call identifier.
|
|
base_dir (str): Workspace root directory.
|
|
pre_tool_callback (Optional[Callable]): Hook for HITL validation.
|
|
qa_callback (Optional[Callable]): QA verification callback.
|
|
r_idx (int): Current tool loop round index.
|
|
tier (str | None): Active MMA orchestration tier.
|
|
patch_callback (Optional[Callable]): Verification callback for code patches.
|
|
|
|
Returns:
|
|
tuple[str, str, str, str]: A tuple containing (tool_name, call_id, output, original_name).
|
|
|
|
Immediate-Mode DAG / Thread Context:
|
|
Called by: _execute_tool_calls_concurrently
|
|
Calls: set_current_tier, events.emit, _append_comms, _run_script,
|
|
pre_tool_callback, mcp_client.async_dispatch
|
|
|
|
SSDL:
|
|
`[I:CheckClutch] -> [B:Approved?] -> [I:run_powershell] -> [T:output]`
|
|
|
|
Thread Boundaries:
|
|
Runs in the active asyncio event loop thread; offloads blocking synchronous calls
|
|
(like pre_tool_callback and _run_script) to separate worker threads using asyncio.to_thread.
|
|
"""
|
|
set_current_tier(tier)
|
|
out = ""
|
|
tool_executed = False
|
|
events.emit("tool_execution", payload = {"status": "started", "tool": name, "args": args, "round": r_idx})
|
|
|
|
# Check for auto approval mode
|
|
approval_mode = _tool_approval_modes.get(name, "ask")
|
|
|
|
# Check for run_powershell
|
|
if name == TOOL_NAME:
|
|
scr = cast(str, args.get("script", ""))
|
|
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
|
if approval_mode == "auto":
|
|
out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback, patch_callback)
|
|
tool_executed = True
|
|
elif pre_tool_callback:
|
|
# pre_tool_callback is synchronous and might block for HITL
|
|
res = await asyncio.to_thread(pre_tool_callback, scr, base_dir, qa_callback)
|
|
if res is None: out = "USER REJECTED: tool execution cancelled"
|
|
else: out = res
|
|
tool_executed = True
|
|
|
|
if not tool_executed:
|
|
is_native = name in mcp_client.TOOL_NAMES
|
|
ext_tools = mcp_client.get_external_mcp_manager().get_all_tools()
|
|
is_external = name in ext_tools
|
|
if name and (is_native or is_external):
|
|
_append_comms("OUT", "tool_call", {"name": name, "id": call_id, "args": args})
|
|
should_approve = (name in mcp_client.MUTATING_TOOLS or is_external) and approval_mode != "auto" and pre_tool_callback
|
|
if should_approve:
|
|
label = "MCP MUTATING" if is_native else "EXTERNAL MCP"
|
|
desc = f"# {label} TOOL: {name}\n" + "\n".join(f"# {k}: {repr(v)}" for k, v in args.items())
|
|
_res = await asyncio.to_thread(pre_tool_callback, desc, base_dir, qa_callback)
|
|
out = "USER REJECTED: tool execution cancelled" if _res is None else await mcp_client.async_dispatch(name, args)
|
|
else:
|
|
out = await mcp_client.async_dispatch(name, args)
|
|
if tool_log_callback:
|
|
tool_log_callback(f"# MCP TOOL: {name}\n{json.dumps(args, indent=1)}", out)
|
|
elif name == TOOL_NAME:
|
|
scr = cast(str, args.get("script", ""))
|
|
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
|
out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback, patch_callback)
|
|
else:
|
|
out = f"ERROR: unknown tool '{name}'"
|
|
if tool_log_callback:
|
|
tool_log_callback(f"ERROR: {name}", out)
|
|
|
|
return (name, call_id, out, name)
|
|
|
|
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
|
|
if confirm_and_run_callback is None:
|
|
return "ERROR: no confirmation handler registered"
|
|
result = confirm_and_run_callback(script, base_dir, qa_callback, patch_callback)
|
|
if result is None: output = "USER REJECTED: command was not executed"
|
|
else: output = result
|
|
if tool_log_callback is not None: tool_log_callback(script, output)
|
|
return output
|
|
|
|
def _truncate_tool_output(output: str) -> str:
|
|
if _history_trunc_limit > 0 and len(output) > _history_trunc_limit:
|
|
return output[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS.]"
|
|
return output
|
|
|
|
#endregion: Tool Execution
|
|
|
|
#region: File Context Building
|
|
|
|
def _reread_file_items_result(file_items: list[dict[str, Any]]) -> Result[tuple[list[dict[str, Any]], list[dict[str, Any]]]]:
|
|
"""Re-reads file items, returns (refreshed, changed) tuple.
|
|
|
|
Per-file read errors are accumulated into Result.errors (structured
|
|
ErrorInfo with original exception preserved). The legacy caller
|
|
_reread_file_items ignores errors (preserving original behavior);
|
|
future callers should check result.errors to detect file re-read
|
|
failures.
|
|
"""
|
|
refreshed: list[dict[str, Any]] = []
|
|
changed: list[dict[str, Any]] = []
|
|
errors: list[ErrorInfo] = []
|
|
for item in file_items:
|
|
path = item.get("path")
|
|
if path is None:
|
|
refreshed.append(item)
|
|
continue
|
|
p = path if isinstance(path, _P) else _P(path)
|
|
try:
|
|
current_mtime = p.stat().st_mtime
|
|
prev_mtime = cast(float, item.get("mtime", 0.0))
|
|
if current_mtime == prev_mtime:
|
|
refreshed.append(item)
|
|
continue
|
|
content = p.read_text(encoding="utf-8")
|
|
new_item = {**item, "old_content": item.get("content", ""), "content": content, "error": False, "mtime": current_mtime}
|
|
refreshed.append(new_item)
|
|
changed.append(new_item)
|
|
except (OSError, UnicodeDecodeError) as e:
|
|
err_item = {**item, "content": f"ERROR re-reading {p}: {e}", "error": True, "mtime": 0.0}
|
|
refreshed.append(err_item)
|
|
changed.append(err_item)
|
|
errors.append(ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to re-read {p}: {e}", source="ai_client._reread_file_items_result", original=e))
|
|
return Result(data=(refreshed, changed), errors=errors)
|
|
|
|
|
|
def _reread_file_items(file_items: list[dict[str, Any]]) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
|
"""
|
|
Re-reads file items from the filesystem if their modification times have changed.
|
|
Functional Purpose:
|
|
Iterates through context files, compares current filesystem mtime against cached mtime,
|
|
and reads file contents if changes are detected, returning both the full refreshed set
|
|
and the subset of changed items.
|
|
|
|
Parameters & Inputs: file_items (list[dict[str, Any]]): List of file dictionaries containing keys "path" and optionally "mtime", "content".
|
|
|
|
Returns: tuple[list[dict[str, Any]], list[dict[str, Any]]]: A tuple containing (refreshed_items, changed_items).
|
|
|
|
Immediate-Mode DAG / Thread Context:
|
|
Called by: _send_gemini
|
|
Calls: pathlib.Path.stat, pathlib.Path.read_text
|
|
|
|
SSDL: `o-> [I:get_mtime] -> [B:changed?] -> [I:read_file] -> [T:diff_text]`
|
|
|
|
Thread Boundaries: Runs synchronously in the caller thread. Does synchronous blocking file system I/O.
|
|
|
|
Thin wrapper over _reread_file_items_result; the legacy tuple shape is
|
|
preserved for backward compatibility, but the try/except Exception lives
|
|
in the Result variant (where it can capture structured ErrorInfo).
|
|
Per-file read errors are logged to stderr as warnings (operator-visible
|
|
drain) and included in err_item[\"error\"] = True for in-band flag checks.
|
|
"""
|
|
result = _reread_file_items_result(file_items)
|
|
if result.errors:
|
|
for err in result.errors:
|
|
sys.stderr.write(f"[AI_CLIENT] {err.ui_message()}\n")
|
|
sys.stderr.flush()
|
|
refreshed, changed = result.data
|
|
return refreshed, changed
|
|
|
|
def _build_file_context_text(file_items: list[dict[str, Any]]) -> str:
|
|
if not file_items:
|
|
return ""
|
|
parts: list[str] = []
|
|
for item in file_items:
|
|
path = item.get("path") or item.get("entry", "unknown")
|
|
suffix = str(path).rsplit(".", 1)[-1] if "." in str(path) else "text"
|
|
content = item.get("content", "")
|
|
parts.append(f"### `{path}`\n\n```{suffix}\n{content}\n```")
|
|
return "\n\n---\n\n".join(parts)
|
|
|
|
_DIFF_LINE_THRESHOLD: int = 200
|
|
|
|
def _build_file_diff_text(changed_items: list[dict[str, Any]]) -> str:
|
|
"""
|
|
Generates unified diffs or full file dumps for changed files in the context.
|
|
|
|
Functional Purpose:
|
|
Formats file modifications for the LLM prompt. If a file change is small or lacks prior content,
|
|
the full file is dumped; otherwise, a unified diff is constructed.
|
|
|
|
Parameters & Inputs:
|
|
changed_items (list[dict[str, Any]]): List of file dictionaries that have changed.
|
|
|
|
Returns:
|
|
str: Combined markdown string representing the changes or full files.
|
|
|
|
Immediate-Mode DAG / Thread Context:
|
|
Called by: _send_gemini
|
|
Calls: difflib.unified_diff
|
|
|
|
SSDL:
|
|
`o-> [I:get_mtime] -> [B:changed?] -> [I:read_file] -> [T:diff_text]`
|
|
|
|
Thread Boundaries:
|
|
Runs synchronously in the caller thread.
|
|
"""
|
|
if not changed_items:
|
|
return ""
|
|
parts: list[str] = []
|
|
for item in changed_items:
|
|
path = item.get("path") or item.get("entry", "unknown")
|
|
content = cast(str, item.get("content", ""))
|
|
old_content = cast(str, item.get("old_content", ""))
|
|
new_lines = content.splitlines(keepends=True)
|
|
if len(new_lines) <= _DIFF_LINE_THRESHOLD or not old_content:
|
|
suffix = str(path).rsplit(".", 1)[-1] if "." in str(path) else "text"
|
|
parts.append(f"### `{path}` (full)\n\n```{suffix}\n{content}\n```")
|
|
else:
|
|
old_lines = old_content.splitlines(keepends=True)
|
|
diff = difflib.unified_diff(old_lines, new_lines, fromfile=str(path), tofile=str(path), lineterm="")
|
|
diff_text = "\n".join(diff)
|
|
if diff_text: parts.append(f"### `{path}` (diff)\n\n```diff\n{diff_text}\n```")
|
|
else: parts.append(f"### `{path}` (no changes detected)")
|
|
return "\n\n---\n\n".join(parts)
|
|
|
|
def _build_deepseek_tools() -> list[dict[str, Any]]:
|
|
raw_tools: list[dict[str, Any]] = []
|
|
for spec in mcp_client.get_tool_schemas():
|
|
if _agent_tools.get(spec["name"], True):
|
|
raw_tools.append({
|
|
"name": spec["name"],
|
|
"description": spec["description"],
|
|
"parameters": spec["parameters"]
|
|
})
|
|
if _agent_tools.get(TOOL_NAME, True):
|
|
raw_tools.append({
|
|
"name": TOOL_NAME,
|
|
"description": (
|
|
"Run a PowerShell script within the project base_dir. "
|
|
"Use this to create, edit, rename, or delete files and directories. "
|
|
"The working directory is set to base_dir automatically. "
|
|
"Always prefer targeted edits over full rewrites where possible. "
|
|
"stdout and stderr are returned to you as the result."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"script": {
|
|
"type": "string",
|
|
"description": "The PowerShell script to execute."
|
|
}
|
|
},
|
|
"required": ["script"]
|
|
}
|
|
})
|
|
if _active_tool_preset:
|
|
_BIAS_ENGINE.apply_semantic_nudges(raw_tools, _active_tool_preset)
|
|
tools_list: list[dict[str, Any]] = []
|
|
for tool_def in raw_tools:
|
|
tools_list.append({
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool_def["name"],
|
|
"description": tool_def["description"],
|
|
"parameters": tool_def["parameters"],
|
|
}
|
|
})
|
|
return tools_list
|
|
|
|
_CACHED_DEEPSEEK_TOOLS: Optional[list[dict[str, Any]]] = None
|
|
|
|
def _get_deepseek_tools() -> list[dict[str, Any]]:
|
|
global _CACHED_DEEPSEEK_TOOLS
|
|
if _CACHED_DEEPSEEK_TOOLS is None:
|
|
_CACHED_DEEPSEEK_TOOLS = _build_deepseek_tools()
|
|
return _CACHED_DEEPSEEK_TOOLS
|
|
|
|
def _content_block_to_dict(block: Any) -> dict[str, Any]:
|
|
if isinstance(block, dict): return block
|
|
if hasattr(block, "model_dump"): return cast(dict[str, Any], block.model_dump())
|
|
if hasattr(block, "to_dict"): return cast(dict[str, Any], block.to_dict())
|
|
block_type = getattr(block, "type", None)
|
|
if block_type == "text": return {"type": "text", "text": block.text}
|
|
if block_type == "tool_use": return {"type": "tool_use", "id": getattr(block, "id"), "name": getattr(block, "name"), "input": getattr(block, "input")}
|
|
return {"type": "text", "text": str(block)}
|
|
|
|
#endregion: File Context Building
|
|
|
|
#region: Token Estimation
|
|
|
|
_CHARS_PER_TOKEN: float = 3.5
|
|
_ANTHROPIC_MAX_PROMPT_TOKENS: int = 180_000
|
|
_GEMINI_MAX_INPUT_TOKENS: int = 900_000
|
|
_FILE_REFRESH_MARKER: str = _project_context_marker if _project_context_marker.strip() else "[SYSTEM: FILES UPDATED]"
|
|
|
|
def _estimate_message_tokens(msg: dict[str, Any]) -> int:
|
|
cached = msg.get("_est_tokens")
|
|
if cached is not None: return cast(int, cached)
|
|
total_chars = 0
|
|
content = msg.get("content", "")
|
|
if isinstance(content, str):
|
|
total_chars += len(content)
|
|
elif isinstance(content, list):
|
|
for block in content:
|
|
if isinstance(block, dict):
|
|
text = block.get("text", "") or block.get("content", "")
|
|
if isinstance(text, str):
|
|
total_chars += len(text)
|
|
inp = block.get("input")
|
|
if isinstance(inp, dict):
|
|
total_chars += len(json.dumps(inp, ensure_ascii=False))
|
|
elif isinstance(block, str):
|
|
total_chars += len(block)
|
|
est = max(1, int(total_chars / _CHARS_PER_TOKEN))
|
|
msg["_est_tokens"] = est
|
|
return est
|
|
|
|
def _invalidate_token_estimate(msg: dict[str, Any]) -> None:
|
|
msg.pop("_est_tokens", None)
|
|
|
|
def _estimate_prompt_tokens(system_blocks: list[dict[str, Any]], history: list[dict[str, Any]]) -> int:
|
|
total = 0
|
|
for block in system_blocks:
|
|
text = cast(str, block.get("text", ""))
|
|
total += max(1, int(len(text) / _CHARS_PER_TOKEN))
|
|
total += 2500
|
|
for msg in history:
|
|
total += _estimate_message_tokens(msg)
|
|
return total
|
|
|
|
def _strip_stale_file_refreshes(history: list[dict[str, Any]]) -> None:
|
|
if len(history) < 2:
|
|
return
|
|
last_user_idx = -1
|
|
for i in range(len(history) - 1, -1, -1):
|
|
if history[i].get("role") == "user":
|
|
last_user_idx = i
|
|
break
|
|
for i, msg in enumerate(history):
|
|
if msg.get("role") != "user" or i == last_user_idx:
|
|
continue
|
|
content = msg.get("content")
|
|
if not isinstance(content, list):
|
|
continue
|
|
cleaned: list[dict[str, Any]] = []
|
|
for block in content:
|
|
if isinstance(block, dict) and block.get("type") == "text":
|
|
text = cast(str, block.get("text", ""))
|
|
if text.startswith(_FILE_REFRESH_MARKER):
|
|
continue
|
|
cleaned.append(block)
|
|
if len(cleaned) < len(content):
|
|
msg["content"] = cleaned
|
|
_invalidate_token_estimate(msg)
|
|
|
|
def _chunk_text(text: str, chunk_size: int) -> list[str]:
|
|
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]
|
|
|
|
def _build_chunked_context_blocks(md_content: str) -> list[dict[str, Any]]:
|
|
chunks = _chunk_text(md_content, _ANTHROPIC_CHUNK_SIZE)
|
|
blocks: list[dict[str, Any]] = []
|
|
for i, chunk in enumerate(chunks):
|
|
block: dict[str, Any] = {"type": "text", "text": chunk}
|
|
if i == len(chunks) - 1:
|
|
block["cache_control"] = {"type": "ephemeral"}
|
|
blocks.append(block)
|
|
return blocks
|
|
|
|
def _strip_cache_controls(history: list[dict[str, Any]]) -> None:
|
|
for msg in history:
|
|
content = msg.get("content")
|
|
if isinstance(content, list):
|
|
for block in content:
|
|
if isinstance(block, dict):
|
|
block.pop("cache_control", None)
|
|
|
|
def _add_history_cache_breakpoint(history: list[dict[str, Any]]) -> None:
|
|
user_indices = [i for i, m in enumerate(history) if m.get("role") == "user"]
|
|
if len(user_indices) < 2: return
|
|
target_idx = user_indices[-2]
|
|
content = history[target_idx].get("content")
|
|
if isinstance(content, list) and content:
|
|
last_block = content[-1]
|
|
if isinstance(last_block, dict):
|
|
last_block["cache_control"] = {"type": "ephemeral"}
|
|
elif isinstance(content, str):
|
|
history[target_idx]["content"] = [
|
|
{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
|
|
]
|
|
|
|
#endregion: Token Estimation
|
|
|
|
#region: Anthropic Provider
|
|
|
|
def _list_anthropic_models_result() -> Result[list[str]]:
|
|
"""List available Anthropic models via the SDK.
|
|
|
|
Returns Result(data=sorted_models) on success, Result(data=[],
|
|
errors=[ErrorInfo]) on SDK or credentials failure.
|
|
|
|
The previous version had:
|
|
except Exception as exc:
|
|
raise _classify_anthropic_error(exc) from exc
|
|
which raised an ErrorInfo as an Exception — a runtime bug. This
|
|
migration follows the Phase 9 redo precedent: convert to Result[T].
|
|
"""
|
|
try:
|
|
anthropic = _require_warmed("anthropic")
|
|
creds = _load_credentials()
|
|
client = anthropic.Anthropic(api_key=creds["anthropic"]["api_key"])
|
|
models: list[str] = []
|
|
for m in client.models.list(): models.append(m.id)
|
|
return Result(data=sorted(models))
|
|
except Exception as exc:
|
|
return Result(
|
|
data=[],
|
|
errors=[_classify_anthropic_error(exc, source="ai_client._list_anthropic_models_result")],
|
|
)
|
|
|
|
|
|
def _list_anthropic_models() -> list[str]:
|
|
return _list_anthropic_models_result().data
|
|
|
|
def _ensure_anthropic_client() -> None:
|
|
global _anthropic_client
|
|
anthropic = _require_warmed("anthropic")
|
|
if _anthropic_client is None:
|
|
creds = _load_credentials()
|
|
_anthropic_client = anthropic.Anthropic(
|
|
api_key = creds["anthropic"]["api_key"],
|
|
default_headers = {"anthropic-beta": "prompt-caching-2024-07-31"}
|
|
)
|
|
|
|
def _trim_anthropic_history(system_blocks: list[dict[str, Any]], history: list[dict[str, Any]]) -> int:
|
|
_strip_stale_file_refreshes(history)
|
|
est = _estimate_prompt_tokens(system_blocks, history)
|
|
if est <= _ANTHROPIC_MAX_PROMPT_TOKENS: return 0
|
|
dropped = 0
|
|
while len(history) > 3 and est > _ANTHROPIC_MAX_PROMPT_TOKENS:
|
|
if history[1].get("role") == "assistant" and len(history) > 2 and history[2].get("role") == "user":
|
|
removed_asst = history.pop(1)
|
|
removed_user = history.pop(1)
|
|
dropped += 2
|
|
est -= _estimate_message_tokens(removed_asst)
|
|
est -= _estimate_message_tokens(removed_user)
|
|
while len(history) > 2 and history[1].get("role") == "assistant" and history[2].get("role") == "user":
|
|
content = history[2].get("content", [])
|
|
if isinstance(content, list) and content and isinstance(content[0], dict) and content[0].get("type") == "tool_result":
|
|
r_a = history.pop(1)
|
|
r_u = history.pop(1)
|
|
dropped += 2
|
|
est -= _estimate_message_tokens(r_a)
|
|
est -= _estimate_message_tokens(r_u)
|
|
else:
|
|
break
|
|
else:
|
|
removed = history.pop(1)
|
|
dropped += 1
|
|
est -= _estimate_message_tokens(removed)
|
|
return dropped
|
|
|
|
def _repair_anthropic_history(history: list[dict[str, Any]]) -> None:
|
|
if not history: return
|
|
last = history[-1]
|
|
if last.get("role") != "assistant": return
|
|
content = last.get("content", [])
|
|
tool_use_ids: list[str] = []
|
|
for block in content:
|
|
if isinstance(block, dict):
|
|
if block.get("type") == "tool_use":
|
|
tool_use_ids.append(cast(str, block["id"]))
|
|
if not tool_use_ids:
|
|
return
|
|
history.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "tool_result",
|
|
"tool_use_id": tid,
|
|
"content": "Tool call was not completed (session interrupted).",
|
|
}
|
|
for tid in tool_use_ids
|
|
],
|
|
})
|
|
|
|
def _send_anthropic(
|
|
md_content: str,
|
|
user_message: str,
|
|
base_dir: str,
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
|
|
) -> Result[str]:
|
|
"""
|
|
Functional Purpose:
|
|
Sends requests to Anthropic models, managing conversation history, prompt caching, token limits, and executing tool loops.
|
|
Parameters & Inputs:
|
|
- md_content, user_message, base_dir, file_items, discussion_history: Context and input parameters.
|
|
- pre_tool_callback, qa_callback, stream_callback, patch_callback: Execution control callbacks.
|
|
Immediate-Mode DAG / Thread Context:
|
|
- Called by: send
|
|
- Calls: _ensure_anthropic_client, _trim_anthropic_history, client.messages.create, run_with_tool_loop
|
|
SSDL:
|
|
[I:_ensure_anthropic_client] -> [I:_trim_anthropic_history] -> [I:client.messages.create] -> [T:Result]
|
|
Thread Boundaries:
|
|
Runs on whichever thread calls send (typically an async worker thread).
|
|
"""
|
|
anthropic = _require_warmed("anthropic")
|
|
genai = _require_warmed("google.genai")
|
|
types = genai.types
|
|
monitor = performance_monitor.get_monitor()
|
|
if monitor.enabled: monitor.start_component("ai_client._send_anthropic")
|
|
try:
|
|
_ensure_anthropic_client()
|
|
mcp_client.configure(file_items or [], [base_dir])
|
|
stable_prompt = _get_combined_system_prompt()
|
|
stable_blocks: list[dict[str, Any]] = [{"type": "text", "text": stable_prompt, "cache_control": {"type": "ephemeral"}}]
|
|
context_text = f"\n\n<context>\n{md_content}\n</context>"
|
|
context_blocks = _build_chunked_context_blocks(context_text)
|
|
system_blocks = stable_blocks + context_blocks
|
|
if discussion_history and not _anthropic_history:
|
|
user_content: list[dict[str, Any]] = [{"type": "text", "text": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"}]
|
|
else:
|
|
user_content = [{"type": "text", "text": user_message}]
|
|
for msg in _anthropic_history:
|
|
if msg.get("role") == "user" and isinstance(msg.get("content"), list):
|
|
modified = False
|
|
for block in cast(List[dict[str, Any]], msg["content"]):
|
|
if isinstance(block, dict) and block.get("type") == "tool_result":
|
|
t_content = block.get("content", "")
|
|
if _history_trunc_limit > 0 and isinstance(t_content, str) and len(t_content) > _history_trunc_limit:
|
|
block["content"] = t_content[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS. Original output was too large.]"
|
|
modified = True
|
|
if modified: _invalidate_token_estimate(msg)
|
|
_strip_cache_controls(_anthropic_history)
|
|
_repair_anthropic_history(_anthropic_history)
|
|
_anthropic_history.append({"role": "user", "content": user_content})
|
|
_add_history_cache_breakpoint(_anthropic_history)
|
|
all_text_parts: list[str] = []
|
|
_cumulative_tool_bytes = 0
|
|
|
|
def _strip_private_keys(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
return [{k: v for k, v in m.items() if not k.startswith("_")} for m in history]
|
|
|
|
for round_idx in range(MAX_TOOL_ROUNDS + 2):
|
|
response: Any = None
|
|
dropped = _trim_anthropic_history(system_blocks, _anthropic_history)
|
|
if dropped > 0:
|
|
est_tokens = _estimate_prompt_tokens(system_blocks, _anthropic_history)
|
|
_append_comms("OUT", "request", {
|
|
"message": (
|
|
f"[HISTORY TRIMMED: dropped {dropped} old messages to fit token budget. "
|
|
f"Estimated {est_tokens} tokens remaining. {len(_anthropic_history)} messages in history.]"
|
|
),
|
|
})
|
|
|
|
events.emit("request_start", payload={"provider": "anthropic", "model": _model, "round": round_idx})
|
|
assert _anthropic_client is not None
|
|
if stream_callback:
|
|
with _anthropic_client.messages.stream(
|
|
model = _model,
|
|
max_tokens = _max_tokens,
|
|
temperature = _temperature,
|
|
top_p = _top_p,
|
|
system = cast(Iterable[anthropic.types.TextBlockParam], system_blocks),
|
|
tools = cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()),
|
|
messages = cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)),
|
|
) as stream:
|
|
for event in stream:
|
|
if isinstance(event, anthropic.types.ContentBlockDeltaEvent) and event.delta.type == "text_delta":
|
|
stream_callback(event.delta.text)
|
|
response = stream.get_final_message()
|
|
else:
|
|
response = _anthropic_client.messages.create(
|
|
model = _model,
|
|
max_tokens = _max_tokens,
|
|
temperature = _temperature,
|
|
top_p = _top_p,
|
|
system = cast(Iterable[anthropic.types.TextBlockParam], system_blocks),
|
|
tools = cast(Iterable[anthropic.types.ToolParam], _get_anthropic_tools()),
|
|
messages = cast(Iterable[anthropic.types.MessageParam], _strip_private_keys(_anthropic_history)),
|
|
)
|
|
serialised_content = [_content_block_to_dict(b) for b in response.content]
|
|
_anthropic_history.append({
|
|
"role": "assistant",
|
|
"content": serialised_content,
|
|
})
|
|
text_blocks = [b.text for b in response.content if hasattr(b, "text") and b.text]
|
|
if text_blocks:
|
|
all_text_parts.append("\n".join(text_blocks))
|
|
tool_use_blocks = [
|
|
{"id": getattr(b, "id"), "name": getattr(b, "name"), "input": getattr(b, "input")}
|
|
for b in response.content
|
|
if getattr(b, "type", None) == "tool_use"
|
|
]
|
|
usage_dict: dict[str, Any] = {}
|
|
if response.usage:
|
|
usage_dict["input_tokens"] = response.usage.input_tokens
|
|
usage_dict["output_tokens"] = response.usage.output_tokens
|
|
cache_creation = getattr(response.usage, "cache_creation_input_tokens", None)
|
|
cache_read = getattr(response.usage, "cache_read_input_tokens", None)
|
|
if cache_creation is not None: usage_dict["cache_creation_input_tokens"] = cache_creation
|
|
if cache_read is not None: usage_dict["cache_read_input_tokens"] = cache_read
|
|
events.emit("response_received", payload={"provider": "anthropic", "model": _model, "usage": usage_dict, "round": round_idx})
|
|
_append_comms("IN", "response", {
|
|
"round": round_idx,
|
|
"stop_reason": response.stop_reason,
|
|
"text": "\n".join(text_blocks),
|
|
"tool_calls": tool_use_blocks,
|
|
"usage": usage_dict,
|
|
})
|
|
if response.stop_reason != "tool_use" or not tool_use_blocks: break
|
|
if round_idx > MAX_TOOL_ROUNDS: break
|
|
|
|
# Execute tools concurrently
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
results = asyncio.run_coroutine_threadsafe(
|
|
_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic", patch_callback),
|
|
loop
|
|
).result()
|
|
except RuntimeError:
|
|
results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic", patch_callback))
|
|
|
|
tool_results: list[dict[str, Any]] = []
|
|
for i, (name, call_id, out, _) in enumerate(results):
|
|
truncated = _truncate_tool_output(out)
|
|
_cumulative_tool_bytes += len(truncated)
|
|
tool_results.append({
|
|
"type": "tool_result",
|
|
"tool_use_id": call_id,
|
|
"content": truncated,
|
|
})
|
|
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
|
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx})
|
|
|
|
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
|
tool_results.append({
|
|
"type": "text",
|
|
"text": f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
|
|
})
|
|
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
|
|
if file_items:
|
|
file_items, changed = _reread_file_items(file_items)
|
|
refreshed_ctx = _build_file_diff_text(changed)
|
|
if refreshed_ctx:
|
|
tool_results.append({
|
|
"type": "text",
|
|
"text": (
|
|
f"{_get_context_marker()}\n\n"
|
|
+ refreshed_ctx
|
|
),
|
|
})
|
|
if round_idx == MAX_TOOL_ROUNDS:
|
|
tool_results.append({
|
|
"type": "text",
|
|
"text": "SYSTEM WARNING: MAX TOOL ROUNDS REACHED. YOU MUST PROVIDE YOUR FINAL ANSWER NOW WITHOUT CALLING ANY MORE TOOLS."
|
|
})
|
|
_anthropic_history.append({
|
|
"role": "user",
|
|
"content": tool_results,
|
|
})
|
|
_append_comms("OUT", "tool_result_send", {
|
|
"results": [
|
|
{"tool_use_id": r["tool_use_id"], "content": r["content"]}
|
|
for r in tool_results if r.get("type") == "tool_result"
|
|
],
|
|
})
|
|
final_text = "\n\n".join(all_text_parts)
|
|
res = final_text if final_text.strip() else "(No text returned by the model)"
|
|
if monitor.enabled: monitor.end_component("ai_client._send_anthropic")
|
|
return Result(data=res)
|
|
except Exception as exc:
|
|
if monitor.enabled: monitor.end_component("ai_client._send_anthropic")
|
|
return Result(data="", errors=[_classify_anthropic_error(exc, source="ai_client.anthropic")])
|
|
|
|
#endregion: Anthropic Provider
|
|
|
|
#region: Gemini Provider
|
|
|
|
def get_gemini_cache_stats() -> dict[str, Any]:
|
|
_ensure_gemini_client()
|
|
if not _gemini_client: return {"cache_count": 0, "total_size_bytes": 0, "cached_files": []}
|
|
caches_iterator = _gemini_client.caches.list()
|
|
caches = list(caches_iterator)
|
|
total_size_bytes = sum(getattr(c, 'size_bytes', 0) for c in caches)
|
|
return {
|
|
"cache_count": len(caches),
|
|
"total_size_bytes": total_size_bytes,
|
|
"cached_files": _gemini_cached_file_paths,
|
|
}
|
|
|
|
def _list_gemini_cli_models() -> list[str]:
|
|
return [
|
|
"gemini-3-flash-preview",
|
|
"gemini-3.1-pro-preview",
|
|
"gemini-2.5-pro",
|
|
"gemini-2.5-flash",
|
|
"gemini-2.0-flash",
|
|
"gemini-2.5-flash-lite",
|
|
]
|
|
|
|
def _list_gemini_models_result(api_key: str) -> Result[list[str]]:
|
|
"""List available Gemini models via google-genai SDK.
|
|
|
|
Returns the sorted list of Gemini model names. On SDK or network failure,
|
|
returns Result(data=[], errors=[ErrorInfo(...)]). The legacy caller
|
|
(_list_gemini_models) returns result.data directly (preserving original
|
|
behavior); callers that need to surface errors should call this helper
|
|
and inspect result.errors.
|
|
"""
|
|
try:
|
|
genai = _require_warmed("google.genai")
|
|
client = genai.Client(api_key=api_key)
|
|
models: list[str] = []
|
|
for m in client.models.list():
|
|
name = m.name
|
|
if name and name.startswith("models/"): name = name[len("models/"):]
|
|
if name and "gemini" in name.lower(): models.append(name)
|
|
return Result(data=sorted(models))
|
|
except Exception as exc:
|
|
return Result(
|
|
data=[],
|
|
errors=[_classify_gemini_error(exc, source="ai_client._list_gemini_models_result")],
|
|
)
|
|
|
|
|
|
def _list_gemini_models(api_key: str) -> list[str]:
|
|
return _list_gemini_models_result(api_key).data
|
|
|
|
def _ensure_gemini_client() -> None:
|
|
global _gemini_client
|
|
genai = _require_warmed("google.genai")
|
|
if _gemini_client is None:
|
|
creds = _load_credentials()
|
|
_gemini_client = genai.Client(api_key=creds["gemini"]["api_key"])
|
|
|
|
def _delete_gemini_cache_result() -> Result[None]:
|
|
"""Delete the active Gemini cache. Returns Result[None].
|
|
|
|
On SDK failure, returns Result(data=None, errors=[ErrorInfo]) and logs
|
|
a warning to comms. The caller ignores errors (cache-delete is a
|
|
best-effort cleanup; the caller proceeds to rebuild cache state).
|
|
"""
|
|
if _gemini_cache is None or _gemini_client is None:
|
|
return Result(data=None)
|
|
try:
|
|
_gemini_client.caches.delete(name=_gemini_cache.name)
|
|
return Result(data=None)
|
|
except Exception as e:
|
|
_append_comms("OUT", "request", {"message": f"[CACHE DELETE WARN] {e}"})
|
|
return Result(
|
|
data=None,
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to delete gemini cache: {e}", source="ai_client._delete_gemini_cache_result", original=e)],
|
|
)
|
|
|
|
_GEMINI_CACHE_TOKEN_THRESHOLD: int = 2048
|
|
|
|
def _should_cache_gemini_result(sys_instr: str) -> Result[bool]:
|
|
"""Decide whether the current Gemini context warrants caching.
|
|
|
|
Returns Result(data=True) if token count >= 2048, Result(data=False)
|
|
if below threshold (with a [CACHING SKIPPED] comms note), or
|
|
Result(data=False, errors=[ErrorInfo]) on SDK failure.
|
|
|
|
The caller (_send_gemini) ignores errors and treats failure as
|
|
'do not cache' (safe default: cache create is expensive; skipping
|
|
on count failure is a soft fallback to inline system_instruction).
|
|
"""
|
|
if _gemini_client is None:
|
|
return Result(data=False)
|
|
try:
|
|
count_resp = _gemini_client.models.count_tokens(model=_model, contents=[sys_instr])
|
|
total = count_resp.total_tokens
|
|
if total and total >= _GEMINI_CACHE_TOKEN_THRESHOLD:
|
|
return Result(data=True)
|
|
_append_comms("OUT", "request", {"message": f"[CACHING SKIPPED] Context too small ({total} tokens < {_GEMINI_CACHE_TOKEN_THRESHOLD})"})
|
|
return Result(data=False)
|
|
except Exception as e:
|
|
_append_comms("OUT", "request", {"message": f"[COUNT FAILED] {e}"})
|
|
return Result(
|
|
data=False,
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to count gemini tokens: {e}", source="ai_client._should_cache_gemini_result", original=e)],
|
|
)
|
|
|
|
def _create_gemini_cache_result(sys_instr: str, tools_decl: Any, file_items: list[dict[str, Any]] | None) -> Result[Any]:
|
|
"""Create a Gemini cache and the corresponding GenerateContentConfig.
|
|
|
|
Returns Result(data=chat_config_with_cached_content) on success and
|
|
Result(data=None, errors=[ErrorInfo]) on SDK failure. Side effects on
|
|
globals _gemini_cache, _gemini_cache_created_at, _gemini_cached_file_paths
|
|
are managed inside the helper (set on success, reset on failure to match
|
|
original semantics).
|
|
"""
|
|
global _gemini_cache, _gemini_cache_created_at, _gemini_cached_file_paths
|
|
types = _require_warmed("google.genai").types
|
|
try:
|
|
_gemini_cache = _gemini_client.caches.create(
|
|
model=_model,
|
|
config=types.CreateCachedContentConfig(
|
|
system_instruction=sys_instr,
|
|
tools=cast(Any, tools_decl),
|
|
ttl=f"{_GEMINI_CACHE_TTL}s",
|
|
)
|
|
)
|
|
_gemini_cache_created_at = time.time()
|
|
_gemini_cached_file_paths = [str(item.get("path", "")) for item in (file_items or []) if item.get("path")]
|
|
chat_config = types.GenerateContentConfig(
|
|
cached_content=_gemini_cache.name,
|
|
temperature=_temperature,
|
|
max_output_tokens=_max_tokens,
|
|
safety_settings=[types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=types.HarmBlockThreshold.BLOCK_ONLY_HIGH)]
|
|
)
|
|
_append_comms("OUT", "request", {"message": f"[CACHE CREATED] {_gemini_cache.name}"})
|
|
return Result(data=chat_config)
|
|
except Exception as e:
|
|
_gemini_cache = None
|
|
_gemini_cache_created_at = None
|
|
_gemini_cached_file_paths = []
|
|
_append_comms("OUT", "request", {"message": f"[CACHE FAILED] {type(e).__name__}: {e} \u2014 falling back to inline system_instruction"})
|
|
return Result(
|
|
data=None,
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to create gemini cache: {type(e).__name__}: {e}", source="ai_client._create_gemini_cache_result", original=e)],
|
|
)
|
|
|
|
def _send_cli_round_result(r_idx: int, adapter: Any, payload: Any, safety_settings: list[Any], sys_instr: str, stream_callback: Optional[Callable[[str], None]]) -> Result[dict[str, Any]]:
|
|
"""Call the Gemini CLI adapter for one round. Returns Result[resp_data].
|
|
|
|
On SDK failure, emits a response_received event with the error info
|
|
(preserving the original side-effect semantics) and returns
|
|
Result(errors=[ErrorInfo]). The caller (_send in _send_gemini_cli)
|
|
re-raises the original exception to preserve the outer catch flow.
|
|
"""
|
|
events.emit("request_start", payload={"provider": "gemini_cli", "model": _model, "round": r_idx})
|
|
if r_idx > 0:
|
|
_append_comms("OUT", "request", {"message": f"[CLI] [round {r_idx}] [msg {len(payload)}]"})
|
|
send_payload: Any = json.dumps(payload) if isinstance(payload, list) else payload
|
|
try:
|
|
resp_data = adapter.send(cast(str, send_payload), safety_settings=safety_settings, system_instruction=sys_instr, model=_model, stream_callback=stream_callback)
|
|
return Result(data=resp_data)
|
|
except Exception as e:
|
|
events.emit("response_received", payload={"provider": "gemini_cli", "model": _model, "usage": {}, "latency": 0, "round": r_idx, "error": str(e)})
|
|
return Result(
|
|
data=None,
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=str(e), source="ai_client._send_cli_round_result", original=e)],
|
|
)
|
|
|
|
def _extract_gemini_thoughts_result(resp: Any) -> Result[str]:
|
|
"""Extracts concatenated thinking text from a Gemini response object's parts.
|
|
|
|
Per the data-oriented convention: returns Result(data=thinking_text) on
|
|
success, Result(data="", errors=[ErrorInfo]) if attribute access fails.
|
|
The legacy caller (_extract_gemini_thoughts) returns result.data
|
|
(preserving the original str signature; an empty string signals "no
|
|
thoughts" to the caller).
|
|
"""
|
|
chunks: list[str] = []
|
|
try:
|
|
candidates = getattr(resp, "candidates", None) or []
|
|
for cand in candidates:
|
|
content = getattr(cand, "content", None)
|
|
if content is None: continue
|
|
parts = getattr(content, "parts", None) or []
|
|
for p in parts:
|
|
if getattr(p, "thought", False) and getattr(p, "text", None):
|
|
chunks.append(p.text)
|
|
return Result(data="".join(chunks).strip())
|
|
except Exception as e:
|
|
return Result(
|
|
data="",
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to extract gemini thoughts: {e}", source="ai_client._extract_gemini_thoughts_result", original=e)],
|
|
)
|
|
|
|
|
|
def _extract_gemini_thoughts(resp: Any) -> str:
|
|
"""
|
|
Extracts concatenated thinking text from a Gemini response object's parts.
|
|
Parts with thought=True are thinking segments; parts with thought=False or unset are visible text.
|
|
The google-genai SDK filters thoughts out of resp.text, so we must scan parts directly.
|
|
Returns "" if no thoughts are present.
|
|
"""
|
|
return _extract_gemini_thoughts_result(resp).data
|
|
|
|
def _get_gemini_history_list(chat: Any | None) -> list[Any]:
|
|
if not chat: return []
|
|
if hasattr(chat, "_history"): return cast(list[Any], chat._history)
|
|
if hasattr(chat, "history"): return cast(list[Any], chat.history)
|
|
if hasattr(chat, "get_history"): return cast(list[Any], chat.get_history())
|
|
return []
|
|
|
|
def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
enable_tools: bool = True,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
|
|
) -> Result[str]:
|
|
"""
|
|
Functional Purpose: Sends requests to Gemini via google-genai SDK, handling context caching, chat history, and tools.
|
|
Parameters & Inputs: md_content, user_message, base_dir, file_items, discussion_history, callbacks, enable_tools.
|
|
Immediate-Mode DAG / Thread Context: Called by: send; Calls: _ensure_gemini_client, client.caches.create, client.chats.create, run_with_tool_loop
|
|
SSDL: [I:_ensure_gemini_client] -> [B:Cache Changed?] -> [I:client.caches.create] -> [I:client.chats.create] -> [T:Result]
|
|
Thread Boundaries: Runs on caller thread (typically an async worker thread).
|
|
"""
|
|
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at, _gemini_cached_file_paths
|
|
genai = _require_warmed("google.genai")
|
|
types = genai.types
|
|
monitor = performance_monitor.get_monitor()
|
|
if monitor.enabled: monitor.start_component("ai_client._send_gemini")
|
|
try:
|
|
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
|
|
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"
|
|
td = _gemini_tool_declaration() if enable_tools else None
|
|
tools_decl = [td] if td else None
|
|
current_md_hash = hashlib.md5(md_content.encode()).hexdigest()
|
|
old_history = None
|
|
assert _gemini_client is not None
|
|
if _gemini_chat and _gemini_cache_md_hash != current_md_hash:
|
|
old_history = list(_get_gemini_history_list(_gemini_chat)) if _get_gemini_history_list(_gemini_chat) else []
|
|
if _gemini_cache:
|
|
_delete_gemini_cache_result()
|
|
_gemini_chat = None
|
|
_gemini_cache = None
|
|
_gemini_cache_created_at = None
|
|
_gemini_cached_file_paths = []
|
|
_append_comms("OUT", "request", {"message": "[CONTEXT CHANGED] Rebuilding cache and chat session..."})
|
|
if _gemini_chat and _gemini_cache and _gemini_cache_created_at:
|
|
elapsed = time.time() - _gemini_cache_created_at
|
|
if elapsed > _GEMINI_CACHE_TTL * 0.9:
|
|
old_history = list(_get_gemini_history_list(_gemini_chat)) if _get_gemini_history_list(_gemini_chat) else []
|
|
_delete_gemini_cache_result()
|
|
_gemini_chat = None
|
|
_gemini_cache = None
|
|
_gemini_cache_created_at = None
|
|
_gemini_cached_file_paths = []
|
|
_append_comms("OUT", "request", {"message": f"[CACHE TTL] Rebuilding cache (expired after {int(elapsed)}s)..."})
|
|
|
|
if not _gemini_chat:
|
|
chat_config = types.GenerateContentConfig(
|
|
system_instruction = sys_instr,
|
|
tools = cast(Any, tools_decl),
|
|
temperature = _temperature,
|
|
top_p = _top_p,
|
|
max_output_tokens = _max_tokens,
|
|
safety_settings = [types.SafetySetting(category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=types.HarmBlockThreshold.BLOCK_ONLY_HIGH)]
|
|
)
|
|
|
|
should_cache = _should_cache_gemini_result(sys_instr).data
|
|
if should_cache and _gemini_client:
|
|
cached_config_result = _create_gemini_cache_result(sys_instr, tools_decl, file_items)
|
|
if cached_config_result.ok:
|
|
chat_config = cached_config_result.data
|
|
kwargs: dict[str, Any] = {"model": _model, "config": chat_config}
|
|
if old_history:
|
|
kwargs["history"] = old_history
|
|
if _gemini_client:
|
|
_gemini_chat = _gemini_client.chats.create(**kwargs)
|
|
_gemini_cache_md_hash = current_md_hash
|
|
if discussion_history and not old_history:
|
|
_gemini_chat.send_message(f"[DISCUSSION HISTORY]\n\n{discussion_history}")
|
|
_append_comms("OUT", "request", {"message": f"[HISTORY INJECTED] {len(discussion_history)} chars"})
|
|
payload: str | list[types.Part] = user_message
|
|
all_text: list[str] = []
|
|
_cumulative_tool_bytes = 0
|
|
if _gemini_chat and _get_gemini_history_list(_gemini_chat):
|
|
for msg in _get_gemini_history_list(_gemini_chat):
|
|
if msg.role == "user" and hasattr(msg, "parts"):
|
|
for p in msg.parts:
|
|
if hasattr(p, "function_response") and p.function_response and hasattr(p.function_response, "response"):
|
|
r = p.function_response.response
|
|
if isinstance(r, dict) and "output" in r:
|
|
val = r["output"]
|
|
if isinstance(val, str):
|
|
marker = _get_context_marker()
|
|
if marker in val:
|
|
val = val.split(marker)[0].strip()
|
|
if _history_trunc_limit > 0 and len(val) > _history_trunc_limit:
|
|
val = val[:_history_trunc_limit] + "\n\n... [TRUNCATED BY SYSTEM TO SAVE TOKENS.]"
|
|
r["output"] = val
|
|
for r_idx in range(MAX_TOOL_ROUNDS + 2):
|
|
events.emit("request_start", payload={"provider": "gemini", "model": _model, "round": r_idx})
|
|
|
|
# Shared config for this round
|
|
td = _gemini_tool_declaration() if enable_tools else None
|
|
config = types.GenerateContentConfig(
|
|
tools=[td] if td else [],
|
|
temperature=_temperature,
|
|
top_p=_top_p,
|
|
max_output_tokens=_max_tokens,
|
|
)
|
|
|
|
if stream_callback:
|
|
resp = _gemini_chat.send_message_stream(payload, config=config)
|
|
txt_chunks: list[str] = []
|
|
calls = []
|
|
usage = {}
|
|
reason = "STOP"
|
|
final_resp = None
|
|
for chunk in resp:
|
|
if chunk.text:
|
|
txt_chunks.append(chunk.text)
|
|
stream_callback(chunk.text)
|
|
if chunk.candidates:
|
|
c = chunk.candidates[0]
|
|
if c.content and c.content.parts:
|
|
calls.extend([p.function_call for p in c.content.parts if p.function_call])
|
|
if hasattr(c, "finish_reason") and c.finish_reason:
|
|
reason = c.finish_reason.name
|
|
if chunk.usage_metadata:
|
|
usage = {
|
|
"input_tokens": chunk.usage_metadata.prompt_token_count,
|
|
"output_tokens": chunk.usage_metadata.candidates_token_count,
|
|
"total_tokens": chunk.usage_metadata.total_token_count,
|
|
"cache_read_input_tokens": getattr(chunk.usage_metadata, "cached_content_token_count", 0)
|
|
}
|
|
final_resp = chunk
|
|
txt = "".join(txt_chunks)
|
|
if txt: all_text.append(txt)
|
|
# Final validation of response object for subsequent code
|
|
resp = final_resp
|
|
events.emit("response_received", payload={"provider": "gemini", "model": _model, "usage": usage, "round": r_idx})
|
|
else:
|
|
resp = _gemini_chat.send_message(payload, config=config)
|
|
txt = resp.text or ""
|
|
if txt: all_text.append(txt)
|
|
calls = [p.function_call for c in resp.candidates if getattr(c, "content", None) for p in c.content.parts if p.function_call]
|
|
usage = {
|
|
"input_tokens": getattr(resp.usage_metadata, "prompt_token_count", 0),
|
|
"output_tokens": getattr(resp.usage_metadata, "candidates_token_count", 0),
|
|
"total_tokens": getattr(resp.usage_metadata, "total_token_count", 0),
|
|
"cache_read_input_tokens": getattr(resp.usage_metadata, "cached_content_token_count", 0)
|
|
}
|
|
reason = resp.candidates[0].finish_reason.name if (resp.candidates and hasattr(resp.candidates[0], "finish_reason")) else "STOP"
|
|
events.emit("response_received", payload={"provider": "gemini", "model": _model, "usage": usage, "round": r_idx})
|
|
|
|
_append_comms("IN", "response", {"round": r_idx, "stop_reason": reason, "text": txt, "tool_calls": [{"name": c.name, "args": dict(c.args)} for c in calls], "usage": usage})
|
|
total_in = usage.get("input_tokens", 0)
|
|
if total_in > _GEMINI_MAX_INPUT_TOKENS * 0.4 and _gemini_chat and _get_gemini_history_list(_gemini_chat):
|
|
hist = _get_gemini_history_list(_gemini_chat)
|
|
dropped = 0
|
|
while len(hist) > 4 and total_in > _GEMINI_MAX_INPUT_TOKENS * 0.3:
|
|
saved = 0
|
|
for _ in range(2):
|
|
if not hist: break
|
|
for p in hist[0].parts:
|
|
if hasattr(p, "text") and p.text:
|
|
saved += int(len(p.text) / _CHARS_PER_TOKEN)
|
|
elif hasattr(p, "function_response") and p.function_response:
|
|
r = getattr(p.function_response, "response", {})
|
|
if isinstance(r, dict):
|
|
saved += int(len(str(r.get("output", ""))) / _CHARS_PER_TOKEN)
|
|
hist.pop(0)
|
|
dropped += 1
|
|
total_in -= max(saved, 200)
|
|
if dropped > 0:
|
|
_append_comms("OUT", "request", {"message": f"[GEMINI HISTORY TRIMMED: dropped {dropped} old entries to stay within token budget]"})
|
|
if not calls or r_idx > MAX_TOOL_ROUNDS: break
|
|
f_resps: list[types.Part] = []
|
|
log: list[dict[str, Any]] = []
|
|
|
|
# Execute tools concurrently
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
results = asyncio.run_coroutine_threadsafe(
|
|
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback),
|
|
loop
|
|
).result()
|
|
except RuntimeError:
|
|
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback))
|
|
|
|
for i, (name, call_id, out, _) in enumerate(results):
|
|
# Check if this is the last tool to trigger file refresh
|
|
if i == len(results) - 1:
|
|
if file_items:
|
|
file_items, changed = _reread_file_items(file_items)
|
|
ctx = _build_file_diff_text(changed)
|
|
if ctx:
|
|
out += f"\n\n{_get_context_marker()}\n\n{ctx}"
|
|
if r_idx == MAX_TOOL_ROUNDS: out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
|
|
|
out = _truncate_tool_output(out)
|
|
_cumulative_tool_bytes += len(out)
|
|
f_resps.append(types.Part(function_response=types.FunctionResponse(name=cast(str, name), response={"output": out})))
|
|
log.append({"tool_use_id": name, "content": out})
|
|
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
|
|
|
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
|
f_resps.append(types.Part(text=
|
|
f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
|
|
))
|
|
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
|
|
_append_comms("OUT", "tool_result_send", {"results": log})
|
|
payload = f_resps
|
|
res = "\n\n".join(all_text) if all_text else "(No text returned)"
|
|
thought_text = _extract_gemini_thoughts(final_resp if stream_callback else resp)
|
|
if thought_text:
|
|
res = f"<thinking>\n{thought_text}\n</thinking>\n\n{res}"
|
|
if monitor.enabled: monitor.end_component("ai_client._send_gemini")
|
|
return Result(data=res)
|
|
except Exception as e:
|
|
if monitor.enabled: monitor.end_component("ai_client._send_gemini")
|
|
return Result(data="", errors=[_classify_gemini_error(e, source="ai_client.gemini")])
|
|
|
|
def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Result[str]:
|
|
from src.openai_compatible import OpenAICompatibleRequest, NormalizedResponse
|
|
"""
|
|
[C: src/ai_server.py:_handle_send]
|
|
Functional Purpose: Sends requests to Gemini via the headless Gemini CLI subprocess adapter.
|
|
Parameters & Inputs: md_content, user_message, base_dir, file_items, discussion_history, callbacks.
|
|
Immediate-Mode DAG / Thread Context: Called by: send; Calls: run_with_tool_loop, GeminiCliAdapter.send
|
|
SSDL:
|
|
[I:run_with_tool_loop] -> [I:GeminiCliAdapter.send] -> [T:Result]
|
|
Thread Boundaries: Runs on caller thread (typically an async worker thread).
|
|
"""
|
|
global _gemini_cli_adapter
|
|
try:
|
|
if _gemini_cli_adapter is None:
|
|
_gemini_cli_adapter = GeminiCliAdapter(binary_path="gemini")
|
|
adapter = _gemini_cli_adapter
|
|
mcp_client.configure(file_items or [], [base_dir])
|
|
sys_instr = f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"
|
|
safety_settings = [{'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'threshold': 'BLOCK_ONLY_HIGH'}]
|
|
payload: Union[str, list[dict[str, Any]]] = user_message
|
|
if adapter.session_id is None:
|
|
if discussion_history:
|
|
payload = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"
|
|
all_text: list[str] = []
|
|
cumulative_tool_bytes = 0
|
|
|
|
def _send(r_idx: int) -> NormalizedResponse:
|
|
if adapter is None:
|
|
return NormalizedResponse(text="(adapter unavailable)", tool_calls=[], usage_input_tokens=0, usage_output_tokens=0, usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=None)
|
|
send_result = _send_cli_round_result(r_idx, adapter, payload, safety_settings, sys_instr, stream_callback)
|
|
if not send_result.ok:
|
|
raise cast(Exception, send_result.errors[0].original) from None
|
|
resp_data = send_result.data
|
|
cli_stderr = resp_data.get("stderr", "")
|
|
if cli_stderr:
|
|
sys.stderr.write(f"\n--- Gemini CLI stderr ---\n{cli_stderr}\n-------------------------\n")
|
|
sys.stderr.flush()
|
|
txt = cast(str, resp_data.get("text", ""))
|
|
if txt: all_text.append(txt)
|
|
calls = cast(List[dict[str, Any]], resp_data.get("tool_calls", []))
|
|
usage = adapter.last_usage or {}
|
|
latency = adapter.last_latency
|
|
events.emit("response_received", payload={"provider": "gemini_cli", "model": _model, "usage": usage, "latency": latency, "round": r_idx})
|
|
log_calls: list[dict[str, Any]] = []
|
|
for c in calls:
|
|
log_calls.append({"name": c.get("name"), "args": c.get("args"), "id": c.get("id")})
|
|
_append_comms("IN", "response", {
|
|
"round": r_idx,
|
|
"stop_reason": "TOOL_USE" if calls else "STOP",
|
|
"text": txt,
|
|
"tool_calls": log_calls,
|
|
"usage": usage
|
|
})
|
|
if txt and calls:
|
|
cb = get_comms_log_callback()
|
|
if cb:
|
|
cb({
|
|
"ts": project_manager.now_ts(),
|
|
"direction": "IN",
|
|
"kind": "history_add",
|
|
"payload": {"role": "AI", "content": txt}
|
|
})
|
|
return NormalizedResponse(text=txt, tool_calls=calls, usage_input_tokens=usage.get("prompt_tokens", 0), usage_output_tokens=usage.get("completion_tokens", 0), usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=resp_data)
|
|
|
|
def _pre_dispatch(r_idx: int, calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
nonlocal payload, cumulative_tool_bytes, file_items
|
|
tool_results_for_cli: list[dict[str, Any]] = []
|
|
results_iter: list[tuple[str, str, str, str]] = []
|
|
from src.ai_client import _execute_tool_calls_concurrently as _executor
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
results_iter = loop.run_until_complete(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback)) if False else asyncio.run_coroutine_threadsafe(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback), loop).result()
|
|
except RuntimeError:
|
|
results_iter = asyncio.run(_executor(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback))
|
|
for i, (name, call_id, out, _) in enumerate(results_iter):
|
|
if i == len(results_iter) - 1:
|
|
if file_items:
|
|
file_items, changed = _reread_file_items(file_items)
|
|
ctx = _build_file_diff_text(changed)
|
|
if ctx:
|
|
out += f"\n\n{_get_context_marker()}\n\n{ctx}"
|
|
if r_idx == MAX_TOOL_ROUNDS:
|
|
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
|
out = _truncate_tool_output(out)
|
|
cumulative_tool_bytes += len(out)
|
|
tool_results_for_cli.append({"role": "tool", "tool_call_id": call_id, "name": name, "content": out})
|
|
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
|
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": r_idx})
|
|
payload = tool_results_for_cli
|
|
if cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
|
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {cumulative_tool_bytes} bytes]"})
|
|
return calls
|
|
|
|
run_with_tool_loop(
|
|
client=adapter, request=lambda _i: cast(OpenAICompatibleRequest, None),
|
|
base_dir=base_dir, vendor_name="gemini_cli",
|
|
pre_tool_callback=pre_tool_callback, qa_callback=qa_callback,
|
|
stream_callback=stream_callback, patch_callback=patch_callback,
|
|
send_func=_send, on_pre_dispatch=_pre_dispatch,
|
|
)
|
|
final_text = all_text[-1] if all_text else "(No text returned)"
|
|
return Result(data=final_text)
|
|
except Exception as e:
|
|
return Result(data="", errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=str(e), source="ai_client.gemini_cli", original=e)])
|
|
|
|
#endregion: Gemini Provider
|
|
|
|
#region: DeepSeek Provider
|
|
|
|
def _list_deepseek_models(api_key: str) -> list[str]:
|
|
return ["deepseek-chat", "deepseek-reasoner"]
|
|
|
|
def _repair_deepseek_history(history: list[dict[str, Any]]) -> None:
|
|
if not history:
|
|
return
|
|
last = history[-1]
|
|
if last.get("role") != "assistant":
|
|
return
|
|
tool_calls = last.get("tool_calls", [])
|
|
if not tool_calls:
|
|
return
|
|
call_ids = [tc.get("id") for tc in tool_calls if tc.get("id")]
|
|
for cid in call_ids:
|
|
# Check if already present in tail (to be safe, though usually missing if we're here)
|
|
already_has = any(m.get("role") == "tool" and m.get("tool_call_id") == cid for m in history[-len(call_ids)-1:])
|
|
if not already_has:
|
|
history.append({
|
|
"role": "tool",
|
|
|
|
"tool_call_id": cid,
|
|
"content": "ERROR: Session was interrupted before tool result was recorded.",
|
|
})
|
|
|
|
def _ensure_deepseek_client() -> None:
|
|
global _deepseek_client
|
|
if _deepseek_client is None:
|
|
_load_credentials()
|
|
pass
|
|
|
|
def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
stream: bool = False,
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Result[str]:
|
|
"""
|
|
[C: src/ai_server.py:_handle_send]
|
|
Functional Purpose: Sends requests to DeepSeek via requests.post API call, managing history repairs and tools.
|
|
Parameters & Inputs: md_content, user_message, base_dir, file_items, discussion_history, stream, callbacks.
|
|
Immediate-Mode DAG / Thread Context: Called by: send; Calls: _ensure_deepseek_client, _repair_deepseek_history, requests.post
|
|
SSDL:
|
|
[I:_ensure_deepseek_client] -> [I:_repair_deepseek_history] -> [I:requests.post] -> [T:Result]
|
|
Thread Boundaries: Runs on caller thread (typically an async worker thread).
|
|
"""
|
|
requests = _require_warmed("requests")
|
|
monitor = performance_monitor.get_monitor()
|
|
if monitor.enabled: monitor.start_component("ai_client._send_deepseek")
|
|
try:
|
|
mcp_client.configure(file_items or [], [base_dir])
|
|
creds = _load_credentials()
|
|
api_key = creds.get("deepseek", {}).get("api_key")
|
|
if not api_key:
|
|
if monitor.enabled: monitor.end_component("ai_client._send_deepseek")
|
|
raise ValueError("DeepSeek API key not found in credentials.toml")
|
|
api_url = "https://api.deepseek.com/chat/completions"
|
|
headers = {
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
is_reasoner = _model in ("deepseek-reasoner", "deepseek-r1")
|
|
|
|
# Update history following Anthropic pattern
|
|
with _deepseek_history_lock:
|
|
_repair_deepseek_history(_deepseek_history)
|
|
if discussion_history and not _deepseek_history:
|
|
user_content = f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"
|
|
else:
|
|
user_content = user_message
|
|
_deepseek_history.append({"role": "user", "content": user_content})
|
|
|
|
all_text_parts: list[str] = []
|
|
_cumulative_tool_bytes = 0
|
|
|
|
for round_idx in range(MAX_TOOL_ROUNDS + 2):
|
|
current_api_messages: list[dict[str, Any]] = []
|
|
|
|
# DeepSeek R1 (Reasoner) can be extremely strict about the 'system' role.
|
|
# For maximum compatibility, we'll only use 'system' for non-reasoner models.
|
|
if not is_reasoner:
|
|
sys_msg = {"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}
|
|
current_api_messages.append(sys_msg)
|
|
|
|
with _deepseek_history_lock:
|
|
for i, msg in enumerate(_deepseek_history):
|
|
# Create a clean copy of the message for the API
|
|
role = msg.get("role")
|
|
api_msg = {"role": role}
|
|
|
|
content = msg.get("content")
|
|
if i == 0 and is_reasoner:
|
|
# Prepend system instructions to the first user message for R1
|
|
content = f"System Instructions:\n{_get_combined_system_prompt()}\n\nContext:\n{md_content}\n\n---\n\n{content}"
|
|
|
|
if role == "assistant":
|
|
# OpenAI/DeepSeek: content MUST be a string if tool_calls is absent
|
|
# If tool_calls is present, content can be null
|
|
if msg.get("tool_calls"):
|
|
api_msg["content"] = content or None
|
|
api_msg["tool_calls"] = msg["tool_calls"]
|
|
else:
|
|
api_msg["content"] = content or ""
|
|
if msg.get("reasoning_content"):
|
|
api_msg["reasoning_content"] = msg["reasoning_content"]
|
|
elif role == "tool":
|
|
api_msg["content"] = content or ""
|
|
api_msg["tool_call_id"] = msg.get("tool_call_id")
|
|
else:
|
|
api_msg["content"] = content or ""
|
|
|
|
current_api_messages.append(api_msg)
|
|
|
|
request_payload: dict[str, Any] = {
|
|
"model": _model,
|
|
"messages": current_api_messages,
|
|
"stream": stream,
|
|
}
|
|
|
|
if stream:
|
|
request_payload["stream_options"] = {"include_usage": True}
|
|
|
|
if not is_reasoner:
|
|
request_payload["temperature"] = _temperature
|
|
request_payload["top_p"] = _top_p
|
|
# DeepSeek max_tokens is for the output, clamp to 8192 which is their hard limit for V3/Chat
|
|
request_payload["max_tokens"] = min(_max_tokens, 8192)
|
|
tools = _get_deepseek_tools()
|
|
if tools:
|
|
request_payload["tools"] = tools
|
|
|
|
events.emit("request_start", payload={"provider": "deepseek", "model": _model, "round": round_idx, "streaming": stream})
|
|
|
|
try:
|
|
response = requests.post(api_url, headers=headers, json=request_payload, timeout=120, stream=stream)
|
|
response.raise_for_status()
|
|
except requests.exceptions.RequestException as e:
|
|
if monitor.enabled: monitor.end_component("ai_client._send_deepseek")
|
|
return Result(data="", errors=[_classify_deepseek_error(e, source="ai_client.deepseek")])
|
|
|
|
assistant_text = ""
|
|
tool_calls_raw = []
|
|
reasoning_content = ""
|
|
finish_reason = "stop"
|
|
usage = {}
|
|
|
|
if stream:
|
|
aggregated_content = ""
|
|
aggregated_tool_calls: list[dict[str, Any]] = []
|
|
aggregated_reasoning = ""
|
|
current_usage: dict[str, Any] = {}
|
|
final_finish_reason = "stop"
|
|
for line in response.iter_lines():
|
|
if not line:
|
|
continue
|
|
decoded = line.decode('utf-8')
|
|
if decoded.startswith('data: '):
|
|
chunk_str = decoded[len('data: '):]
|
|
if chunk_str.strip() == '[DONE]':
|
|
continue
|
|
try:
|
|
chunk = json.loads(chunk_str)
|
|
if not chunk.get("choices"):
|
|
if chunk.get("usage"):
|
|
current_usage = cast(dict[str, Any], chunk["usage"])
|
|
continue
|
|
delta = cast(dict[str, Any], chunk.get("choices", [{}])[0].get("delta", {}))
|
|
if delta.get("content"):
|
|
content_chunk = cast(str, delta["content"])
|
|
aggregated_content += content_chunk
|
|
if stream_callback:
|
|
stream_callback(content_chunk)
|
|
if delta.get("reasoning_content"):
|
|
aggregated_reasoning += cast(str, delta["reasoning_content"])
|
|
if delta.get("tool_calls"):
|
|
for tc_delta in cast(List[dict[str, Any]], delta["tool_calls"]):
|
|
idx = cast(int, tc_delta.get("index", 0))
|
|
while len(aggregated_tool_calls) <= idx:
|
|
aggregated_tool_calls.append({"id": "", "type": "function", "function": {"name": "", "arguments": ""}})
|
|
target = aggregated_tool_calls[idx]
|
|
if tc_delta.get("id"):
|
|
target["id"] = cast(str, tc_delta["id"])
|
|
if tc_delta.get("function", {}).get("name"):
|
|
target["function"]["name"] += cast(str, tc_delta["function"]["name"])
|
|
if tc_delta.get("function", {}).get("arguments"):
|
|
target["function"]["arguments"] += cast(str, tc_delta["function"]["arguments"])
|
|
if chunk.get("choices", [{}])[0].get("finish_reason"):
|
|
final_finish_reason = cast(str, chunk["choices"][0]["finish_reason"])
|
|
if chunk.get("usage"):
|
|
current_usage = cast(dict[str, Any], chunk["usage"])
|
|
except json.JSONDecodeError:
|
|
continue
|
|
assistant_text = aggregated_content
|
|
tool_calls_raw = aggregated_tool_calls
|
|
reasoning_content = aggregated_reasoning
|
|
finish_reason = final_finish_reason
|
|
usage = current_usage
|
|
else:
|
|
response_data = response.json()
|
|
choices = response_data.get("choices", [])
|
|
if not choices:
|
|
_append_comms("IN", "response", {"round": round_idx, "text": "(No choices returned)", "usage": response_data.get("usage", {})})
|
|
break
|
|
choice = choices[0]
|
|
message = choice.get("message", {})
|
|
assistant_text = message.get("content", "")
|
|
tool_calls_raw = message.get("tool_calls", [])
|
|
reasoning_content = message.get("reasoning_content", "")
|
|
finish_reason = choice.get("finish_reason", "stop")
|
|
usage = response_data.get("usage", {})
|
|
|
|
thinking_tags = ""
|
|
if reasoning_content:
|
|
thinking_tags = f"<thinking>\n{reasoning_content}\n</thinking>\n"
|
|
full_assistant_text = thinking_tags + assistant_text
|
|
|
|
with _deepseek_history_lock:
|
|
# DeepSeek/OpenAI: If tool_calls are present, content can be null but should usually be present
|
|
msg_to_store: dict[str, Any] = {"role": "assistant", "content": assistant_text or None}
|
|
if reasoning_content:
|
|
msg_to_store["reasoning_content"] = reasoning_content
|
|
if tool_calls_raw:
|
|
msg_to_store["tool_calls"] = tool_calls_raw
|
|
_deepseek_history.append(msg_to_store)
|
|
|
|
if full_assistant_text:
|
|
all_text_parts.append(full_assistant_text)
|
|
|
|
_append_comms("IN", "response", {
|
|
"round": round_idx,
|
|
"stop_reason": finish_reason,
|
|
"text": full_assistant_text,
|
|
"tool_calls": tool_calls_raw,
|
|
"usage": usage,
|
|
"streaming": stream
|
|
})
|
|
|
|
if finish_reason != "tool_calls" and not tool_calls_raw:
|
|
break
|
|
if round_idx > MAX_TOOL_ROUNDS:
|
|
break
|
|
|
|
# Execute tools concurrently
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
results = asyncio.run_coroutine_threadsafe(
|
|
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback),
|
|
loop
|
|
).result()
|
|
except RuntimeError:
|
|
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback))
|
|
|
|
tool_results_for_history: list[dict[str, Any]] = []
|
|
for i, (name, call_id, out, _) in enumerate(results):
|
|
if i == len(results) - 1:
|
|
if file_items:
|
|
file_items, changed = _reread_file_items(file_items)
|
|
ctx = _build_file_diff_text(changed)
|
|
if ctx:
|
|
out += f"\n\n{_get_context_marker()}\n\n{ctx}"
|
|
if round_idx == MAX_TOOL_ROUNDS:
|
|
out += "\n\n[SYSTEM: MAX ROUNDS. PROVIDE FINAL ANSWER.]"
|
|
|
|
truncated = _truncate_tool_output(out)
|
|
_cumulative_tool_bytes += len(truncated)
|
|
tool_results_for_history.append({
|
|
"role": "tool",
|
|
"tool_call_id": call_id,
|
|
"content": truncated,
|
|
})
|
|
_append_comms("IN", "tool_result", {"name": name, "id": call_id, "output": out})
|
|
events.emit("tool_execution", payload={"status": "completed", "tool": name, "result": out, "round": round_idx})
|
|
|
|
if _cumulative_tool_bytes > _MAX_TOOL_OUTPUT_BYTES:
|
|
tool_results_for_history.append({
|
|
"role": "user",
|
|
"content": f"SYSTEM WARNING: Cumulative tool output exceeded {_MAX_TOOL_OUTPUT_BYTES // 1000}KB budget. Provide your final answer now."
|
|
})
|
|
_append_comms("OUT", "request", {"message": f"[TOOL OUTPUT BUDGET EXCEEDED: {_cumulative_tool_bytes} bytes]"})
|
|
|
|
with _deepseek_history_lock:
|
|
for tr in tool_results_for_history:
|
|
_deepseek_history.append(tr)
|
|
|
|
res = "\n\n".join(all_text_parts) if all_text_parts else "(No text returned)"
|
|
if monitor.enabled: monitor.end_component("ai_client._send_deepseek")
|
|
return Result(data=res)
|
|
except Exception as e:
|
|
if monitor.enabled: monitor.end_component("ai_client._send_deepseek")
|
|
return Result(data="", errors=[_classify_deepseek_error(e, source="ai_client.deepseek")])
|
|
|
|
#endregion: DeepSeek Provider
|
|
|
|
#region: MiniMax Provider
|
|
|
|
_MINIMAX_DEFAULT_MODELS: list[str] = ["MiniMax-M2.7", "MiniMax-M2.5", "MiniMax-M2.1", "MiniMax-M2"]
|
|
|
|
#TODO(Ed): This causes a pause on gui thread, this should be cached.
|
|
def _list_minimax_models_result(api_key: str) -> Result[list[str]]:
|
|
"""List available MiniMax models via the OpenAI-compatible SDK.
|
|
|
|
Returns Result(data=sorted_models) on success, Result(data=defaults, errors=[ErrorInfo])
|
|
on SDK failure. The legacy caller (_list_minimax_models) returns result.data
|
|
(preserving the original list[str] signature; defaults are returned on failure
|
|
to maintain the original behavior).
|
|
"""
|
|
try:
|
|
openai = _require_warmed("openai")
|
|
OpenAI = openai.OpenAI
|
|
creds = _load_credentials()
|
|
base_url = creds.get("minimax", {}).get("base_url") or "https://api.minimax.io/v1"
|
|
client = OpenAI(api_key=api_key, base_url=base_url)
|
|
models_list = client.models.list()
|
|
found = [m.id for m in models_list]
|
|
if found:
|
|
return Result(data=sorted(found))
|
|
return Result(data=_MINIMAX_DEFAULT_MODELS)
|
|
except Exception as e:
|
|
return Result(
|
|
data=_MINIMAX_DEFAULT_MODELS,
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to list minimax models: {e}", source="ai_client._list_minimax_models_result", original=e)],
|
|
)
|
|
|
|
|
|
def _list_minimax_models(api_key: str) -> list[str]:
|
|
return _list_minimax_models_result(api_key).data
|
|
|
|
def _repair_minimax_history(history: list[dict[str, Any]]) -> None:
|
|
if not history: return
|
|
last = history[-1]
|
|
if last.get("role") != "assistant": return
|
|
tool_calls = last.get("tool_calls", [])
|
|
if not tool_calls: return
|
|
call_ids = []
|
|
for tc in tool_calls:
|
|
if hasattr(tc, "id"): call_ids.append(tc.id)
|
|
elif isinstance(tc, dict) and tc.get("id"): call_ids.append(tc["id"])
|
|
|
|
for cid in call_ids:
|
|
already_has = any(m.get("role") == "tool" and m.get("tool_call_id") == cid for m in history[-len(call_ids)-1:])
|
|
if not already_has:
|
|
history.append({
|
|
"role": "tool",
|
|
"tool_call_id": cid,
|
|
"content": "ERROR: Session was interrupted before tool result was recorded.",
|
|
})
|
|
|
|
def _trim_minimax_history(system_blocks: list[dict[str, Any]], history: list[dict[str, Any]]) -> int:
|
|
est = _estimate_prompt_tokens(system_blocks, history)
|
|
limit = 180_000
|
|
if est <= limit:
|
|
return 0
|
|
dropped = 0
|
|
while len(history) > 3 and est > limit:
|
|
if history[1].get("role") == "assistant" and len(history) > 2 and history[2].get("role") == "user":
|
|
removed_asst = history.pop(1)
|
|
removed_user = history.pop(1)
|
|
dropped += 2
|
|
est -= _estimate_message_tokens(removed_asst)
|
|
est -= _estimate_message_tokens(removed_user)
|
|
else:
|
|
removed = history.pop(1)
|
|
dropped += 1
|
|
est -= _estimate_message_tokens(removed)
|
|
|
|
# Ensure we don't leave dangling 'tool' messages if their parent 'assistant' was dropped.
|
|
# MiniMax strictly requires 'tool' messages to immediately follow 'assistant' with tool_calls.
|
|
while len(history) > 1 and history[1].get("role") == "tool":
|
|
removed_tool = history.pop(1)
|
|
dropped += 1
|
|
est -= _estimate_message_tokens(removed_tool)
|
|
return dropped
|
|
|
|
def _ensure_minimax_client() -> None:
|
|
global _minimax_client
|
|
openai = _require_warmed("openai")
|
|
if _minimax_client is None:
|
|
creds = _load_credentials()
|
|
api_key = creds.get("minimax", {}).get("api_key")
|
|
if not api_key:
|
|
raise ValueError("MiniMax API key not found in credentials.toml")
|
|
base_url = creds.get("minimax", {}).get("base_url") or "https://api.minimax.io/v1"
|
|
_minimax_client = openai.OpenAI(api_key=api_key, base_url=base_url)
|
|
|
|
def _ensure_grok_client() -> Any:
|
|
global _grok_client
|
|
if _grok_client is None:
|
|
openai = _require_warmed("openai")
|
|
creds = _load_credentials()
|
|
api_key = creds.get("grok", {}).get("api_key")
|
|
if not api_key:
|
|
raise ValueError("Grok API key not found in credentials.toml")
|
|
_grok_client = openai.OpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
|
|
return _grok_client
|
|
|
|
def _send_grok(md_content: str, user_message: str, base_dir: str,
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
stream: bool = False,
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Result[str]:
|
|
"""
|
|
Dispatches queries to Grok (x.ai) model endpoint using OpenAI compatible client.
|
|
|
|
Functional Purpose:
|
|
Initializes/ensures Grok client, sets up tool schema, appends new messages to Grok history,
|
|
constructs the Grok request structure, and executes it via the run_with_tool_loop.
|
|
|
|
Parameters & Inputs:
|
|
md_content (str): Markdown formatted context content.
|
|
user_message (str): User prompt text.
|
|
base_dir (str): Workspace root directory.
|
|
file_items (Optional[list[dict[str, Any]]]): Media or file items for multimodal queries.
|
|
discussion_history (str): Contextual discussion text.
|
|
stream (bool): Whether to stream output.
|
|
pre_tool_callback (Optional[Callable]): Hook for HITL tool confirmation.
|
|
qa_callback (Optional[Callable]): Verification callback for QA checks.
|
|
stream_callback (Optional[Callable]): Callback function for streaming chunks.
|
|
patch_callback (Optional[Callable]): Validation callback for code edits.
|
|
|
|
Returns:
|
|
Result[str]: Wrap of string response and potential errors.
|
|
|
|
Immediate-Mode DAG / Thread Context:
|
|
Called by: send
|
|
Calls: _ensure_grok_client, _get_deepseek_tools, get_capabilities, run_with_tool_loop
|
|
|
|
SSDL:
|
|
`[I:_ensure_grok_client] -> [I:run_with_tool_loop] -> [T:Result]`
|
|
|
|
Thread Boundaries:
|
|
Runs synchronously in the caller thread; synchronizes Grok history using _grok_history_lock.
|
|
"""
|
|
from src.openai_compatible import OpenAICompatibleRequest, _classify_openai_compatible_error
|
|
try:
|
|
client = _ensure_grok_client()
|
|
tools: list[dict[str, Any]] | None = _get_deepseek_tools() or None
|
|
caps = get_capabilities("grok", _model)
|
|
with _grok_history_lock:
|
|
user_content = user_message
|
|
if file_items:
|
|
for fi in file_items:
|
|
if fi.get("is_image") and fi.get("base64_data"):
|
|
user_content = f"[IMAGE: {fi.get('path', 'attachment')}]\n{user_content}"
|
|
if discussion_history and not _grok_history:
|
|
_grok_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"})
|
|
else:
|
|
_grok_history.append({"role": "user", "content": user_content})
|
|
def _build_grok_request(_round_idx: int) -> OpenAICompatibleRequest:
|
|
with _grok_history_lock:
|
|
messages: list[dict[str, Any]] = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}]
|
|
messages.extend(_grok_history)
|
|
extra_body: dict[str, Any] = {}
|
|
if caps.web_search:
|
|
extra_body["search_parameters"] = {"mode": "auto"}
|
|
if caps.x_search:
|
|
extra_body.setdefault("search_parameters", {})
|
|
extra_body["search_parameters"]["sources"] = [{"type": "x"}]
|
|
return OpenAICompatibleRequest(
|
|
messages=messages, model=_model, temperature=_temperature, top_p=_top_p,
|
|
max_tokens=_max_tokens, stream=stream, stream_callback=stream_callback,
|
|
tools=tools, tool_choice="auto" if tools else "auto",
|
|
extra_body=extra_body or None,
|
|
)
|
|
return Result(data=run_with_tool_loop(
|
|
client, _build_grok_request, capabilities=caps,
|
|
pre_tool_callback=pre_tool_callback, qa_callback=qa_callback, stream_callback=stream_callback,
|
|
patch_callback=patch_callback, base_dir=base_dir, vendor_name="grok",
|
|
history_lock=_grok_history_lock, history=_grok_history,
|
|
))
|
|
except Exception as exc:
|
|
return Result(data="", errors=[_classify_openai_compatible_error(exc, source="ai_client.grok")])
|
|
|
|
def _list_grok_models() -> list[str]:
|
|
from src.vendor_capabilities import list_models_for_vendor
|
|
return list_models_for_vendor("grok")
|
|
|
|
def _send_minimax(md_content: str, user_message: str, base_dir: str,
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
stream: bool = False,
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Result[str]:
|
|
"""
|
|
Dispatches queries to the MiniMax provider using OpenAI compatible client.
|
|
|
|
Functional Purpose:
|
|
Ensures client setup, performs MiniMax-specific history repairs, appends new messages,
|
|
constructs the MiniMax request structure, extracts reasoning content, and executes it via the tool loop.
|
|
|
|
Parameters & Inputs:
|
|
md_content (str): Markdown formatted context content.
|
|
user_message (str): User prompt text.
|
|
base_dir (str): Workspace root directory.
|
|
file_items (Optional[list[dict[str, Any]]]): Media or file items for multimodal queries.
|
|
discussion_history (str): Contextual discussion text.
|
|
stream (bool): Whether to stream output.
|
|
pre_tool_callback (Optional[Callable]): Hook for HITL tool confirmation.
|
|
qa_callback (Optional[Callable]): Verification callback for QA checks.
|
|
stream_callback (Optional[Callable]): Callback function for streaming chunks.
|
|
patch_callback (Optional[Callable]): Validation callback for code edits.
|
|
|
|
Returns:
|
|
Result[str]: Wrap of string response and potential errors.
|
|
|
|
Immediate-Mode DAG / Thread Context:
|
|
Called by: send
|
|
Calls: _ensure_minimax_client, _repair_minimax_history, _get_deepseek_tools,
|
|
get_capabilities, run_with_tool_loop
|
|
|
|
SSDL:
|
|
`[I:_ensure_minimax_client] -> [I:_repair_minimax_history] -> [I:run_with_tool_loop] -> [T:Result]`
|
|
|
|
Thread Boundaries:
|
|
Runs synchronously in the caller thread; synchronizes MiniMax history using _minimax_history_lock.
|
|
"""
|
|
from src.openai_compatible import OpenAICompatibleRequest
|
|
try:
|
|
_ensure_minimax_client()
|
|
tools: list[dict[str, Any]] | None = _get_deepseek_tools() or None
|
|
_repair_minimax_history(_minimax_history)
|
|
if discussion_history and not _minimax_history:
|
|
_minimax_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"})
|
|
else:
|
|
_minimax_history.append({"role": "user", "content": user_message})
|
|
def _build_minimax_request(_round_idx: int) -> OpenAICompatibleRequest:
|
|
with _minimax_history_lock:
|
|
messages: list[dict[str, Any]] = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}]
|
|
messages.extend(_minimax_history)
|
|
return OpenAICompatibleRequest(
|
|
messages=messages, model=_model, temperature=_temperature, top_p=_top_p,
|
|
max_tokens=min(_max_tokens, 8192), stream=stream, stream_callback=stream_callback,
|
|
tools=tools, tool_choice="auto" if tools else "auto",
|
|
)
|
|
def _extract_minimax_reasoning(raw_response: Any) -> str:
|
|
if raw_response and hasattr(raw_response, "choices"):
|
|
choice = raw_response.choices[0]
|
|
if hasattr(choice.message, "reasoning_details") and choice.message.reasoning_details:
|
|
return choice.message.reasoning_details[0].get("text", "") or ""
|
|
return ""
|
|
caps = get_capabilities("minimax", _model)
|
|
return Result(data=run_with_tool_loop(
|
|
_minimax_client, _build_minimax_request, capabilities=caps,
|
|
pre_tool_callback=pre_tool_callback, qa_callback=qa_callback, stream_callback=stream_callback,
|
|
patch_callback=patch_callback, base_dir=base_dir, vendor_name="minimax",
|
|
history_lock=_minimax_history_lock, history=_minimax_history,
|
|
trim_func=lambda h: _trim_minimax_history(_build_minimax_request(0).messages, h),
|
|
reasoning_extractor=_extract_minimax_reasoning if caps.reasoning else None,
|
|
wrap_reasoning_in_text=bool(caps.reasoning),
|
|
))
|
|
except Exception as exc:
|
|
return Result(data="", errors=[_classify_minimax_error(exc, source="ai_client.minimax")])
|
|
|
|
#endregion: MiniMax Provider
|
|
|
|
#region: Qwen Provider
|
|
|
|
def _ensure_qwen_client() -> None:
|
|
global _qwen_client, _qwen_region
|
|
if _qwen_client is None:
|
|
import dashscope
|
|
creds = _load_credentials()
|
|
api_key = creds.get("qwen", {}).get("api_key")
|
|
if not api_key:
|
|
raise ValueError("Qwen API key not found in credentials.toml")
|
|
_qwen_region = creds.get("qwen", {}).get("region", "china")
|
|
if _qwen_region == "international":
|
|
dashscope.base_http_api_url = "https://dashscope-intl.aliyuncs.com/api/v1"
|
|
else:
|
|
dashscope.base_http_api_url = "https://dashscope.aliyuncs.com/api/v1"
|
|
dashscope.api_key = api_key
|
|
_qwen_client = dashscope.Generation
|
|
|
|
def _dashscope_call(
|
|
model: str,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None,
|
|
*,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
top_p: float,
|
|
) -> dict[str, Any]:
|
|
import dashscope
|
|
from src.qwen_adapter import build_dashscope_tools
|
|
kwargs: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature,
|
|
"top_p": top_p,
|
|
"result_format": "message",
|
|
}
|
|
if tools:
|
|
kwargs["tools"] = build_dashscope_tools(tools)
|
|
resp = dashscope.Generation.call(**kwargs)
|
|
if getattr(resp, "status_code", 200) != 200:
|
|
from src.qwen_adapter import classify_dashscope_error
|
|
raise classify_dashscope_error(_dashscope_exception_from_response(resp)) from None
|
|
return {
|
|
"text": resp.output.text if hasattr(resp, "output") and resp.output else "",
|
|
"tool_calls": _extract_dashscope_tool_calls(resp),
|
|
"usage": {
|
|
"input_tokens": getattr(resp.usage, "input_tokens", 0) if hasattr(resp, "usage") and resp.usage else 0,
|
|
"output_tokens": getattr(resp.usage, "output_tokens", 0) if hasattr(resp, "usage") and resp.usage else 0,
|
|
},
|
|
}
|
|
|
|
def _dashscope_exception_from_response(resp: Any) -> Exception:
|
|
msg = getattr(resp, "message", "unknown dashscope error")
|
|
return RuntimeError(msg)
|
|
|
|
def _extract_dashscope_tool_calls(resp: Any) -> list[dict[str, Any]]:
|
|
out: list[dict[str, Any]] = []
|
|
if not (hasattr(resp, "output") and resp.output and getattr(resp.output, "tool_calls", None)):
|
|
return out
|
|
for tc in resp.output.tool_calls:
|
|
out.append({
|
|
"id": getattr(tc, "id", ""),
|
|
"type": "function",
|
|
"function": {
|
|
"name": getattr(tc.function, "name", "") if hasattr(tc, "function") else "",
|
|
"arguments": getattr(tc.function, "arguments", "{}") if hasattr(tc, "function") else "{}",
|
|
},
|
|
})
|
|
return out
|
|
|
|
def _list_qwen_models() -> list[str]:
|
|
from src.vendor_capabilities import list_models_for_vendor
|
|
return list_models_for_vendor("qwen")
|
|
|
|
def _send_qwen(md_content: str, user_message: str, base_dir: str,
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
stream: bool = False,
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Result[str]:
|
|
"""
|
|
Dispatches queries to Alibaba's Qwen model via DashScope SDK.
|
|
|
|
Functional Purpose:
|
|
Initializes/ensures DashScope setup, builds the conversation history,
|
|
and delegates the invocation to _dashscope_call which triggers dashscope.Generation.call.
|
|
|
|
Parameters & Inputs:
|
|
md_content (str): Markdown formatted context content.
|
|
user_message (str): User prompt text.
|
|
base_dir (str): Workspace root directory.
|
|
file_items (Optional[list[dict[str, Any]]]): Media or file items for multimodal queries.
|
|
discussion_history (str): Contextual discussion text.
|
|
stream (bool): Whether to stream output.
|
|
pre_tool_callback (Optional[Callable]): Hook for HITL tool confirmation.
|
|
qa_callback (Optional[Callable]): Verification callback for QA checks.
|
|
stream_callback (Optional[Callable]): Callback function for streaming chunks.
|
|
patch_callback (Optional[Callable]): Validation callback for code edits.
|
|
|
|
Returns:
|
|
Result[str]: Wrap of string response and potential errors.
|
|
|
|
Immediate-Mode DAG / Thread Context:
|
|
Called by: send
|
|
Calls: _ensure_qwen_client, _dashscope_call
|
|
|
|
SSDL:
|
|
`[I:_ensure_qwen_client] -> [I:dashscope.Generation.call] -> [T:Result]`
|
|
|
|
Thread Boundaries:
|
|
Runs synchronously in the caller thread; synchronizes history using _qwen_history_lock.
|
|
"""
|
|
from src.qwen_adapter import classify_dashscope_error
|
|
try:
|
|
_ensure_qwen_client()
|
|
with _qwen_history_lock:
|
|
user_content = user_message
|
|
if file_items:
|
|
for fi in file_items:
|
|
if fi.get("is_image") and fi.get("base64_data"):
|
|
user_content = f"[IMAGE: {fi.get('path', 'attachment')}]\n{user_content}"
|
|
if discussion_history and not _qwen_history:
|
|
_qwen_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"})
|
|
else:
|
|
_qwen_history.append({"role": "user", "content": user_content})
|
|
messages = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}]
|
|
messages.extend(_qwen_history)
|
|
resp = _dashscope_call(
|
|
model=_model,
|
|
messages=messages,
|
|
tools=None,
|
|
max_tokens=_max_tokens,
|
|
temperature=_temperature,
|
|
top_p=_top_p,
|
|
)
|
|
return Result(data=resp.get("text", ""))
|
|
except Exception as exc:
|
|
return Result(data="", errors=[classify_dashscope_error(exc, source="ai_client.qwen")])
|
|
|
|
#endregion: Qwen Provider
|
|
|
|
#region: Llama Provider
|
|
|
|
def _ensure_llama_client() -> Any:
|
|
global _llama_client, _llama_base_url, _llama_api_key
|
|
if _llama_client is None:
|
|
openai = _require_warmed("openai")
|
|
creds = _load_credentials()
|
|
configured_url = creds.get("llama", {}).get("base_url")
|
|
configured_key = creds.get("llama", {}).get("api_key")
|
|
if configured_url:
|
|
_llama_base_url = configured_url
|
|
if configured_key is not None:
|
|
_llama_api_key = configured_key or "ollama"
|
|
_llama_client = openai.OpenAI(api_key=_llama_api_key, base_url=_llama_base_url)
|
|
return _llama_client
|
|
|
|
def _send_llama(md_content: str, user_message: str, base_dir: str,
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
stream: bool = False,
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Result[str]:
|
|
"""
|
|
Dispatches queries to Llama-based models using OpenAI compatible client or native Ollama backend.
|
|
|
|
Functional Purpose:
|
|
Routes execution either to _send_llama_native (if using a local/Ollama base URL) or
|
|
to the OpenAI compatible client setup with history management and tool loop execution.
|
|
|
|
Parameters & Inputs:
|
|
md_content (str): Markdown formatted context content.
|
|
user_message (str): User prompt text.
|
|
base_dir (str): Workspace root directory.
|
|
file_items (Optional[list[dict[str, Any]]]): Media or file items for multimodal queries.
|
|
discussion_history (str): Contextual discussion text.
|
|
stream (bool): Whether to stream output.
|
|
pre_tool_callback (Optional[Callable]): Hook for HITL tool confirmation.
|
|
qa_callback (Optional[Callable]): Verification callback for QA checks.
|
|
stream_callback (Optional[Callable]): Callback function for streaming chunks.
|
|
patch_callback (Optional[Callable]): Validation callback for code edits.
|
|
|
|
Returns:
|
|
Result[str]: Wrap of string response and potential errors.
|
|
|
|
Immediate-Mode DAG / Thread Context:
|
|
Called by: send
|
|
Calls: _send_llama_native, _ensure_llama_client, _get_deepseek_tools,
|
|
get_capabilities, run_with_tool_loop
|
|
|
|
SSDL:
|
|
`[I:_ensure_llama_client] -> [I:run_with_tool_loop] -> [T:Result]`
|
|
|
|
Thread Boundaries:
|
|
Runs synchronously in the caller thread; synchronizes history using _llama_history_lock.
|
|
"""
|
|
from src.openai_compatible import OpenAICompatibleRequest, _classify_openai_compatible_error
|
|
try:
|
|
if "localhost" in _llama_base_url or "127.0.0.1" in _llama_base_url:
|
|
return _send_llama_native(md_content, user_message, base_dir, file_items, discussion_history, stream, pre_tool_callback, qa_callback, stream_callback, patch_callback)
|
|
client = _ensure_llama_client()
|
|
tools: list[dict[str, Any]] | None = _get_deepseek_tools() or None
|
|
with _llama_history_lock:
|
|
user_content = user_message
|
|
if file_items:
|
|
for fi in file_items:
|
|
if fi.get("is_image") and fi.get("base64_data"):
|
|
user_content = f"[IMAGE: {fi.get('path', 'attachment')}]\n{user_content}"
|
|
if discussion_history and not _llama_history:
|
|
_llama_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"})
|
|
else:
|
|
_llama_history.append({"role": "user", "content": user_content})
|
|
def _build_llama_request(_round_idx: int) -> OpenAICompatibleRequest:
|
|
with _llama_history_lock:
|
|
messages: list[dict[str, Any]] = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}]
|
|
messages.extend(_llama_history)
|
|
return OpenAICompatibleRequest(
|
|
messages=messages, model=_model, temperature=_temperature, top_p=_top_p,
|
|
max_tokens=_max_tokens, stream=stream, stream_callback=stream_callback,
|
|
tools=tools, tool_choice="auto" if tools else "auto",
|
|
)
|
|
caps = get_capabilities("llama", _model)
|
|
return Result(data=run_with_tool_loop(
|
|
client, _build_llama_request, capabilities=caps,
|
|
pre_tool_callback=pre_tool_callback, qa_callback=qa_callback, stream_callback=stream_callback,
|
|
patch_callback=patch_callback, base_dir=base_dir, vendor_name="llama",
|
|
history_lock=_llama_history_lock, history=_llama_history,
|
|
))
|
|
except Exception as exc:
|
|
return Result(data="", errors=[_classify_openai_compatible_error(exc, source="ai_client.llama")])
|
|
|
|
OLLAMA_DEFAULT_BASE_URL: str = "http://localhost:11434"
|
|
|
|
def ollama_chat(
|
|
model: str,
|
|
messages: list[dict[str, Any]],
|
|
*,
|
|
think: str = "low",
|
|
images: list[str] | None = None,
|
|
tools: list[dict[str, Any]] | None = None,
|
|
base_url: str = OLLAMA_DEFAULT_BASE_URL,
|
|
) -> dict[str, Any]:
|
|
requests = _require_warmed("requests")
|
|
payload: dict[str, Any] = {"model": model, "messages": messages, "stream": False}
|
|
if think:
|
|
payload["think"] = think
|
|
if images:
|
|
payload["images"] = images
|
|
if tools:
|
|
payload["tools"] = tools
|
|
resp = requests.post(f"{base_url}/api/chat", json=payload, timeout=120)
|
|
return resp.json()
|
|
|
|
def _send_llama_native(md_content: str, user_message: str, base_dir: str,
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
stream: bool = False,
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Result[str]:
|
|
"""
|
|
Dispatches queries natively to local Ollama endpoints using direct HTTP requests.
|
|
|
|
Functional Purpose:
|
|
Bypasses the OpenAI compatible wrapper to interact directly with the Ollama REST API.
|
|
Supports image attachments and extracts deep thinking logs (if present) to format the response.
|
|
|
|
Parameters & Inputs:
|
|
md_content (str): Markdown formatted context content.
|
|
user_message (str): User prompt text.
|
|
base_dir (str): Workspace root directory.
|
|
file_items (Optional[list[dict[str, Any]]]): Media or file items for multimodal queries.
|
|
discussion_history (str): Contextual discussion text.
|
|
stream (bool): Whether to stream output.
|
|
pre_tool_callback (Optional[Callable]): Hook for HITL tool confirmation.
|
|
qa_callback (Optional[Callable]): Verification callback for QA checks.
|
|
stream_callback (Optional[Callable]): Callback function for streaming chunks.
|
|
patch_callback (Optional[Callable]): Validation callback for code edits.
|
|
|
|
Returns:
|
|
Result[str]: Wrap of string response (possibly including thinking blocks) and potential errors.
|
|
|
|
Immediate-Mode DAG / Thread Context:
|
|
Called by: _send_llama
|
|
Calls: ollama_chat
|
|
|
|
SSDL:
|
|
`[I:_ensure_llama_client] -> [I:run_with_tool_loop] -> [T:Result]`
|
|
|
|
Thread Boundaries:
|
|
Runs synchronously in the caller thread; synchronizes history using _llama_history_lock.
|
|
"""
|
|
try:
|
|
base_url = _llama_base_url.replace("/v1", "")
|
|
with _llama_history_lock:
|
|
if discussion_history and not _llama_history:
|
|
_llama_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"})
|
|
else:
|
|
_llama_history.append({"role": "user", "content": user_message})
|
|
messages: list[dict[str, Any]] = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}]
|
|
messages.extend(_llama_history)
|
|
images: list[str] = []
|
|
if file_items:
|
|
for fi in file_items:
|
|
if fi.get("is_image") and fi.get("base64_data"):
|
|
images.append(fi["base64_data"])
|
|
response = ollama_chat(_model, messages, images=images, base_url=base_url)
|
|
text = response.get("message", {}).get("content", "")
|
|
thinking = response.get("message", {}).get("thinking", "")
|
|
with _llama_history_lock:
|
|
msg: dict[str, Any] = {"role": "assistant", "content": text or None}
|
|
if thinking:
|
|
msg["thinking"] = thinking
|
|
_llama_history.append(msg)
|
|
return Result(data=(f"<thinking>\n{thinking}\n</thinking>\n" if thinking else "") + text)
|
|
except Exception as exc:
|
|
return Result(data="", errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=str(exc), source="ai_client.llama_native", original=exc)])
|
|
|
|
def _list_llama_models() -> list[str]:
|
|
from src.vendor_capabilities import list_models_for_vendor
|
|
return list_models_for_vendor("llama")
|
|
|
|
def _get_llama_cost_tracking() -> bool:
|
|
if "localhost" in _llama_base_url or "127.0.0.1" in _llama_base_url:
|
|
return False
|
|
from src.vendor_capabilities import get_capabilities
|
|
try:
|
|
caps = get_capabilities("llama", _model)
|
|
return caps.cost_tracking
|
|
except KeyError:
|
|
return True
|
|
|
|
#endregion: Llama Provider
|
|
|
|
#region: Tier 4 Analysis
|
|
|
|
def _run_tier4_analysis_result(stderr: str) -> Result[str]:
|
|
"""Tier 4 QA agent: analyze stderr and propose a fix in ~20 words.
|
|
|
|
Returns Result(data=analysis) on success, Result(data="", errors=[ErrorInfo])
|
|
on SDK failure. The legacy caller (run_tier4_analysis) returns result.data
|
|
(preserving the original str signature; failures surface as empty string
|
|
to keep the qa_callback contract).
|
|
"""
|
|
if not stderr or not stderr.strip():
|
|
return Result(data="")
|
|
try:
|
|
_ensure_gemini_client()
|
|
if not _gemini_client:
|
|
return Result(data="")
|
|
genai = _require_warmed("google.genai")
|
|
types = genai.types
|
|
prompt = (
|
|
f"You are a Tier 4 QA Agent specializing in error analysis.\n"
|
|
f"Analyze the following stderr output from a PowerShell command:\n\n"
|
|
f"```\n{stderr}\n```\n\n"
|
|
f"Provide a concise summary of the failure and suggest a fix in approximately 20 words."
|
|
)
|
|
model_name = "gemini-2.5-flash-lite"
|
|
resp = _gemini_client.models.generate_content(
|
|
model=model_name,
|
|
contents=prompt,
|
|
config=types.GenerateContentConfig(
|
|
temperature=0.0,
|
|
max_output_tokens=150,
|
|
)
|
|
)
|
|
analysis = resp.text.strip() if resp.text else ""
|
|
return Result(data=analysis)
|
|
except Exception as e:
|
|
return Result(
|
|
data="",
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"[QA ANALYSIS FAILED] {e}", source="ai_client._run_tier4_analysis_result", original=e)],
|
|
)
|
|
|
|
|
|
def run_tier4_analysis(stderr: str) -> str:
|
|
return _run_tier4_analysis_result(stderr).data
|
|
|
|
#endregion: Tier 4 Analysis
|
|
|
|
#region: Session & Public API
|
|
|
|
def _run_tier4_patch_callback_result(stderr: str, base_dir: str) -> Result[Optional[str]]:
|
|
"""Tier 4 QA agent: propose a unified-diff patch for the stderr.
|
|
|
|
Returns Result(data=patch) when a valid diff is produced, Result(data=None)
|
|
when no valid diff, Result(data=None, errors=[ErrorInfo]) on SDK failure.
|
|
The legacy caller (run_tier4_patch_callback) returns result.data
|
|
(preserving the original Optional[str] signature).
|
|
"""
|
|
try:
|
|
file_items = project_manager.get_current_file_items()
|
|
file_context = ""
|
|
for item in file_items[:5]:
|
|
path = item.get("path", "")
|
|
content = item.get("content", "")[:2000]
|
|
file_context += f"\n\nFile: {path}\n```\n{content}\n```\n"
|
|
patch = run_tier4_patch_generation(stderr, file_context)
|
|
if patch and "---" in patch and "+++" in patch:
|
|
return Result(data=patch)
|
|
return Result(data=None)
|
|
except Exception as e:
|
|
return Result(
|
|
data=None,
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"tier4 patch callback failed: {e}", source="ai_client._run_tier4_patch_callback_result", original=e)],
|
|
)
|
|
|
|
|
|
def run_tier4_patch_callback(stderr: str, base_dir: str) -> Optional[str]:
|
|
return _run_tier4_patch_callback_result(stderr, base_dir).data
|
|
|
|
def _run_tier4_patch_generation_result(error: str, file_context: str) -> Result[str]:
|
|
"""Tier 4 QA agent: generate a unified-diff patch for the given error.
|
|
|
|
Returns Result(data=patch) on success, Result(data="", errors=[ErrorInfo])
|
|
on SDK failure. The legacy caller (run_tier4_patch_generation) returns
|
|
result.data (preserving the original str signature; failures surface as
|
|
empty string to keep callers' downstream code working).
|
|
"""
|
|
if not error or not error.strip():
|
|
return Result(data="")
|
|
try:
|
|
_ensure_gemini_client()
|
|
if not _gemini_client:
|
|
return Result(data="")
|
|
genai = _require_warmed("google.genai")
|
|
types = genai.types
|
|
prompt = (
|
|
f"{mma_prompts.TIER4_PATCH_PROMPT}\n\n"
|
|
f"Error:\n```\n{error}\n```\n\n"
|
|
f"File Context:\n```\n{file_context}\n```\n"
|
|
)
|
|
model_name = "gemini-2.5-flash-lite"
|
|
resp = _gemini_client.models.generate_content(
|
|
model=model_name,
|
|
contents=prompt,
|
|
config=types.GenerateContentConfig(
|
|
temperature=0.0,
|
|
max_output_tokens=2048,
|
|
)
|
|
)
|
|
patch = resp.text.strip() if resp.text else ""
|
|
return Result(data=patch)
|
|
except Exception as e:
|
|
return Result(
|
|
data="",
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"[PATCH GENERATION FAILED] {e}", source="ai_client._run_tier4_patch_generation_result", original=e)],
|
|
)
|
|
|
|
|
|
def run_tier4_patch_generation(error: str, file_context: str) -> str:
|
|
"""
|
|
[C: src/gui_2.py:App.request_patch_from_tier4, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_calls_ai, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_empty_error, tests/test_tier4_patch_generation.py:test_run_tier4_patch_generation_returns_diff]
|
|
"""
|
|
return _run_tier4_patch_generation_result(error, file_context).data
|
|
|
|
def _count_gemini_tokens_for_stats_result(md_content: str) -> Result[int]:
|
|
"""Count tokens via Gemini SDK for the token-stats panel.
|
|
|
|
Returns Result(data=token_count) on success, Result(data=0, errors=[ErrorInfo])
|
|
on SDK or warmup failure. The legacy caller (get_token_stats) treats
|
|
errors as "token count unavailable" and falls back to character-based
|
|
estimation (preserving original behavior).
|
|
"""
|
|
if _gemini_client is None:
|
|
_ensure_gemini_client()
|
|
if _gemini_client is None:
|
|
return Result(data=0)
|
|
try:
|
|
resp = _gemini_client.models.count_tokens(model=_model, contents=md_content)
|
|
return Result(data=cast(int, resp.total_tokens))
|
|
except Exception as e:
|
|
return Result(
|
|
data=0,
|
|
errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=f"failed to count gemini tokens for stats: {e}", source="ai_client._count_gemini_tokens_for_stats_result", original=e)],
|
|
)
|
|
|
|
|
|
def get_token_stats(md_content: str) -> dict[str, Any]:
|
|
"""
|
|
[C: src/app_controller.py:AppController._refresh_api_metrics]
|
|
"""
|
|
global _provider, _gemini_client, _model, _CHARS_PER_TOKEN
|
|
total_tokens = 0
|
|
p = str(_provider).lower().strip()
|
|
if p in ("gemini", "gemini_cli"):
|
|
total_tokens = _count_gemini_tokens_for_stats_result(md_content).data
|
|
if total_tokens == 0:
|
|
total_tokens = max(1, int(len(md_content) / _CHARS_PER_TOKEN))
|
|
limit = _GEMINI_MAX_INPUT_TOKENS if p in ["gemini", "gemini_cli"] else _ANTHROPIC_MAX_PROMPT_TOKENS
|
|
if p == "deepseek":
|
|
limit = 64000
|
|
pct = (total_tokens / limit * 100) if limit > 0 else 0
|
|
stats = {
|
|
"total_tokens": total_tokens,
|
|
"current": total_tokens,
|
|
"limit": limit,
|
|
"percentage": pct
|
|
}
|
|
return _add_bleed_derived(stats, sys_tok=total_tokens)
|
|
|
|
def send(
|
|
md_content: str,
|
|
user_message: str,
|
|
base_dir: str = ".",
|
|
file_items: list[dict[str, Any]] | None = None,
|
|
discussion_history: str = "",
|
|
stream: bool = False,
|
|
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
|
qa_callback: Optional[Callable[[str], str]] = None,
|
|
enable_tools: bool = True,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None,
|
|
rag_engine: Optional[Any] = None,
|
|
) -> Result[str]:
|
|
"""
|
|
Sends a prompt to the currently configured AI provider, returning a comprehensive Result object.
|
|
|
|
Functional Purpose:
|
|
This is the primary public entry point for AI communication. It integrates retrieval-augmented
|
|
generation (RAG) by searching the vector index and injecting relevant context chunks into the user
|
|
message. It logs the outgoing request to the communications logger, acquires a global thread-safety
|
|
lock (_send_lock), and routes the request to the appropriate vendor-specific handler based on the
|
|
active provider configuration. All exceptions are caught and returned gracefully as ErrorInfo objects.
|
|
|
|
Parameters & Inputs:
|
|
md_content (str): System prompt template or markdown prompt structure.
|
|
user_message (str): The primary user instruction.
|
|
base_dir (str): Base workspace directory path (defaults to ".").
|
|
file_items (list[dict[str, Any]] | None): Optional list of active context files.
|
|
discussion_history (str): Contextual discussion history lines.
|
|
stream (bool): Whether to stream the response chunks.
|
|
pre_tool_callback (Optional[Callable]): Hook called before executing tool calls.
|
|
qa_callback (Optional[Callable]): Hook for Tier 4 quality/validation checks.
|
|
enable_tools (bool): Controls whether LLM tool usage is enabled.
|
|
stream_callback (Optional[Callable]): Hook to stream response chunks to.
|
|
patch_callback (Optional[Callable]): Custom callback for interactive patch validation.
|
|
rag_engine (Optional[Any]): RAG search engine instance to fetch vector context.
|
|
|
|
Returns:
|
|
Result[str]: Container holding the successful response string or error details.
|
|
|
|
Immediate-Mode DAG / Thread Context:
|
|
Called by: send() and direct public callers verifying error structures.
|
|
Calls: performance_monitor, rag_engine.search, _append_comms, _send_gemini,
|
|
_send_gemini_cli, _send_anthropic, _send_deepseek, _send_minimax,
|
|
_send_qwen, _send_llama, _send_grok, _send_llama_native
|
|
|
|
SSDL:
|
|
`[Q:active_provider] -> [I:SetupTierTag] -> [I:DispatchProvider] -> [T:Result]`
|
|
|
|
Thread Boundaries:
|
|
Acquires the global _send_lock to synchronize provider calls. Safely called from any worker
|
|
thread executing background tasks, preventing concurrent thread collisions on shared provider SDK states.
|
|
|
|
[C: tests/test_ai_client_result.py:test_send_public_api_returns_result, tests/test_ai_client_result.py:test_send_preserves_errors, tests/test_deprecation_warnings.py:test_send_does_not_emit_deprecation]
|
|
"""
|
|
monitor = performance_monitor.get_monitor()
|
|
if monitor.enabled: monitor.start_component("ai_client.send")
|
|
|
|
if rag_engine and getattr(rag_engine.config, "enabled", False) and "## Retrieved Context" not in user_message:
|
|
chunks = rag_engine.search(user_message)
|
|
if chunks:
|
|
context_block = "## Retrieved Context\n\n"
|
|
for i, chunk in enumerate(chunks):
|
|
path = chunk.get("metadata", {}).get("path", "unknown")
|
|
context_block += f"### Chunk {i+1} (Source: {path})\n{chunk.get('document', '')}\n\n"
|
|
user_message = context_block + user_message
|
|
|
|
_append_comms("OUT", "request", {"message": user_message, "system": _get_combined_system_prompt(_active_tool_preset, _active_bias_profile)})
|
|
with _send_lock:
|
|
p = str(_provider).lower().strip()
|
|
try:
|
|
if p == "gemini":
|
|
res = _send_gemini(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
pre_tool_callback, qa_callback, enable_tools, stream_callback, patch_callback
|
|
)
|
|
elif p == "gemini_cli":
|
|
res = _send_gemini_cli(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
pre_tool_callback, qa_callback, stream_callback, patch_callback
|
|
)
|
|
elif p == "anthropic":
|
|
res = _send_anthropic(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
pre_tool_callback, qa_callback, stream_callback=stream_callback, patch_callback=patch_callback
|
|
)
|
|
elif p == "deepseek":
|
|
res = _send_deepseek(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
|
|
)
|
|
elif p == "minimax":
|
|
res = _send_minimax(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
|
|
)
|
|
elif p == "qwen":
|
|
res = _send_qwen(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
|
|
)
|
|
elif p == "llama":
|
|
res = _send_llama(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
|
|
)
|
|
elif p == "grok":
|
|
res = _send_grok(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
|
|
)
|
|
elif p == "llama_native":
|
|
res = _send_llama_native(
|
|
md_content, user_message, base_dir, file_items, discussion_history,
|
|
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
|
|
)
|
|
else:
|
|
res = Result(data="", errors=[ErrorInfo(kind=ErrorKind.CONFIG, message=f"unknown provider: {_provider}", source="ai_client.send")])
|
|
except Exception as exc:
|
|
res = Result(data="", errors=[ErrorInfo(kind=ErrorKind.INTERNAL, message=str(exc), source="ai_client.send", original=exc)])
|
|
if monitor.enabled: monitor.end_component("ai_client.send")
|
|
return res
|
|
|
|
def _add_bleed_derived(d: dict[str, Any], sys_tok: int = 0, tool_tok: int = 0) -> dict[str, Any]:
|
|
"""
|
|
[C: tests/test_token_viz.py:test_add_bleed_derived_aliases, tests/test_token_viz.py:test_add_bleed_derived_breakdown, tests/test_token_viz.py:test_add_bleed_derived_headroom, tests/test_token_viz.py:test_add_bleed_derived_headroom_clamped_to_zero, tests/test_token_viz.py:test_add_bleed_derived_history_clamped_to_zero, tests/test_token_viz.py:test_add_bleed_derived_would_trim_false, tests/test_token_viz.py:test_add_bleed_derived_would_trim_true, tests/test_token_viz.py:test_would_trim_boundary_exact, tests/test_token_viz.py:test_would_trim_just_above_threshold, tests/test_token_viz.py:test_would_trim_just_below_threshold]
|
|
"""
|
|
cur = d.get("current", 0)
|
|
lim = d.get("limit", 0)
|
|
d["estimated_prompt_tokens"] = cur
|
|
d["max_prompt_tokens"] = lim
|
|
d["utilization_pct"] = d.get("percentage", 0.0)
|
|
d["headroom"] = max(0, lim - cur)
|
|
d["would_trim"] = cur >= lim
|
|
d["sys_tokens"] = sys_tok
|
|
d["tool_tokens"] = tool_tok
|
|
d["history_tokens"] = max(0, cur - sys_tok - tool_tok)
|
|
return d
|
|
|
|
# Check for tool preset in environment variable (headless mode)
|
|
if os.environ.get("SLOP_TOOL_PRESET"):
|
|
_set_tool_preset_result(os.environ["SLOP_TOOL_PRESET"])
|
|
|
|
#endregion: Session & Public API
|
|
|
|
#region: Subagent Summarization
|
|
|
|
def run_subagent_summarization(file_path: str, content: str, is_code: bool, outline: str) -> str:
|
|
"""
|
|
[C: src/summarize.py:summarise_file, tests/test_subagent_summarization.py:test_run_subagent_summarization_anthropic, tests/test_subagent_summarization.py:test_run_subagent_summarization_gemini]
|
|
"""
|
|
requests = _require_warmed("requests")
|
|
genai = _require_warmed("google.genai")
|
|
types = genai.types
|
|
prompt_tmpl = mma_prompts.TIER4_SUMMARIZE_CODE_PROMPT if is_code else mma_prompts.TIER4_SUMMARIZE_TEXT_PROMPT
|
|
prompt = prompt_tmpl.format(file_path=file_path, outline=outline, content=content)
|
|
if _provider == "gemini":
|
|
_ensure_gemini_client()
|
|
if _gemini_client:
|
|
resp = _gemini_client.models.generate_content(
|
|
model=_model,
|
|
contents=prompt,
|
|
config=types.GenerateContentConfig(
|
|
temperature=0.0,
|
|
max_output_tokens=1024,
|
|
)
|
|
)
|
|
return resp.text or ""
|
|
elif _provider == "anthropic":
|
|
_ensure_anthropic_client()
|
|
if _anthropic_client:
|
|
resp = _anthropic_client.messages.create(
|
|
model=_model,
|
|
max_tokens=1024,
|
|
messages=[{"role": "user", "content": prompt}]
|
|
)
|
|
return "".join([b.text for b in resp.content if hasattr(b, "text") and b.text])
|
|
elif _provider == "deepseek":
|
|
creds = _load_credentials()
|
|
api_key = creds.get("deepseek", {}).get("api_key")
|
|
if not api_key: return "ERROR: DeepSeek API key missing"
|
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
payload = {
|
|
"model": _model,
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"temperature": 0.0,
|
|
}
|
|
try:
|
|
r = requests.post("https://api.deepseek.com/chat/completions", headers=headers, json=payload, timeout=60)
|
|
r.raise_for_status()
|
|
return r.json()["choices"][0]["message"]["content"]
|
|
except Exception as e:
|
|
return f"ERROR: DeepSeek summarization failed: {e}"
|
|
elif _provider == "gemini_cli":
|
|
# Using the adapter for a one-off call
|
|
adapter = GeminiCliAdapter(binary_path="gemini")
|
|
resp_data = adapter.send(prompt, model=_model)
|
|
return resp_data.get("text", "")
|
|
return "ERROR: Unsupported provider for sub-agent summarization"
|
|
|
|
def run_discussion_compression(discussion_text: str) -> str:
|
|
genai = _require_warmed("google.genai")
|
|
types = genai.types
|
|
requests = _require_warmed("requests")
|
|
# Robustly identify the provider string (handles case and whitespace)
|
|
p = str(get_provider()).lower().strip()
|
|
prompt = f"The following is a long conversation history.\n\nPlease provide a highly compact, dense summary of the key facts, decisions, bugs encountered, and outcomes that should be retained for context going forward. Categorize into User intent, Tool outputs, and AI reasoning. Omit pleasantries and redundant thoughts.\n\n[HISTORY]\n{discussion_text}"
|
|
if p == "gemini":
|
|
_ensure_gemini_client()
|
|
if _gemini_client:
|
|
resp = _gemini_client.models.generate_content(
|
|
model=_model,
|
|
contents=prompt,
|
|
config=types.GenerateContentConfig(temperature=0.0, max_output_tokens=2048)
|
|
)
|
|
return resp.text or ""
|
|
elif p == "anthropic":
|
|
_ensure_anthropic_client()
|
|
if _anthropic_client:
|
|
resp = _anthropic_client.messages.create(
|
|
model=_model, max_tokens=2048,
|
|
messages=[{"role": "user", "content": prompt}]
|
|
)
|
|
return "".join([b.text for b in resp.content if hasattr(b, "text") and b.text])
|
|
elif p == "deepseek":
|
|
creds = _load_credentials()
|
|
api_key = creds.get("deepseek", {}).get("api_key")
|
|
if not api_key: return "ERROR: DeepSeek API key missing"
|
|
try:
|
|
r = requests.post("https://api.deepseek.com/chat/completions", headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, json={"model": _model, "messages": [{"role": "user", "content": prompt}], "temperature": 0.0}, timeout=60)
|
|
r.raise_for_status()
|
|
return r.json()["choices"][0]["message"]["content"]
|
|
except Exception as e:
|
|
return f"ERROR: DeepSeek compression failed: {e}"
|
|
elif p == "minimax":
|
|
_ensure_minimax_client()
|
|
if _minimax_client:
|
|
resp = _minimax_client.chat.completions.create(
|
|
model=_model,
|
|
messages=[{"role": "user", "content": prompt}],
|
|
temperature=0.0,
|
|
max_tokens=2048
|
|
)
|
|
return resp.choices[0].message.content or ""
|
|
elif p == "gemini_cli":
|
|
adapter = GeminiCliAdapter(binary_path="gemini")
|
|
resp_data = adapter.send(prompt, model=_model)
|
|
return resp_data.get("text", "")
|
|
return f"ERROR: Unsupported provider for discussion compression: '{p}'"
|
|
|
|
#endregion: Subagent Summarization
|