from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable, Optional from openai import OpenAIError, RateLimitError, AuthenticationError, PermissionDeniedError, APIConnectionError, APIStatusError, BadRequestError @dataclass(frozen=True) class NormalizedResponse: text: str tool_calls: list[dict[str, Any]] usage_input_tokens: int usage_output_tokens: int usage_cache_read_tokens: int usage_cache_creation_tokens: int raw_response: Any @dataclass class OpenAICompatibleRequest: messages: list[dict[str, Any]] model: str temperature: float = 0.0 top_p: float = 1.0 max_tokens: int = 8192 tools: Optional[list[dict[str, Any]]] = None tool_choice: str = "auto" stream: bool = False stream_callback: Optional[Callable[[str], None]] = None def _to_dict_tool_call(tc: Any) -> dict[str, Any]: return { "id": getattr(tc, "id", None), "type": getattr(tc, "type", "function"), "function": { "name": getattr(tc.function, "name", None), "arguments": getattr(tc.function, "arguments", "{}"), }, } def _classify_openai_compatible_error(exc: Exception) -> "ProviderError": from src.ai_client import ProviderError if isinstance(exc, RateLimitError): return ProviderError(kind="rate_limit", provider="openai_compatible", original=exc) if isinstance(exc, AuthenticationError) or isinstance(exc, PermissionDeniedError): return ProviderError(kind="auth", provider="openai_compatible", original=exc) if isinstance(exc, APIConnectionError): return ProviderError(kind="network", provider="openai_compatible", original=exc) if isinstance(exc, APIStatusError): code = getattr(exc, "status_code", 0) if code == 402: return ProviderError(kind="balance", provider="openai_compatible", original=exc) if code == 429: return ProviderError(kind="rate_limit", provider="openai_compatible", original=exc) if code in (401, 403): return ProviderError(kind="auth", provider="openai_compatible", original=exc) if code in (500, 502, 503, 504): return ProviderError(kind="network", provider="openai_compatible", original=exc) if isinstance(exc, BadRequestError): return ProviderError(kind="quota", provider="openai_compatible", original=exc) return ProviderError(kind="unknown", provider="openai_compatible", original=exc) def send_openai_compatible( client: Any, request: OpenAICompatibleRequest, *, capabilities: Any, ) -> NormalizedResponse: kwargs: dict[str, Any] = { "model": request.model, "messages": request.messages, "temperature": request.temperature, "top_p": request.top_p, "max_tokens": request.max_tokens, "stream": request.stream, } if request.tools is not None: kwargs["tools"] = request.tools kwargs["tool_choice"] = request.tool_choice try: if request.stream: return _send_streaming(client, kwargs, request.stream_callback) return _send_blocking(client, kwargs) except OpenAIError as exc: raise _classify_openai_compatible_error(exc) from exc def _send_blocking(client: Any, kwargs: dict[str, Any]) -> NormalizedResponse: resp = client.chat.completions.create(**kwargs) msg = resp.choices[0].message tool_calls_raw = msg.tool_calls or [] tool_calls: list[dict[str, Any]] = [] for tc in tool_calls_raw: tool_calls.append(_to_dict_tool_call(tc)) usage = getattr(resp, "usage", None) return NormalizedResponse( text=msg.content or "", tool_calls=tool_calls, usage_input_tokens=int(getattr(usage, "prompt_tokens", 0) or 0), usage_output_tokens=int(getattr(usage, "completion_tokens", 0) or 0), usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=resp, ) def _send_streaming(client: Any, kwargs: dict[str, Any], callback: Optional[Callable[[str], None]]) -> NormalizedResponse: kwargs_stream = dict(kwargs) kwargs_stream["stream"] = True kwargs_stream["stream_options"] = {"include_usage": True} chunks_iter = client.chat.completions.create(**kwargs_stream) text_parts: list[str] = [] tool_calls_acc: dict[int, dict[str, Any]] = {} usage_input = 0 usage_output = 0 for chunk in chunks_iter: for choice in getattr(chunk, "choices", []) or []: delta = getattr(choice, "delta", None) if delta is None: continue if delta.content: text_parts.append(delta.content) if callback: callback(delta.content) for tc in getattr(delta, "tool_calls", None) or []: idx = getattr(tc, "index", 0) if idx not in tool_calls_acc: tool_calls_acc[idx] = {"id": None, "type": "function", "function": {"name": None, "arguments": ""}} if getattr(tc, "id", None): tool_calls_acc[idx]["id"] = tc.id if getattr(tc, "function", None): if tc.function.name: tool_calls_acc[idx]["function"]["name"] = tc.function.name if tc.function.arguments: tool_calls_acc[idx]["function"]["arguments"] += tc.function.arguments chunk_usage = getattr(chunk, "usage", None) if chunk_usage is not None: usage_input = int(getattr(chunk_usage, "prompt_tokens", 0) or 0) usage_output = int(getattr(chunk_usage, "completion_tokens", 0) or 0) return NormalizedResponse( text="".join(text_parts), tool_calls=[tool_calls_acc[k] for k in sorted(tool_calls_acc.keys())], usage_input_tokens=usage_input, usage_output_tokens=usage_output, usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=None, )