feat(ai_client): Add Grok (xAI) OpenAI-compatible provider
This commit is contained in:
@@ -136,6 +136,16 @@ _qwen_history: list[dict[str, Any]] = []
|
||||
_qwen_history_lock: threading.Lock = threading.Lock()
|
||||
_qwen_region: str = "china"
|
||||
|
||||
_grok_client: Any = None
|
||||
_grok_history: list[dict[str, Any]] = []
|
||||
_grok_history_lock: threading.Lock = threading.Lock()
|
||||
|
||||
_llama_client: Any = None
|
||||
_llama_history: list[dict[str, Any]] = []
|
||||
_llama_history_lock: threading.Lock = threading.Lock()
|
||||
_llama_base_url: str = "http://localhost:11434/v1"
|
||||
_llama_api_key: str = "ollama"
|
||||
|
||||
_send_lock: threading.Lock = threading.Lock()
|
||||
|
||||
_BIAS_ENGINE = ToolBiasEngine()
|
||||
@@ -522,6 +532,14 @@ def reset_session() -> None:
|
||||
_qwen_client = None
|
||||
with _qwen_history_lock:
|
||||
_qwen_history = []
|
||||
_grok_client = None
|
||||
with _grok_history_lock:
|
||||
_grok_history = []
|
||||
_llama_client = None
|
||||
with _llama_history_lock:
|
||||
_llama_history = []
|
||||
_llama_base_url = "http://localhost:11434/v1"
|
||||
_llama_api_key = "ollama"
|
||||
_CACHED_ANTHROPIC_TOOLS = None
|
||||
_CACHED_DEEPSEEK_TOOLS = None
|
||||
file_cache.reset_client()
|
||||
@@ -537,6 +555,8 @@ def list_models(provider: str) -> list[str]:
|
||||
elif provider == "gemini_cli": return _list_gemini_cli_models()
|
||||
elif provider == "minimax": return _list_minimax_models(creds["minimax"]["api_key"])
|
||||
elif provider == "qwen": return _list_qwen_models()
|
||||
elif provider == "grok": return _list_grok_models()
|
||||
elif provider == "llama": return _list_llama_models()
|
||||
return []
|
||||
|
||||
#endregion: Comms Log
|
||||
@@ -2150,6 +2170,58 @@ def _ensure_minimax_client() -> None:
|
||||
raise ValueError("MiniMax API key not found in credentials.toml")
|
||||
_minimax_client = OpenAI(api_key=api_key, base_url="https://api.minimax.chat/v1")
|
||||
|
||||
def _ensure_grok_client() -> Any:
|
||||
global _grok_client
|
||||
if _grok_client is None:
|
||||
openai = _require_warmed("openai")
|
||||
creds = _load_credentials()
|
||||
api_key = creds.get("grok", {}).get("api_key")
|
||||
if not api_key:
|
||||
raise ValueError("Grok API key not found in credentials.toml")
|
||||
_grok_client = openai.OpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
|
||||
return _grok_client
|
||||
|
||||
def _send_grok(md_content: str, user_message: str, base_dir: str,
|
||||
file_items: list[dict[str, Any]] | None = None,
|
||||
discussion_history: str = "",
|
||||
stream: bool = False,
|
||||
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
||||
qa_callback: Optional[Callable[[str], str]] = None,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
|
||||
client = _ensure_grok_client()
|
||||
from src.openai_compatible import OpenAICompatibleRequest, send_openai_compatible
|
||||
from src.vendor_capabilities import get_capabilities
|
||||
with _grok_history_lock:
|
||||
user_content = user_message
|
||||
if file_items:
|
||||
for fi in file_items:
|
||||
if fi.get("is_image") and fi.get("base64_data"):
|
||||
user_content = f"[IMAGE: {fi.get('path', 'attachment')}]\n{user_content}"
|
||||
if discussion_history and not _grok_history:
|
||||
_grok_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"})
|
||||
else:
|
||||
_grok_history.append({"role": "user", "content": user_content})
|
||||
messages = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}]
|
||||
messages.extend(_grok_history)
|
||||
request = OpenAICompatibleRequest(
|
||||
messages=messages,
|
||||
model=_model,
|
||||
temperature=_temperature,
|
||||
top_p=_top_p,
|
||||
max_tokens=_max_tokens,
|
||||
stream=stream,
|
||||
stream_callback=stream_callback,
|
||||
)
|
||||
caps = get_capabilities("grok", _model)
|
||||
response = send_openai_compatible(client, request, capabilities=caps)
|
||||
_grok_history.append({"role": "assistant", "content": response.text})
|
||||
return response.text
|
||||
|
||||
def _list_grok_models() -> list[str]:
|
||||
from src.vendor_capabilities import list_models_for_vendor
|
||||
return list_models_for_vendor("grok")
|
||||
|
||||
def _send_minimax(md_content: str, user_message: str, base_dir: str,
|
||||
file_items: list[dict[str, Any]] | None = None,
|
||||
discussion_history: str = "",
|
||||
@@ -2487,6 +2559,73 @@ def _send_qwen(md_content: str, user_message: str, base_dir: str,
|
||||
|
||||
#endregion: Qwen Provider
|
||||
|
||||
def _ensure_llama_client() -> Any:
|
||||
global _llama_client, _llama_base_url, _llama_api_key
|
||||
if _llama_client is None:
|
||||
openai = _require_warmed("openai")
|
||||
creds = _load_credentials()
|
||||
configured_url = creds.get("llama", {}).get("base_url")
|
||||
configured_key = creds.get("llama", {}).get("api_key")
|
||||
if configured_url:
|
||||
_llama_base_url = configured_url
|
||||
if configured_key is not None:
|
||||
_llama_api_key = configured_key or "ollama"
|
||||
_llama_client = openai.OpenAI(api_key=_llama_api_key, base_url=_llama_base_url)
|
||||
return _llama_client
|
||||
|
||||
def _send_llama(md_content: str, user_message: str, base_dir: str,
|
||||
file_items: list[dict[str, Any]] | None = None,
|
||||
discussion_history: str = "",
|
||||
stream: bool = False,
|
||||
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
||||
qa_callback: Optional[Callable[[str], str]] = None,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
|
||||
client = _ensure_llama_client()
|
||||
from src.openai_compatible import OpenAICompatibleRequest, send_openai_compatible
|
||||
from src.vendor_capabilities import get_capabilities
|
||||
with _llama_history_lock:
|
||||
user_content = user_message
|
||||
if file_items:
|
||||
for fi in file_items:
|
||||
if fi.get("is_image") and fi.get("base64_data"):
|
||||
user_content = f"[IMAGE: {fi.get('path', 'attachment')}]\n{user_content}"
|
||||
if discussion_history and not _llama_history:
|
||||
_llama_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"})
|
||||
else:
|
||||
_llama_history.append({"role": "user", "content": user_content})
|
||||
messages = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}]
|
||||
messages.extend(_llama_history)
|
||||
request = OpenAICompatibleRequest(
|
||||
messages=messages,
|
||||
model=_model,
|
||||
temperature=_temperature,
|
||||
top_p=_top_p,
|
||||
max_tokens=_max_tokens,
|
||||
stream=stream,
|
||||
stream_callback=stream_callback,
|
||||
)
|
||||
caps = get_capabilities("llama", _model)
|
||||
response = send_openai_compatible(client, request, capabilities=caps)
|
||||
_llama_history.append({"role": "assistant", "content": response.text})
|
||||
return response.text
|
||||
|
||||
def _list_llama_models() -> list[str]:
|
||||
from src.vendor_capabilities import list_models_for_vendor
|
||||
return list_models_for_vendor("llama")
|
||||
|
||||
def _get_llama_cost_tracking() -> bool:
|
||||
if "localhost" in _llama_base_url or "127.0.0.1" in _llama_base_url:
|
||||
return False
|
||||
from src.vendor_capabilities import get_capabilities
|
||||
try:
|
||||
caps = get_capabilities("llama", _model)
|
||||
return caps.cost_tracking
|
||||
except KeyError:
|
||||
return True
|
||||
|
||||
#endregion: Llama Provider
|
||||
|
||||
#region: Tier 4 Analysis
|
||||
|
||||
def run_tier4_analysis(stderr: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user