Private
Public Access
0
0
Files
manual_slop/src/openai_compatible.py
T

149 lines
5.5 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Optional
from openai import OpenAIError, RateLimitError, AuthenticationError, PermissionDeniedError, APIConnectionError, APIStatusError, BadRequestError
from src.result_types import ErrorInfo, ErrorKind, Result
@dataclass(frozen=True)
class NormalizedResponse:
text: str
tool_calls: list[dict[str, Any]]
usage_input_tokens: int
usage_output_tokens: int
usage_cache_read_tokens: int
usage_cache_creation_tokens: int
raw_response: Any
@dataclass
class OpenAICompatibleRequest:
messages: list[dict[str, Any]]
model: str
temperature: float = 0.0
top_p: float = 1.0
max_tokens: int = 8192
tools: Optional[list[dict[str, Any]]] = None
tool_choice: str = "auto"
stream: bool = False
stream_callback: Optional[Callable[[str], None]] = None
extra_body: Optional[dict[str, Any]] = None
def _to_dict_tool_call(tc: Any) -> dict[str, Any]:
return {
"id": getattr(tc, "id", None),
"type": getattr(tc, "type", "function"),
"function": {
"name": getattr(tc.function, "name", None),
"arguments": getattr(tc.function, "arguments", "{}"),
},
}
def _classify_openai_compatible_error(exc: Exception, source: str = "openai_compatible") -> ErrorInfo:
if isinstance(exc, RateLimitError):
return ErrorInfo(kind=ErrorKind.RATE_LIMIT, message=str(exc), source=source, original=exc)
if isinstance(exc, AuthenticationError) or isinstance(exc, PermissionDeniedError):
return ErrorInfo(kind=ErrorKind.AUTH, message=str(exc), source=source, original=exc)
if isinstance(exc, APIConnectionError):
return ErrorInfo(kind=ErrorKind.NETWORK, message=str(exc), source=source, original=exc)
if isinstance(exc, APIStatusError):
code = getattr(exc, "status_code", 0)
if code == 402:
return ErrorInfo(kind=ErrorKind.BALANCE, message=str(exc), source=source, original=exc)
if code == 429:
return ErrorInfo(kind=ErrorKind.RATE_LIMIT, message=str(exc), source=source, original=exc)
if code in (401, 403):
return ErrorInfo(kind=ErrorKind.AUTH, message=str(exc), source=source, original=exc)
if code in (500, 502, 503, 504):
return ErrorInfo(kind=ErrorKind.NETWORK, message=str(exc), source=source, original=exc)
if isinstance(exc, BadRequestError):
return ErrorInfo(kind=ErrorKind.QUOTA, message=str(exc), source=source, original=exc)
return ErrorInfo(kind=ErrorKind.UNKNOWN, message=str(exc), source=source, original=exc)
def send_openai_compatible(
client: Any,
request: OpenAICompatibleRequest,
*,
capabilities: Any,
) -> Result[str]:
kwargs: dict[str, Any] = {
"model": request.model,
"messages": request.messages,
"temperature": request.temperature,
"top_p": request.top_p,
"max_tokens": request.max_tokens,
"stream": request.stream,
}
if request.tools is not None:
kwargs["tools"] = request.tools
kwargs["tool_choice"] = request.tool_choice
if request.extra_body:
kwargs["extra_body"] = request.extra_body
try:
if request.stream:
response = _send_streaming(client, kwargs, request.stream_callback)
else:
response = _send_blocking(client, kwargs)
return Result(data=response.text)
except OpenAIError as exc:
return Result(data="", errors=[_classify_openai_compatible_error(exc, source="openai_compatible")])
def _send_blocking(client: Any, kwargs: dict[str, Any]) -> NormalizedResponse:
resp = client.chat.completions.create(**kwargs)
msg = resp.choices[0].message
tool_calls_raw = msg.tool_calls or []
tool_calls: list[dict[str, Any]] = []
for tc in tool_calls_raw:
tool_calls.append(_to_dict_tool_call(tc))
usage = getattr(resp, "usage", None)
return NormalizedResponse(
text=msg.content or "",
tool_calls=tool_calls,
usage_input_tokens=int(getattr(usage, "prompt_tokens", 0) or 0),
usage_output_tokens=int(getattr(usage, "completion_tokens", 0) or 0),
usage_cache_read_tokens=0,
usage_cache_creation_tokens=0,
raw_response=resp,
)
def _send_streaming(client: Any, kwargs: dict[str, Any], callback: Optional[Callable[[str], None]]) -> NormalizedResponse:
kwargs_stream = dict(kwargs)
kwargs_stream["stream"] = True
kwargs_stream["stream_options"] = {"include_usage": True}
chunks_iter = client.chat.completions.create(**kwargs_stream)
text_parts: list[str] = []
tool_calls_acc: dict[int, dict[str, Any]] = {}
usage_input = 0
usage_output = 0
for chunk in chunks_iter:
for choice in getattr(chunk, "choices", []) or []:
delta = getattr(choice, "delta", None)
if delta is None:
continue
if delta.content:
text_parts.append(delta.content)
if callback:
callback(delta.content)
for tc in getattr(delta, "tool_calls", None) or []:
idx = getattr(tc, "index", 0)
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {"id": None, "type": "function", "function": {"name": None, "arguments": ""}}
if getattr(tc, "id", None):
tool_calls_acc[idx]["id"] = tc.id
if getattr(tc, "function", None):
if tc.function.name:
tool_calls_acc[idx]["function"]["name"] = tc.function.name
if tc.function.arguments:
tool_calls_acc[idx]["function"]["arguments"] += tc.function.arguments
chunk_usage = getattr(chunk, "usage", None)
if chunk_usage is not None:
usage_input = int(getattr(chunk_usage, "prompt_tokens", 0) or 0)
usage_output = int(getattr(chunk_usage, "completion_tokens", 0) or 0)
return NormalizedResponse(
text="".join(text_parts),
tool_calls=[tool_calls_acc[k] for k in sorted(tool_calls_acc.keys())],
usage_input_tokens=usage_input,
usage_output_tokens=usage_output,
usage_cache_read_tokens=0,
usage_cache_creation_tokens=0,
raw_response=None,
)