refactor(ai_client): remove ProviderError class; ErrorInfo is the new error type
This commit is contained in:
@@ -76,31 +76,6 @@ _history_trunc_limit: int = 8000
|
||||
# Global event emitter for API lifecycle events
|
||||
events: EventEmitter = EventEmitter()
|
||||
|
||||
class ProviderError(Exception):
|
||||
def __init__(self, kind: str, provider: str, original: Exception) -> None:
|
||||
"""
|
||||
[C: src/api_hooks.py:HookServerInstance.__init__, src/mcp_client.py:_DDGParser.__init__, src/mcp_client.py:_TextExtractor.__init__]
|
||||
"""
|
||||
self.kind = kind
|
||||
self.provider = provider
|
||||
self.original = original
|
||||
super().__init__(str(original))
|
||||
|
||||
def ui_message(self) -> str:
|
||||
"""
|
||||
[C: src/app_controller.py:AppController._handle_request_event, src/app_controller.py:_api_generate]
|
||||
"""
|
||||
labels = {
|
||||
"quota": "QUOTA EXHAUSTED",
|
||||
"rate_limit": "RATE LIMITED",
|
||||
"auth": "AUTH / API KEY ERROR",
|
||||
"balance": "BALANCE / BILLING ERROR",
|
||||
"network": "NETWORK / CONNECTION ERROR",
|
||||
"unknown": "API ERROR",
|
||||
}
|
||||
label = labels.get(self.kind, "API ERROR")
|
||||
return f"[{self.provider.upper()} {label}]\n\n{self.original}"
|
||||
|
||||
#region: Provider Configuration
|
||||
|
||||
def set_model_params(temp: float, max_tok: int, trunc_limit: int = 8000, top_p: float = 1.0) -> None:
|
||||
@@ -1451,9 +1426,6 @@ def _send_anthropic_result(md_content: str, user_message: str, base_dir: str, fi
|
||||
res = final_text if final_text.strip() else "(No text returned by the model)"
|
||||
if monitor.enabled: monitor.end_component("ai_client._send_anthropic")
|
||||
return Result(data=res)
|
||||
except ProviderError:
|
||||
if monitor.enabled: monitor.end_component("ai_client._send_anthropic")
|
||||
raise
|
||||
except Exception as exc:
|
||||
if monitor.enabled: monitor.end_component("ai_client._send_anthropic")
|
||||
return Result(data="", errors=[_classify_anthropic_error(exc, source="ai_client.anthropic")])
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Callable, Optional
|
||||
|
||||
from openai import OpenAIError, RateLimitError, AuthenticationError, PermissionDeniedError, APIConnectionError, APIStatusError, BadRequestError
|
||||
|
||||
from src.result_types import ErrorInfo, ErrorKind
|
||||
from src.result_types import ErrorInfo, ErrorKind, Result
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizedResponse:
|
||||
@@ -64,7 +64,7 @@ def send_openai_compatible(
|
||||
request: OpenAICompatibleRequest,
|
||||
*,
|
||||
capabilities: Any,
|
||||
) -> NormalizedResponse:
|
||||
) -> Result[str]:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": request.model,
|
||||
"messages": request.messages,
|
||||
@@ -80,10 +80,12 @@ def send_openai_compatible(
|
||||
kwargs["extra_body"] = request.extra_body
|
||||
try:
|
||||
if request.stream:
|
||||
return _send_streaming(client, kwargs, request.stream_callback)
|
||||
return _send_blocking(client, kwargs)
|
||||
response = _send_streaming(client, kwargs, request.stream_callback)
|
||||
else:
|
||||
response = _send_blocking(client, kwargs)
|
||||
return Result(data=response.text)
|
||||
except OpenAIError as exc:
|
||||
raise _classify_openai_compatible_error(exc) from exc
|
||||
return Result(data="", errors=[_classify_openai_compatible_error(exc, source="openai_compatible")])
|
||||
|
||||
def _send_blocking(client: Any, kwargs: dict[str, Any]) -> NormalizedResponse:
|
||||
resp = client.chat.completions.create(**kwargs)
|
||||
|
||||
@@ -22,15 +22,14 @@ def _mock_completion(text: str = "hello", tool_calls=None, usage_input: int = 10
|
||||
m.usage.completion_tokens_details = None
|
||||
return m
|
||||
|
||||
def test_send_non_streaming_returns_normalized_response(caps: VendorCapabilities) -> None:
|
||||
def test_send_non_streaming_returns_text_in_result(caps: VendorCapabilities) -> None:
|
||||
client = MagicMock()
|
||||
client.chat.completions.create.return_value = _mock_completion("hi", usage_input=20, usage_output=10)
|
||||
request = OpenAICompatibleRequest(messages=[{"role": "user", "content": "ping"}], model="m", max_tokens=100)
|
||||
response = send_openai_compatible(client, request, capabilities=caps)
|
||||
assert response.text == "hi"
|
||||
assert response.tool_calls == []
|
||||
assert response.usage_input_tokens == 20
|
||||
assert response.usage_output_tokens == 10
|
||||
result = send_openai_compatible(client, request, capabilities=caps)
|
||||
assert result.ok
|
||||
assert result.data == "hi"
|
||||
assert result.errors == []
|
||||
|
||||
def test_send_streaming_aggregates_chunks(caps: VendorCapabilities) -> None:
|
||||
client = MagicMock()
|
||||
@@ -42,12 +41,13 @@ def test_send_streaming_aggregates_chunks(caps: VendorCapabilities) -> None:
|
||||
client.chat.completions.create.return_value = iter(chunks)
|
||||
received: list = []
|
||||
request = OpenAICompatibleRequest(messages=[{"role": "user", "content": "ping"}], model="m", stream=True, stream_callback=received.append)
|
||||
response = send_openai_compatible(client, request, capabilities=caps)
|
||||
assert response.text == "hello"
|
||||
result = send_openai_compatible(client, request, capabilities=caps)
|
||||
assert result.ok
|
||||
assert result.data == "hello"
|
||||
assert received == ["hel", "lo"]
|
||||
assert response.usage_input_tokens == 15
|
||||
|
||||
def test_tool_call_detection_in_response(caps: VendorCapabilities) -> None:
|
||||
def test_tool_call_detection_in_blocking_response(caps: VendorCapabilities) -> None:
|
||||
from src.openai_compatible import _send_blocking
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "call_1"
|
||||
tool_call.function.name = "read_file"
|
||||
@@ -55,8 +55,8 @@ def test_tool_call_detection_in_response(caps: VendorCapabilities) -> None:
|
||||
completion = _mock_completion(text="", tool_calls=[tool_call])
|
||||
client = MagicMock()
|
||||
client.chat.completions.create.return_value = completion
|
||||
request = OpenAICompatibleRequest(messages=[{"role": "user", "content": "ping"}], model="m")
|
||||
response = send_openai_compatible(client, request, capabilities=caps)
|
||||
kwargs = {"model": "m", "messages": [{"role": "user", "content": "ping"}], "temperature": 0.0, "top_p": 1.0, "max_tokens": 8192, "stream": False}
|
||||
response = _send_blocking(client, kwargs)
|
||||
assert len(response.tool_calls) == 1
|
||||
assert response.tool_calls[0]["function"]["name"] == "read_file"
|
||||
assert response.tool_calls[0]["id"] == "call_1"
|
||||
@@ -66,20 +66,21 @@ def test_vision_multimodal_message(caps: VendorCapabilities) -> None:
|
||||
client.chat.completions.create.return_value = _mock_completion("looks like a cat")
|
||||
messages = [{"role": "user", "content": [{"type": "text", "text": "what is this?"}, {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}]}]
|
||||
request = OpenAICompatibleRequest(messages=messages, model="m")
|
||||
response = send_openai_compatible(client, request, capabilities=caps)
|
||||
result = send_openai_compatible(client, request, capabilities=caps)
|
||||
sent_messages = client.chat.completions.create.call_args.kwargs["messages"]
|
||||
assert sent_messages[0]["content"] == messages[0]["content"]
|
||||
assert response.text == "looks like a cat"
|
||||
assert result.data == "looks like a cat"
|
||||
|
||||
def test_error_classification_429_to_rate_limit(caps: VendorCapabilities) -> None:
|
||||
from openai import RateLimitError
|
||||
from src.ai_client import ProviderError
|
||||
from src.result_types import Result, ErrorKind
|
||||
client = MagicMock()
|
||||
client.chat.completions.create.side_effect = RateLimitError("rate limited", response=MagicMock(status_code=429), body=None)
|
||||
request = OpenAICompatibleRequest(messages=[{"role": "user", "content": "ping"}], model="m")
|
||||
with pytest.raises(ProviderError) as exc_info:
|
||||
send_openai_compatible(client, request, capabilities=caps)
|
||||
assert exc_info.value.kind == "rate_limit"
|
||||
result = send_openai_compatible(client, request, capabilities=caps)
|
||||
assert isinstance(result, Result)
|
||||
assert not result.ok
|
||||
assert result.errors[0].kind == ErrorKind.RATE_LIMIT
|
||||
|
||||
def test_normalized_response_is_frozen_dataclass() -> None:
|
||||
from dataclasses import FrozenInstanceError
|
||||
|
||||
@@ -39,12 +39,12 @@ def test_qwen_tool_format_translation() -> None:
|
||||
assert "parameters" in ds_tools[0]
|
||||
|
||||
def test_qwen_error_classification() -> None:
|
||||
from src.ai_client import ProviderError
|
||||
from src.result_types import ErrorKind
|
||||
from src.qwen_adapter import classify_dashscope_error
|
||||
from dashscope.common.error import AuthenticationError
|
||||
err = classify_dashscope_error(AuthenticationError("bad key"))
|
||||
assert err.kind == "auth"
|
||||
assert err.provider == "qwen"
|
||||
assert err.kind == ErrorKind.AUTH
|
||||
assert err.source == "qwen_adapter"
|
||||
|
||||
def test_list_qwen_models_returns_hardcoded_registry() -> None:
|
||||
from src.ai_client import _list_qwen_models
|
||||
|
||||
Reference in New Issue
Block a user