Private
Public Access
0
0

feat(ai_client): Add Grok (xAI) OpenAI-compatible provider

This commit is contained in:
2026-06-11 01:56:21 -04:00
parent 06716252f1
commit 29a96cc9f5
+139
View File
@@ -136,6 +136,16 @@ _qwen_history: list[dict[str, Any]] = []
_qwen_history_lock: threading.Lock = threading.Lock()
_qwen_region: str = "china"
_grok_client: Any = None
_grok_history: list[dict[str, Any]] = []
_grok_history_lock: threading.Lock = threading.Lock()
_llama_client: Any = None
_llama_history: list[dict[str, Any]] = []
_llama_history_lock: threading.Lock = threading.Lock()
_llama_base_url: str = "http://localhost:11434/v1"
_llama_api_key: str = "ollama"
_send_lock: threading.Lock = threading.Lock()
_BIAS_ENGINE = ToolBiasEngine()
@@ -522,6 +532,14 @@ def reset_session() -> None:
_qwen_client = None
with _qwen_history_lock:
_qwen_history = []
_grok_client = None
with _grok_history_lock:
_grok_history = []
_llama_client = None
with _llama_history_lock:
_llama_history = []
_llama_base_url = "http://localhost:11434/v1"
_llama_api_key = "ollama"
_CACHED_ANTHROPIC_TOOLS = None
_CACHED_DEEPSEEK_TOOLS = None
file_cache.reset_client()
@@ -537,6 +555,8 @@ def list_models(provider: str) -> list[str]:
elif provider == "gemini_cli": return _list_gemini_cli_models()
elif provider == "minimax": return _list_minimax_models(creds["minimax"]["api_key"])
elif provider == "qwen": return _list_qwen_models()
elif provider == "grok": return _list_grok_models()
elif provider == "llama": return _list_llama_models()
return []
#endregion: Comms Log
@@ -2150,6 +2170,58 @@ def _ensure_minimax_client() -> None:
raise ValueError("MiniMax API key not found in credentials.toml")
_minimax_client = OpenAI(api_key=api_key, base_url="https://api.minimax.chat/v1")
def _ensure_grok_client() -> Any:
global _grok_client
if _grok_client is None:
openai = _require_warmed("openai")
creds = _load_credentials()
api_key = creds.get("grok", {}).get("api_key")
if not api_key:
raise ValueError("Grok API key not found in credentials.toml")
_grok_client = openai.OpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
return _grok_client
def _send_grok(md_content: str, user_message: str, base_dir: str,
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
stream: bool = False,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
qa_callback: Optional[Callable[[str], str]] = None,
stream_callback: Optional[Callable[[str], None]] = None,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
client = _ensure_grok_client()
from src.openai_compatible import OpenAICompatibleRequest, send_openai_compatible
from src.vendor_capabilities import get_capabilities
with _grok_history_lock:
user_content = user_message
if file_items:
for fi in file_items:
if fi.get("is_image") and fi.get("base64_data"):
user_content = f"[IMAGE: {fi.get('path', 'attachment')}]\n{user_content}"
if discussion_history and not _grok_history:
_grok_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"})
else:
_grok_history.append({"role": "user", "content": user_content})
messages = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}]
messages.extend(_grok_history)
request = OpenAICompatibleRequest(
messages=messages,
model=_model,
temperature=_temperature,
top_p=_top_p,
max_tokens=_max_tokens,
stream=stream,
stream_callback=stream_callback,
)
caps = get_capabilities("grok", _model)
response = send_openai_compatible(client, request, capabilities=caps)
_grok_history.append({"role": "assistant", "content": response.text})
return response.text
def _list_grok_models() -> list[str]:
from src.vendor_capabilities import list_models_for_vendor
return list_models_for_vendor("grok")
def _send_minimax(md_content: str, user_message: str, base_dir: str,
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
@@ -2487,6 +2559,73 @@ def _send_qwen(md_content: str, user_message: str, base_dir: str,
#endregion: Qwen Provider
def _ensure_llama_client() -> Any:
global _llama_client, _llama_base_url, _llama_api_key
if _llama_client is None:
openai = _require_warmed("openai")
creds = _load_credentials()
configured_url = creds.get("llama", {}).get("base_url")
configured_key = creds.get("llama", {}).get("api_key")
if configured_url:
_llama_base_url = configured_url
if configured_key is not None:
_llama_api_key = configured_key or "ollama"
_llama_client = openai.OpenAI(api_key=_llama_api_key, base_url=_llama_base_url)
return _llama_client
def _send_llama(md_content: str, user_message: str, base_dir: str,
file_items: list[dict[str, Any]] | None = None,
discussion_history: str = "",
stream: bool = False,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
qa_callback: Optional[Callable[[str], str]] = None,
stream_callback: Optional[Callable[[str], None]] = None,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
client = _ensure_llama_client()
from src.openai_compatible import OpenAICompatibleRequest, send_openai_compatible
from src.vendor_capabilities import get_capabilities
with _llama_history_lock:
user_content = user_message
if file_items:
for fi in file_items:
if fi.get("is_image") and fi.get("base64_data"):
user_content = f"[IMAGE: {fi.get('path', 'attachment')}]\n{user_content}"
if discussion_history and not _llama_history:
_llama_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"})
else:
_llama_history.append({"role": "user", "content": user_content})
messages = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}]
messages.extend(_llama_history)
request = OpenAICompatibleRequest(
messages=messages,
model=_model,
temperature=_temperature,
top_p=_top_p,
max_tokens=_max_tokens,
stream=stream,
stream_callback=stream_callback,
)
caps = get_capabilities("llama", _model)
response = send_openai_compatible(client, request, capabilities=caps)
_llama_history.append({"role": "assistant", "content": response.text})
return response.text
def _list_llama_models() -> list[str]:
from src.vendor_capabilities import list_models_for_vendor
return list_models_for_vendor("llama")
def _get_llama_cost_tracking() -> bool:
if "localhost" in _llama_base_url or "127.0.0.1" in _llama_base_url:
return False
from src.vendor_capabilities import get_capabilities
try:
caps = get_capabilities("llama", _model)
return caps.cost_tracking
except KeyError:
return True
#endregion: Llama Provider
#region: Tier 4 Analysis
def run_tier4_analysis(stderr: str) -> str: