From 29a96cc9f5742d83513f89338acc5f1bb724950a Mon Sep 17 00:00:00 2001 From: Ed_ Date: Thu, 11 Jun 2026 01:56:21 -0400 Subject: [PATCH] feat(ai_client): Add Grok (xAI) OpenAI-compatible provider --- src/ai_client.py | 139 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) diff --git a/src/ai_client.py b/src/ai_client.py index 08c5f9d5..b5b86e36 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -136,6 +136,16 @@ _qwen_history: list[dict[str, Any]] = [] _qwen_history_lock: threading.Lock = threading.Lock() _qwen_region: str = "china" +_grok_client: Any = None +_grok_history: list[dict[str, Any]] = [] +_grok_history_lock: threading.Lock = threading.Lock() + +_llama_client: Any = None +_llama_history: list[dict[str, Any]] = [] +_llama_history_lock: threading.Lock = threading.Lock() +_llama_base_url: str = "http://localhost:11434/v1" +_llama_api_key: str = "ollama" + _send_lock: threading.Lock = threading.Lock() _BIAS_ENGINE = ToolBiasEngine() @@ -522,6 +532,14 @@ def reset_session() -> None: _qwen_client = None with _qwen_history_lock: _qwen_history = [] + _grok_client = None + with _grok_history_lock: + _grok_history = [] + _llama_client = None + with _llama_history_lock: + _llama_history = [] + _llama_base_url = "http://localhost:11434/v1" + _llama_api_key = "ollama" _CACHED_ANTHROPIC_TOOLS = None _CACHED_DEEPSEEK_TOOLS = None file_cache.reset_client() @@ -537,6 +555,8 @@ def list_models(provider: str) -> list[str]: elif provider == "gemini_cli": return _list_gemini_cli_models() elif provider == "minimax": return _list_minimax_models(creds["minimax"]["api_key"]) elif provider == "qwen": return _list_qwen_models() + elif provider == "grok": return _list_grok_models() + elif provider == "llama": return _list_llama_models() return [] #endregion: Comms Log @@ -2150,6 +2170,58 @@ def _ensure_minimax_client() -> None: raise ValueError("MiniMax API key not found in credentials.toml") _minimax_client = OpenAI(api_key=api_key, base_url="https://api.minimax.chat/v1") +def _ensure_grok_client() -> Any: + global _grok_client + if _grok_client is None: + openai = _require_warmed("openai") + creds = _load_credentials() + api_key = creds.get("grok", {}).get("api_key") + if not api_key: + raise ValueError("Grok API key not found in credentials.toml") + _grok_client = openai.OpenAI(api_key=api_key, base_url="https://api.x.ai/v1") + return _grok_client + +def _send_grok(md_content: str, user_message: str, base_dir: str, + file_items: list[dict[str, Any]] | None = None, + discussion_history: str = "", + stream: bool = False, + pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, + qa_callback: Optional[Callable[[str], str]] = None, + stream_callback: Optional[Callable[[str], None]] = None, + patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: + client = _ensure_grok_client() + from src.openai_compatible import OpenAICompatibleRequest, send_openai_compatible + from src.vendor_capabilities import get_capabilities + with _grok_history_lock: + user_content = user_message + if file_items: + for fi in file_items: + if fi.get("is_image") and fi.get("base64_data"): + user_content = f"[IMAGE: {fi.get('path', 'attachment')}]\n{user_content}" + if discussion_history and not _grok_history: + _grok_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"}) + else: + _grok_history.append({"role": "user", "content": user_content}) + messages = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n\n{md_content}\n"}] + messages.extend(_grok_history) + request = OpenAICompatibleRequest( + messages=messages, + model=_model, + temperature=_temperature, + top_p=_top_p, + max_tokens=_max_tokens, + stream=stream, + stream_callback=stream_callback, + ) + caps = get_capabilities("grok", _model) + response = send_openai_compatible(client, request, capabilities=caps) + _grok_history.append({"role": "assistant", "content": response.text}) + return response.text + +def _list_grok_models() -> list[str]: + from src.vendor_capabilities import list_models_for_vendor + return list_models_for_vendor("grok") + def _send_minimax(md_content: str, user_message: str, base_dir: str, file_items: list[dict[str, Any]] | None = None, discussion_history: str = "", @@ -2487,6 +2559,73 @@ def _send_qwen(md_content: str, user_message: str, base_dir: str, #endregion: Qwen Provider +def _ensure_llama_client() -> Any: + global _llama_client, _llama_base_url, _llama_api_key + if _llama_client is None: + openai = _require_warmed("openai") + creds = _load_credentials() + configured_url = creds.get("llama", {}).get("base_url") + configured_key = creds.get("llama", {}).get("api_key") + if configured_url: + _llama_base_url = configured_url + if configured_key is not None: + _llama_api_key = configured_key or "ollama" + _llama_client = openai.OpenAI(api_key=_llama_api_key, base_url=_llama_base_url) + return _llama_client + +def _send_llama(md_content: str, user_message: str, base_dir: str, + file_items: list[dict[str, Any]] | None = None, + discussion_history: str = "", + stream: bool = False, + pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, + qa_callback: Optional[Callable[[str], str]] = None, + stream_callback: Optional[Callable[[str], None]] = None, + patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: + client = _ensure_llama_client() + from src.openai_compatible import OpenAICompatibleRequest, send_openai_compatible + from src.vendor_capabilities import get_capabilities + with _llama_history_lock: + user_content = user_message + if file_items: + for fi in file_items: + if fi.get("is_image") and fi.get("base64_data"): + user_content = f"[IMAGE: {fi.get('path', 'attachment')}]\n{user_content}" + if discussion_history and not _llama_history: + _llama_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"}) + else: + _llama_history.append({"role": "user", "content": user_content}) + messages = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n\n{md_content}\n"}] + messages.extend(_llama_history) + request = OpenAICompatibleRequest( + messages=messages, + model=_model, + temperature=_temperature, + top_p=_top_p, + max_tokens=_max_tokens, + stream=stream, + stream_callback=stream_callback, + ) + caps = get_capabilities("llama", _model) + response = send_openai_compatible(client, request, capabilities=caps) + _llama_history.append({"role": "assistant", "content": response.text}) + return response.text + +def _list_llama_models() -> list[str]: + from src.vendor_capabilities import list_models_for_vendor + return list_models_for_vendor("llama") + +def _get_llama_cost_tracking() -> bool: + if "localhost" in _llama_base_url or "127.0.0.1" in _llama_base_url: + return False + from src.vendor_capabilities import get_capabilities + try: + caps = get_capabilities("llama", _model) + return caps.cost_tracking + except KeyError: + return True + +#endregion: Llama Provider + #region: Tier 4 Analysis def run_tier4_analysis(stderr: str) -> str: