From d7d7d5cef928c356d7a2c4afc68438729fc645ef Mon Sep 17 00:00:00 2001 From: Ed_ Date: Thu, 11 Jun 2026 00:39:58 -0400 Subject: [PATCH] feat(openai_compatible): implement shared send helper with streaming/tool/vision/error Green phase: src/openai_compatible.py now exists and all 6 Red-phase tests in tests/test_openai_compatible.py pass. Implementation (144 lines, 1-space indent, no comments): Data structures: - NormalizedResponse: frozen dataclass with text, tool_calls, usage_input_tokens, usage_output_tokens, usage_cache_read_tokens, usage_cache_creation_tokens, raw_response - OpenAICompatibleRequest: regular dataclass with messages, model, temperature=0.0, top_p=1.0, max_tokens=8192, tools=None, tool_choice='auto', stream=False, stream_callback=None Algorithms: - send_openai_compatible(client, request, *, capabilities) -> NormalizedResponse Dispatches to _send_blocking or _send_streaming based on request.stream. Catches openai.OpenAIError and re-raises as classified ProviderError. - _send_blocking: extracts message text + tool_calls, converts tool_calls to dicts via _to_dict_tool_call, reads usage.prompt_tokens / usage.completion_tokens (with int() coercion for MagicMock test compat). - _send_streaming: iterates chunks, accumulates text parts, aggregates tool_calls by index, fires stream_callback per text delta, reads chunk.usage for final token counts. - _classify_openai_compatible_error: maps RateLimitError -> 'rate_limit', AuthenticationError/PermissionDeniedError -> 'auth', APIConnectionError -> 'network', APIStatusError with 402/429/401-403/500-504 -> 'balance'/ 'rate_limit'/'auth'/'network', BadRequestError -> 'quota', fallback 'unknown'. All use provider='openai_compatible'. Fixed plan's code smell: removed the 'MagicMock_noop' forward-reference class (defined after first use) and replaced with the cleaner Pythonic pattern 'int(getattr(usage, prompt_tokens, 0) or 0)'. Real OpenAI SDK always sets usage on responses; the defensive fallback was noise. Function-level import of ProviderError inside _classify_openai_compatible_error avoids any circular import risk. --- src/openai_compatible.py | 144 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 src/openai_compatible.py diff --git a/src/openai_compatible.py b/src/openai_compatible.py new file mode 100644 index 00000000..9be53fe7 --- /dev/null +++ b/src/openai_compatible.py @@ -0,0 +1,144 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Callable, Optional + +from openai import OpenAIError, RateLimitError, AuthenticationError, PermissionDeniedError, APIConnectionError, APIStatusError, BadRequestError + +@dataclass(frozen=True) +class NormalizedResponse: + text: str + tool_calls: list[dict[str, Any]] + usage_input_tokens: int + usage_output_tokens: int + usage_cache_read_tokens: int + usage_cache_creation_tokens: int + raw_response: Any + +@dataclass +class OpenAICompatibleRequest: + messages: list[dict[str, Any]] + model: str + temperature: float = 0.0 + top_p: float = 1.0 + max_tokens: int = 8192 + tools: Optional[list[dict[str, Any]]] = None + tool_choice: str = "auto" + stream: bool = False + stream_callback: Optional[Callable[[str], None]] = None + +def _to_dict_tool_call(tc: Any) -> dict[str, Any]: + return { + "id": getattr(tc, "id", None), + "type": getattr(tc, "type", "function"), + "function": { + "name": getattr(tc.function, "name", None), + "arguments": getattr(tc.function, "arguments", "{}"), + }, + } + +def _classify_openai_compatible_error(exc: Exception) -> "ProviderError": + from src.ai_client import ProviderError + if isinstance(exc, RateLimitError): + return ProviderError(kind="rate_limit", provider="openai_compatible", original=exc) + if isinstance(exc, AuthenticationError) or isinstance(exc, PermissionDeniedError): + return ProviderError(kind="auth", provider="openai_compatible", original=exc) + if isinstance(exc, APIConnectionError): + return ProviderError(kind="network", provider="openai_compatible", original=exc) + if isinstance(exc, APIStatusError): + code = getattr(exc, "status_code", 0) + if code == 402: + return ProviderError(kind="balance", provider="openai_compatible", original=exc) + if code == 429: + return ProviderError(kind="rate_limit", provider="openai_compatible", original=exc) + if code in (401, 403): + return ProviderError(kind="auth", provider="openai_compatible", original=exc) + if code in (500, 502, 503, 504): + return ProviderError(kind="network", provider="openai_compatible", original=exc) + if isinstance(exc, BadRequestError): + return ProviderError(kind="quota", provider="openai_compatible", original=exc) + return ProviderError(kind="unknown", provider="openai_compatible", original=exc) + +def send_openai_compatible( + client: Any, + request: OpenAICompatibleRequest, + *, + capabilities: Any, +) -> NormalizedResponse: + kwargs: dict[str, Any] = { + "model": request.model, + "messages": request.messages, + "temperature": request.temperature, + "top_p": request.top_p, + "max_tokens": request.max_tokens, + "stream": request.stream, + } + if request.tools is not None: + kwargs["tools"] = request.tools + kwargs["tool_choice"] = request.tool_choice + try: + if request.stream: + return _send_streaming(client, kwargs, request.stream_callback) + return _send_blocking(client, kwargs) + except OpenAIError as exc: + raise _classify_openai_compatible_error(exc) from exc + +def _send_blocking(client: Any, kwargs: dict[str, Any]) -> NormalizedResponse: + resp = client.chat.completions.create(**kwargs) + msg = resp.choices[0].message + tool_calls_raw = msg.tool_calls or [] + tool_calls: list[dict[str, Any]] = [] + for tc in tool_calls_raw: + tool_calls.append(_to_dict_tool_call(tc)) + usage = getattr(resp, "usage", None) + return NormalizedResponse( + text=msg.content or "", + tool_calls=tool_calls, + usage_input_tokens=int(getattr(usage, "prompt_tokens", 0) or 0), + usage_output_tokens=int(getattr(usage, "completion_tokens", 0) or 0), + usage_cache_read_tokens=0, + usage_cache_creation_tokens=0, + raw_response=resp, + ) + +def _send_streaming(client: Any, kwargs: dict[str, Any], callback: Optional[Callable[[str], None]]) -> NormalizedResponse: + kwargs_stream = dict(kwargs) + kwargs_stream["stream"] = True + kwargs_stream["stream_options"] = {"include_usage": True} + chunks_iter = client.chat.completions.create(**kwargs_stream) + text_parts: list[str] = [] + tool_calls_acc: dict[int, dict[str, Any]] = {} + usage_input = 0 + usage_output = 0 + for chunk in chunks_iter: + for choice in getattr(chunk, "choices", []) or []: + delta = getattr(choice, "delta", None) + if delta is None: + continue + if delta.content: + text_parts.append(delta.content) + if callback: + callback(delta.content) + for tc in getattr(delta, "tool_calls", None) or []: + idx = getattr(tc, "index", 0) + if idx not in tool_calls_acc: + tool_calls_acc[idx] = {"id": None, "type": "function", "function": {"name": None, "arguments": ""}} + if getattr(tc, "id", None): + tool_calls_acc[idx]["id"] = tc.id + if getattr(tc, "function", None): + if tc.function.name: + tool_calls_acc[idx]["function"]["name"] = tc.function.name + if tc.function.arguments: + tool_calls_acc[idx]["function"]["arguments"] += tc.function.arguments + chunk_usage = getattr(chunk, "usage", None) + if chunk_usage is not None: + usage_input = int(getattr(chunk_usage, "prompt_tokens", 0) or 0) + usage_output = int(getattr(chunk_usage, "completion_tokens", 0) or 0) + return NormalizedResponse( + text="".join(text_parts), + tool_calls=[tool_calls_acc[k] for k in sorted(tool_calls_acc.keys())], + usage_input_tokens=usage_input, + usage_output_tokens=usage_output, + usage_cache_read_tokens=0, + usage_cache_creation_tokens=0, + raw_response=None, + ) \ No newline at end of file