d7d7d5cef9
Green phase: src/openai_compatible.py now exists and all 6 Red-phase tests in tests/test_openai_compatible.py pass. Implementation (144 lines, 1-space indent, no comments): Data structures: - NormalizedResponse: frozen dataclass with text, tool_calls, usage_input_tokens, usage_output_tokens, usage_cache_read_tokens, usage_cache_creation_tokens, raw_response - OpenAICompatibleRequest: regular dataclass with messages, model, temperature=0.0, top_p=1.0, max_tokens=8192, tools=None, tool_choice='auto', stream=False, stream_callback=None Algorithms: - send_openai_compatible(client, request, *, capabilities) -> NormalizedResponse Dispatches to _send_blocking or _send_streaming based on request.stream. Catches openai.OpenAIError and re-raises as classified ProviderError. - _send_blocking: extracts message text + tool_calls, converts tool_calls to dicts via _to_dict_tool_call, reads usage.prompt_tokens / usage.completion_tokens (with int() coercion for MagicMock test compat). - _send_streaming: iterates chunks, accumulates text parts, aggregates tool_calls by index, fires stream_callback per text delta, reads chunk.usage for final token counts. - _classify_openai_compatible_error: maps RateLimitError -> 'rate_limit', AuthenticationError/PermissionDeniedError -> 'auth', APIConnectionError -> 'network', APIStatusError with 402/429/401-403/500-504 -> 'balance'/ 'rate_limit'/'auth'/'network', BadRequestError -> 'quota', fallback 'unknown'. All use provider='openai_compatible'. Fixed plan's code smell: removed the 'MagicMock_noop' forward-reference class (defined after first use) and replaced with the cleaner Pythonic pattern 'int(getattr(usage, prompt_tokens, 0) or 0)'. Real OpenAI SDK always sets usage on responses; the defensive fallback was noise. Function-level import of ProviderError inside _classify_openai_compatible_error avoids any circular import risk.
144 lines
5.2 KiB
Python
144 lines
5.2 KiB
Python
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Optional
|
|
|
|
from openai import OpenAIError, RateLimitError, AuthenticationError, PermissionDeniedError, APIConnectionError, APIStatusError, BadRequestError
|
|
|
|
@dataclass(frozen=True)
|
|
class NormalizedResponse:
|
|
text: str
|
|
tool_calls: list[dict[str, Any]]
|
|
usage_input_tokens: int
|
|
usage_output_tokens: int
|
|
usage_cache_read_tokens: int
|
|
usage_cache_creation_tokens: int
|
|
raw_response: Any
|
|
|
|
@dataclass
|
|
class OpenAICompatibleRequest:
|
|
messages: list[dict[str, Any]]
|
|
model: str
|
|
temperature: float = 0.0
|
|
top_p: float = 1.0
|
|
max_tokens: int = 8192
|
|
tools: Optional[list[dict[str, Any]]] = None
|
|
tool_choice: str = "auto"
|
|
stream: bool = False
|
|
stream_callback: Optional[Callable[[str], None]] = None
|
|
|
|
def _to_dict_tool_call(tc: Any) -> dict[str, Any]:
|
|
return {
|
|
"id": getattr(tc, "id", None),
|
|
"type": getattr(tc, "type", "function"),
|
|
"function": {
|
|
"name": getattr(tc.function, "name", None),
|
|
"arguments": getattr(tc.function, "arguments", "{}"),
|
|
},
|
|
}
|
|
|
|
def _classify_openai_compatible_error(exc: Exception) -> "ProviderError":
|
|
from src.ai_client import ProviderError
|
|
if isinstance(exc, RateLimitError):
|
|
return ProviderError(kind="rate_limit", provider="openai_compatible", original=exc)
|
|
if isinstance(exc, AuthenticationError) or isinstance(exc, PermissionDeniedError):
|
|
return ProviderError(kind="auth", provider="openai_compatible", original=exc)
|
|
if isinstance(exc, APIConnectionError):
|
|
return ProviderError(kind="network", provider="openai_compatible", original=exc)
|
|
if isinstance(exc, APIStatusError):
|
|
code = getattr(exc, "status_code", 0)
|
|
if code == 402:
|
|
return ProviderError(kind="balance", provider="openai_compatible", original=exc)
|
|
if code == 429:
|
|
return ProviderError(kind="rate_limit", provider="openai_compatible", original=exc)
|
|
if code in (401, 403):
|
|
return ProviderError(kind="auth", provider="openai_compatible", original=exc)
|
|
if code in (500, 502, 503, 504):
|
|
return ProviderError(kind="network", provider="openai_compatible", original=exc)
|
|
if isinstance(exc, BadRequestError):
|
|
return ProviderError(kind="quota", provider="openai_compatible", original=exc)
|
|
return ProviderError(kind="unknown", provider="openai_compatible", original=exc)
|
|
|
|
def send_openai_compatible(
|
|
client: Any,
|
|
request: OpenAICompatibleRequest,
|
|
*,
|
|
capabilities: Any,
|
|
) -> NormalizedResponse:
|
|
kwargs: dict[str, Any] = {
|
|
"model": request.model,
|
|
"messages": request.messages,
|
|
"temperature": request.temperature,
|
|
"top_p": request.top_p,
|
|
"max_tokens": request.max_tokens,
|
|
"stream": request.stream,
|
|
}
|
|
if request.tools is not None:
|
|
kwargs["tools"] = request.tools
|
|
kwargs["tool_choice"] = request.tool_choice
|
|
try:
|
|
if request.stream:
|
|
return _send_streaming(client, kwargs, request.stream_callback)
|
|
return _send_blocking(client, kwargs)
|
|
except OpenAIError as exc:
|
|
raise _classify_openai_compatible_error(exc) from exc
|
|
|
|
def _send_blocking(client: Any, kwargs: dict[str, Any]) -> NormalizedResponse:
|
|
resp = client.chat.completions.create(**kwargs)
|
|
msg = resp.choices[0].message
|
|
tool_calls_raw = msg.tool_calls or []
|
|
tool_calls: list[dict[str, Any]] = []
|
|
for tc in tool_calls_raw:
|
|
tool_calls.append(_to_dict_tool_call(tc))
|
|
usage = getattr(resp, "usage", None)
|
|
return NormalizedResponse(
|
|
text=msg.content or "",
|
|
tool_calls=tool_calls,
|
|
usage_input_tokens=int(getattr(usage, "prompt_tokens", 0) or 0),
|
|
usage_output_tokens=int(getattr(usage, "completion_tokens", 0) or 0),
|
|
usage_cache_read_tokens=0,
|
|
usage_cache_creation_tokens=0,
|
|
raw_response=resp,
|
|
)
|
|
|
|
def _send_streaming(client: Any, kwargs: dict[str, Any], callback: Optional[Callable[[str], None]]) -> NormalizedResponse:
|
|
kwargs_stream = dict(kwargs)
|
|
kwargs_stream["stream"] = True
|
|
kwargs_stream["stream_options"] = {"include_usage": True}
|
|
chunks_iter = client.chat.completions.create(**kwargs_stream)
|
|
text_parts: list[str] = []
|
|
tool_calls_acc: dict[int, dict[str, Any]] = {}
|
|
usage_input = 0
|
|
usage_output = 0
|
|
for chunk in chunks_iter:
|
|
for choice in getattr(chunk, "choices", []) or []:
|
|
delta = getattr(choice, "delta", None)
|
|
if delta is None:
|
|
continue
|
|
if delta.content:
|
|
text_parts.append(delta.content)
|
|
if callback:
|
|
callback(delta.content)
|
|
for tc in getattr(delta, "tool_calls", None) or []:
|
|
idx = getattr(tc, "index", 0)
|
|
if idx not in tool_calls_acc:
|
|
tool_calls_acc[idx] = {"id": None, "type": "function", "function": {"name": None, "arguments": ""}}
|
|
if getattr(tc, "id", None):
|
|
tool_calls_acc[idx]["id"] = tc.id
|
|
if getattr(tc, "function", None):
|
|
if tc.function.name:
|
|
tool_calls_acc[idx]["function"]["name"] = tc.function.name
|
|
if tc.function.arguments:
|
|
tool_calls_acc[idx]["function"]["arguments"] += tc.function.arguments
|
|
chunk_usage = getattr(chunk, "usage", None)
|
|
if chunk_usage is not None:
|
|
usage_input = int(getattr(chunk_usage, "prompt_tokens", 0) or 0)
|
|
usage_output = int(getattr(chunk_usage, "completion_tokens", 0) or 0)
|
|
return NormalizedResponse(
|
|
text="".join(text_parts),
|
|
tool_calls=[tool_calls_acc[k] for k in sorted(tool_calls_acc.keys())],
|
|
usage_input_tokens=usage_input,
|
|
usage_output_tokens=usage_output,
|
|
usage_cache_read_tokens=0,
|
|
usage_cache_creation_tokens=0,
|
|
raw_response=None,
|
|
) |