feat(ai): Add Qwen provider support to ai_client
This commit is contained in:
@@ -131,6 +131,11 @@ _minimax_client: Any = None
|
||||
_minimax_history: list[dict[str, Any]] = []
|
||||
_minimax_history_lock: threading.Lock = threading.Lock()
|
||||
|
||||
_qwen_client: Any = None
|
||||
_qwen_history: list[dict[str, Any]] = []
|
||||
_qwen_history_lock: threading.Lock = threading.Lock()
|
||||
_qwen_region: str = "china"
|
||||
|
||||
_send_lock: threading.Lock = threading.Lock()
|
||||
|
||||
_BIAS_ENGINE = ToolBiasEngine()
|
||||
@@ -486,6 +491,7 @@ def reset_session() -> None:
|
||||
global _anthropic_client, _anthropic_history
|
||||
global _deepseek_client, _deepseek_history
|
||||
global _minimax_client, _minimax_history
|
||||
global _qwen_client, _qwen_history
|
||||
global _CACHED_ANTHROPIC_TOOLS, _CACHED_DEEPSEEK_TOOLS
|
||||
global _gemini_cli_adapter
|
||||
if _gemini_client and _gemini_cache:
|
||||
@@ -513,6 +519,9 @@ def reset_session() -> None:
|
||||
_minimax_client = None
|
||||
with _minimax_history_lock:
|
||||
_minimax_history = []
|
||||
_qwen_client = None
|
||||
with _qwen_history_lock:
|
||||
_qwen_history = []
|
||||
_CACHED_ANTHROPIC_TOOLS = None
|
||||
_CACHED_DEEPSEEK_TOOLS = None
|
||||
file_cache.reset_client()
|
||||
@@ -527,6 +536,7 @@ def list_models(provider: str) -> list[str]:
|
||||
elif provider == "deepseek": return _list_deepseek_models(creds["deepseek"]["api_key"])
|
||||
elif provider == "gemini_cli": return _list_gemini_cli_models()
|
||||
elif provider == "minimax": return _list_minimax_models(creds["minimax"]["api_key"])
|
||||
elif provider == "qwen": return _list_qwen_models()
|
||||
return []
|
||||
|
||||
#endregion: Comms Log
|
||||
@@ -2369,6 +2379,110 @@ def _send_minimax(md_content: str, user_message: str, base_dir: str,
|
||||
|
||||
#endregion: MiniMax Provider
|
||||
|
||||
#region: Qwen Provider
|
||||
|
||||
def _ensure_qwen_client() -> None:
|
||||
global _qwen_client, _qwen_region
|
||||
if _qwen_client is None:
|
||||
import dashscope
|
||||
creds = _load_credentials()
|
||||
api_key = creds.get("qwen", {}).get("api_key")
|
||||
if not api_key:
|
||||
raise ValueError("Qwen API key not found in credentials.toml")
|
||||
_qwen_region = creds.get("qwen", {}).get("region", "china")
|
||||
dashscope.api_key = api_key
|
||||
_qwen_client = dashscope.Generation
|
||||
|
||||
def _dashscope_call(
|
||||
model: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
*,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
) -> dict[str, Any]:
|
||||
import dashscope
|
||||
from src.qwen_adapter import build_dashscope_tools
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"result_format": "message",
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = build_dashscope_tools(tools)
|
||||
resp = dashscope.Generation.call(**kwargs)
|
||||
if getattr(resp, "status_code", 200) != 200:
|
||||
from src.qwen_adapter import classify_dashscope_error
|
||||
raise classify_dashscope_error(_dashscope_exception_from_response(resp))
|
||||
return {
|
||||
"text": resp.output.text if hasattr(resp, "output") and resp.output else "",
|
||||
"tool_calls": _extract_dashscope_tool_calls(resp),
|
||||
"usage": {
|
||||
"input_tokens": getattr(resp.usage, "input_tokens", 0) if hasattr(resp, "usage") and resp.usage else 0,
|
||||
"output_tokens": getattr(resp.usage, "output_tokens", 0) if hasattr(resp, "usage") and resp.usage else 0,
|
||||
},
|
||||
}
|
||||
|
||||
def _dashscope_exception_from_response(resp: Any) -> Exception:
|
||||
msg = getattr(resp, "message", "unknown dashscope error")
|
||||
return RuntimeError(msg)
|
||||
|
||||
def _extract_dashscope_tool_calls(resp: Any) -> list[dict[str, Any]]:
|
||||
out: list[dict[str, Any]] = []
|
||||
if not (hasattr(resp, "output") and resp.output and getattr(resp.output, "tool_calls", None)):
|
||||
return out
|
||||
for tc in resp.output.tool_calls:
|
||||
out.append({
|
||||
"id": getattr(tc, "id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": getattr(tc.function, "name", "") if hasattr(tc, "function") else "",
|
||||
"arguments": getattr(tc.function, "arguments", "{}") if hasattr(tc, "function") else "{}",
|
||||
},
|
||||
})
|
||||
return out
|
||||
|
||||
def _list_qwen_models() -> list[str]:
|
||||
from src.vendor_capabilities import list_models_for_vendor
|
||||
return list_models_for_vendor("qwen")
|
||||
|
||||
def _send_qwen(md_content: str, user_message: str, base_dir: str,
|
||||
file_items: list[dict[str, Any]] | None = None,
|
||||
discussion_history: str = "",
|
||||
stream: bool = False,
|
||||
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
||||
qa_callback: Optional[Callable[[str], str]] = None,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
|
||||
_ensure_qwen_client()
|
||||
with _qwen_history_lock:
|
||||
user_content = user_message
|
||||
if file_items:
|
||||
for fi in file_items:
|
||||
if fi.get("is_image") and fi.get("base64_data"):
|
||||
user_content = f"[IMAGE: {fi.get('path', 'attachment')}]\n{user_content}"
|
||||
if discussion_history and not _qwen_history:
|
||||
_qwen_history.append({"role": "user", "content": f"[DISCUSSION HISTORY]\n\n{discussion_history}\n\n---\n\n{user_message}"})
|
||||
else:
|
||||
_qwen_history.append({"role": "user", "content": user_content})
|
||||
messages = [{"role": "system", "content": f"{_get_combined_system_prompt()}\n\n<context>\n{md_content}\n</context>"}]
|
||||
messages.extend(_qwen_history)
|
||||
resp = _dashscope_call(
|
||||
model=_model,
|
||||
messages=messages,
|
||||
tools=None,
|
||||
max_tokens=_max_tokens,
|
||||
temperature=_temperature,
|
||||
top_p=_top_p,
|
||||
)
|
||||
return resp.get("text", "")
|
||||
|
||||
#endregion: Qwen Provider
|
||||
|
||||
#region: Tier 4 Analysis
|
||||
|
||||
def run_tier4_analysis(stderr: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user