30c8b26381
Phase 2 deferred t2_6: update src/ai_client.py _send_grok + _send_minimax +
_send_llama + _send_gemini_cli (4 functions) to use the new
dataclass API after NormalizedResponse was refactored to
(text, tool_calls: tuple[ToolCall, ...], usage: UsageStats, raw_response).
These 4 callers were left with the old keyword args
(usage_input_tokens, usage_output_tokens, ...) which broke at
runtime: ai_client.send() raised
TypeError: NormalizedResponse.__init__() got an unexpected keyword
argument 'usage_input_tokens'.
FIXES:
- src/ai_client.py L2054: gemini_cli 'adapter unavailable' branch
- src/ai_client.py L2088: gemini_cli normal response branch
- Added: from src.openai_schemas import UsageStats (module level)
- Added backward-compat in src/openai_compatible.py:
messages_dicts = [m.to_dict() if hasattr(m, 'to_dict') else m for m in request.messages]
(accepts both ChatMessage dataclass and dict for backward compat
with existing tests that pass raw dicts)
TEST FIXES:
- tests/test_ai_client_tool_loop.py: _make_normalized_response helper
uses UsageStats instead of usage_*_tokens kwargs
- tests/test_ai_client_tool_loop_builder.py: same
- tests/test_ai_client_tool_loop_send_func.py: same
- tests/test_openai_compatible.py: NormalizedResponse(text=..., usage=UsageStats(...))
+ tool_calls[0].function.name (attribute access) instead of ['function']['name']
- tests/test_auto_whitelist.py: use update_session_metadata() instead of
dict subscript assignment (Session dataclass doesn't support item assignment)
VERIFIED:
uv run pytest tests/test_ai_client_*.py tests/test_openai_*.py \
tests/test_auto_whitelist.py --timeout=30
56 passed in 4.49s (19 previously failing tests now pass)
uv run python scripts/audit_weak_types.py --strict
STRICT OK: 115 weak sites <= baseline 115
uv run python scripts/audit_dataclass_coverage.py --strict
STRICT OK: 200 weak sites <= baseline 207
This commit closes the t2_6 deferred task. The 41-site Phase 3 call-site
migration remains deferred (separate provider_state_migration track).
182 lines
6.0 KiB
Python
182 lines
6.0 KiB
Python
"""OpenAI-compatible API client for the Manual Slop ai_client layer.
|
|
|
|
Provides `send_openai_compatible(client, request, *, capabilities)` which
|
|
calls any OpenAI-compatible chat completion endpoint and returns a
|
|
`NormalizedResponse` (re-exported from src.openai_schemas).
|
|
|
|
CONVENTION: 1-space indentation. NO COMMENTS.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Callable, Optional
|
|
|
|
from openai import (
|
|
APIConnectionError,
|
|
APIStatusError,
|
|
AuthenticationError,
|
|
BadRequestError,
|
|
OpenAIError,
|
|
PermissionDeniedError,
|
|
RateLimitError,
|
|
)
|
|
|
|
from src.openai_schemas import (
|
|
ChatMessage,
|
|
NormalizedResponse,
|
|
OpenAICompatibleRequest,
|
|
ToolCall,
|
|
ToolCallFunction,
|
|
UsageStats,
|
|
)
|
|
from src.result_types import ErrorInfo, ErrorKind, Result
|
|
|
|
__all__ = [
|
|
"ChatMessage",
|
|
"NormalizedResponse",
|
|
"OpenAICompatibleRequest",
|
|
"ToolCall",
|
|
"ToolCallFunction",
|
|
"UsageStats",
|
|
]
|
|
|
|
|
|
def _to_typed_tool_call(tc: Any) -> ToolCall:
|
|
return ToolCall(
|
|
id=getattr(tc, "id", "") or "",
|
|
type=getattr(tc, "type", "function"),
|
|
function=ToolCallFunction(
|
|
name=getattr(tc.function, "name", "") or "",
|
|
arguments=getattr(tc.function, "arguments", "{}") or "{}",
|
|
),
|
|
)
|
|
|
|
|
|
def _to_dict_tool_call(tc: ToolCall) -> dict[str, Any]:
|
|
return tc.to_dict()
|
|
|
|
|
|
def _classify_openai_compatible_error(exc: Exception, source: str = "openai_compatible") -> ErrorInfo:
|
|
if isinstance(exc, RateLimitError):
|
|
return ErrorInfo(kind=ErrorKind.RATE_LIMIT, message=str(exc), source=source, original=exc)
|
|
if isinstance(exc, AuthenticationError) or isinstance(exc, PermissionDeniedError):
|
|
return ErrorInfo(kind=ErrorKind.AUTH, message=str(exc), source=source, original=exc)
|
|
if isinstance(exc, APIConnectionError):
|
|
return ErrorInfo(kind=ErrorKind.NETWORK, message=str(exc), source=source, original=exc)
|
|
if isinstance(exc, APIStatusError):
|
|
code = getattr(exc, "status_code", 0)
|
|
if code == 402:
|
|
return ErrorInfo(kind=ErrorKind.BALANCE, message=str(exc), source=source, original=exc)
|
|
if code == 429:
|
|
return ErrorInfo(kind=ErrorKind.RATE_LIMIT, message=str(exc), source=source, original=exc)
|
|
if code in (401, 403):
|
|
return ErrorInfo(kind=ErrorKind.AUTH, message=str(exc), source=source, original=exc)
|
|
if code in (500, 502, 503, 504):
|
|
return ErrorInfo(kind=ErrorKind.NETWORK, message=str(exc), source=source, original=exc)
|
|
if isinstance(exc, BadRequestError):
|
|
return ErrorInfo(kind=ErrorKind.QUOTA, message=str(exc), source=source, original=exc)
|
|
return ErrorInfo(kind=ErrorKind.UNKNOWN, message=str(exc), source=source, original=exc)
|
|
|
|
|
|
def send_openai_compatible(
|
|
client: Any,
|
|
request: OpenAICompatibleRequest,
|
|
*,
|
|
capabilities: Any,
|
|
) -> Result[NormalizedResponse]:
|
|
messages_dicts = [m.to_dict() if hasattr(m, "to_dict") else m for m in request.messages]
|
|
kwargs: dict[str, Any] = {
|
|
"model": request.model,
|
|
"messages": messages_dicts,
|
|
"temperature": request.temperature,
|
|
"top_p": request.top_p,
|
|
"max_tokens": request.max_tokens,
|
|
"stream": request.stream,
|
|
}
|
|
if request.tools is not None:
|
|
kwargs["tools"] = request.tools
|
|
kwargs["tool_choice"] = request.tool_choice
|
|
if request.extra_body:
|
|
kwargs["extra_body"] = request.extra_body
|
|
try:
|
|
if request.stream:
|
|
response = _send_streaming(client, kwargs, request.stream_callback)
|
|
else:
|
|
response = _send_blocking(client, kwargs)
|
|
return Result(data=response)
|
|
except OpenAIError as exc:
|
|
empty_resp = NormalizedResponse(
|
|
text="",
|
|
tool_calls=(),
|
|
usage=UsageStats(input_tokens=0, output_tokens=0),
|
|
raw_response=None,
|
|
)
|
|
return Result(data=empty_resp, errors=[_classify_openai_compatible_error(exc, source="openai_compatible")])
|
|
|
|
|
|
def _send_blocking(client: Any, kwargs: dict[str, Any]) -> NormalizedResponse:
|
|
resp = client.chat.completions.create(**kwargs)
|
|
msg = resp.choices[0].message
|
|
tool_calls_raw = msg.tool_calls or []
|
|
tool_calls: tuple[ToolCall, ...] = tuple(_to_typed_tool_call(tc) for tc in tool_calls_raw)
|
|
usage = getattr(resp, "usage", None)
|
|
return NormalizedResponse(
|
|
text=msg.content or "",
|
|
tool_calls=tool_calls,
|
|
usage=UsageStats(
|
|
input_tokens=int(getattr(usage, "prompt_tokens", 0) or 0),
|
|
output_tokens=int(getattr(usage, "completion_tokens", 0) or 0),
|
|
),
|
|
raw_response=resp,
|
|
)
|
|
|
|
|
|
def _send_streaming(client: Any, kwargs: dict[str, Any], callback: Optional[Callable[[str], None]]) -> NormalizedResponse:
|
|
kwargs_stream = dict(kwargs)
|
|
kwargs_stream["stream"] = True
|
|
kwargs_stream["stream_options"] = {"include_usage": True}
|
|
chunks_iter = client.chat.completions.create(**kwargs_stream)
|
|
text_parts: list[str] = []
|
|
tool_calls_acc: dict[int, dict[str, Any]] = {}
|
|
usage_input = 0
|
|
usage_output = 0
|
|
for chunk in chunks_iter:
|
|
for choice in getattr(chunk, "choices", []) or []:
|
|
delta = getattr(choice, "delta", None)
|
|
if delta is None:
|
|
continue
|
|
if delta.content:
|
|
text_parts.append(delta.content)
|
|
if callback:
|
|
callback(delta.content)
|
|
for tc in getattr(delta, "tool_calls", None) or []:
|
|
idx = getattr(tc, "index", 0)
|
|
if idx not in tool_calls_acc:
|
|
tool_calls_acc[idx] = {"id": None, "type": "function", "function": {"name": None, "arguments": ""}}
|
|
if getattr(tc, "id", None):
|
|
tool_calls_acc[idx]["id"] = tc.id
|
|
if getattr(tc, "function", None):
|
|
if tc.function.name:
|
|
tool_calls_acc[idx]["function"]["name"] = tc.function.name
|
|
if tc.function.arguments:
|
|
tool_calls_acc[idx]["function"]["arguments"] += tc.function.arguments
|
|
chunk_usage = getattr(chunk, "usage", None)
|
|
if chunk_usage is not None:
|
|
usage_input = int(getattr(chunk_usage, "prompt_tokens", 0) or 0)
|
|
usage_output = int(getattr(chunk_usage, "completion_tokens", 0) or 0)
|
|
tool_calls_typed: tuple[ToolCall, ...] = tuple(
|
|
ToolCall(
|
|
id=acc["id"] or "",
|
|
type=acc["type"],
|
|
function=ToolCallFunction(
|
|
name=acc["function"]["name"] or "",
|
|
arguments=acc["function"]["arguments"] or "{}",
|
|
),
|
|
)
|
|
for acc in (tool_calls_acc[k] for k in sorted(tool_calls_acc.keys()))
|
|
)
|
|
return NormalizedResponse(
|
|
text="".join(text_parts),
|
|
tool_calls=tool_calls_typed,
|
|
usage=UsageStats(input_tokens=usage_input, output_tokens=usage_output),
|
|
raw_response=None,
|
|
) |