diff --git a/src/ai_client.py b/src/ai_client.py index 9684f243..d068f3a9 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -76,31 +76,6 @@ _history_trunc_limit: int = 8000 # Global event emitter for API lifecycle events events: EventEmitter = EventEmitter() -class ProviderError(Exception): - def __init__(self, kind: str, provider: str, original: Exception) -> None: - """ - [C: src/api_hooks.py:HookServerInstance.__init__, src/mcp_client.py:_DDGParser.__init__, src/mcp_client.py:_TextExtractor.__init__] - """ - self.kind = kind - self.provider = provider - self.original = original - super().__init__(str(original)) - - def ui_message(self) -> str: - """ - [C: src/app_controller.py:AppController._handle_request_event, src/app_controller.py:_api_generate] - """ - labels = { - "quota": "QUOTA EXHAUSTED", - "rate_limit": "RATE LIMITED", - "auth": "AUTH / API KEY ERROR", - "balance": "BALANCE / BILLING ERROR", - "network": "NETWORK / CONNECTION ERROR", - "unknown": "API ERROR", - } - label = labels.get(self.kind, "API ERROR") - return f"[{self.provider.upper()} {label}]\n\n{self.original}" - #region: Provider Configuration def set_model_params(temp: float, max_tok: int, trunc_limit: int = 8000, top_p: float = 1.0) -> None: @@ -1451,9 +1426,6 @@ def _send_anthropic_result(md_content: str, user_message: str, base_dir: str, fi res = final_text if final_text.strip() else "(No text returned by the model)" if monitor.enabled: monitor.end_component("ai_client._send_anthropic") return Result(data=res) - except ProviderError: - if monitor.enabled: monitor.end_component("ai_client._send_anthropic") - raise except Exception as exc: if monitor.enabled: monitor.end_component("ai_client._send_anthropic") return Result(data="", errors=[_classify_anthropic_error(exc, source="ai_client.anthropic")]) diff --git a/src/openai_compatible.py b/src/openai_compatible.py index a484ee38..7fd5be24 100644 --- a/src/openai_compatible.py +++ b/src/openai_compatible.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Optional from openai import OpenAIError, RateLimitError, AuthenticationError, PermissionDeniedError, APIConnectionError, APIStatusError, BadRequestError -from src.result_types import ErrorInfo, ErrorKind +from src.result_types import ErrorInfo, ErrorKind, Result @dataclass(frozen=True) class NormalizedResponse: @@ -64,7 +64,7 @@ def send_openai_compatible( request: OpenAICompatibleRequest, *, capabilities: Any, -) -> NormalizedResponse: +) -> Result[str]: kwargs: dict[str, Any] = { "model": request.model, "messages": request.messages, @@ -80,10 +80,12 @@ def send_openai_compatible( kwargs["extra_body"] = request.extra_body try: if request.stream: - return _send_streaming(client, kwargs, request.stream_callback) - return _send_blocking(client, kwargs) + response = _send_streaming(client, kwargs, request.stream_callback) + else: + response = _send_blocking(client, kwargs) + return Result(data=response.text) except OpenAIError as exc: - raise _classify_openai_compatible_error(exc) from exc + return Result(data="", errors=[_classify_openai_compatible_error(exc, source="openai_compatible")]) def _send_blocking(client: Any, kwargs: dict[str, Any]) -> NormalizedResponse: resp = client.chat.completions.create(**kwargs) diff --git a/tests/test_openai_compatible.py b/tests/test_openai_compatible.py index 5b6f11a2..9425dc19 100644 --- a/tests/test_openai_compatible.py +++ b/tests/test_openai_compatible.py @@ -22,15 +22,14 @@ def _mock_completion(text: str = "hello", tool_calls=None, usage_input: int = 10 m.usage.completion_tokens_details = None return m -def test_send_non_streaming_returns_normalized_response(caps: VendorCapabilities) -> None: +def test_send_non_streaming_returns_text_in_result(caps: VendorCapabilities) -> None: client = MagicMock() client.chat.completions.create.return_value = _mock_completion("hi", usage_input=20, usage_output=10) request = OpenAICompatibleRequest(messages=[{"role": "user", "content": "ping"}], model="m", max_tokens=100) - response = send_openai_compatible(client, request, capabilities=caps) - assert response.text == "hi" - assert response.tool_calls == [] - assert response.usage_input_tokens == 20 - assert response.usage_output_tokens == 10 + result = send_openai_compatible(client, request, capabilities=caps) + assert result.ok + assert result.data == "hi" + assert result.errors == [] def test_send_streaming_aggregates_chunks(caps: VendorCapabilities) -> None: client = MagicMock() @@ -42,12 +41,13 @@ def test_send_streaming_aggregates_chunks(caps: VendorCapabilities) -> None: client.chat.completions.create.return_value = iter(chunks) received: list = [] request = OpenAICompatibleRequest(messages=[{"role": "user", "content": "ping"}], model="m", stream=True, stream_callback=received.append) - response = send_openai_compatible(client, request, capabilities=caps) - assert response.text == "hello" + result = send_openai_compatible(client, request, capabilities=caps) + assert result.ok + assert result.data == "hello" assert received == ["hel", "lo"] - assert response.usage_input_tokens == 15 -def test_tool_call_detection_in_response(caps: VendorCapabilities) -> None: +def test_tool_call_detection_in_blocking_response(caps: VendorCapabilities) -> None: + from src.openai_compatible import _send_blocking tool_call = MagicMock() tool_call.id = "call_1" tool_call.function.name = "read_file" @@ -55,8 +55,8 @@ def test_tool_call_detection_in_response(caps: VendorCapabilities) -> None: completion = _mock_completion(text="", tool_calls=[tool_call]) client = MagicMock() client.chat.completions.create.return_value = completion - request = OpenAICompatibleRequest(messages=[{"role": "user", "content": "ping"}], model="m") - response = send_openai_compatible(client, request, capabilities=caps) + kwargs = {"model": "m", "messages": [{"role": "user", "content": "ping"}], "temperature": 0.0, "top_p": 1.0, "max_tokens": 8192, "stream": False} + response = _send_blocking(client, kwargs) assert len(response.tool_calls) == 1 assert response.tool_calls[0]["function"]["name"] == "read_file" assert response.tool_calls[0]["id"] == "call_1" @@ -66,20 +66,21 @@ def test_vision_multimodal_message(caps: VendorCapabilities) -> None: client.chat.completions.create.return_value = _mock_completion("looks like a cat") messages = [{"role": "user", "content": [{"type": "text", "text": "what is this?"}, {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}]}] request = OpenAICompatibleRequest(messages=messages, model="m") - response = send_openai_compatible(client, request, capabilities=caps) + result = send_openai_compatible(client, request, capabilities=caps) sent_messages = client.chat.completions.create.call_args.kwargs["messages"] assert sent_messages[0]["content"] == messages[0]["content"] - assert response.text == "looks like a cat" + assert result.data == "looks like a cat" def test_error_classification_429_to_rate_limit(caps: VendorCapabilities) -> None: from openai import RateLimitError - from src.ai_client import ProviderError + from src.result_types import Result, ErrorKind client = MagicMock() client.chat.completions.create.side_effect = RateLimitError("rate limited", response=MagicMock(status_code=429), body=None) request = OpenAICompatibleRequest(messages=[{"role": "user", "content": "ping"}], model="m") - with pytest.raises(ProviderError) as exc_info: - send_openai_compatible(client, request, capabilities=caps) - assert exc_info.value.kind == "rate_limit" + result = send_openai_compatible(client, request, capabilities=caps) + assert isinstance(result, Result) + assert not result.ok + assert result.errors[0].kind == ErrorKind.RATE_LIMIT def test_normalized_response_is_frozen_dataclass() -> None: from dataclasses import FrozenInstanceError diff --git a/tests/test_qwen_provider.py b/tests/test_qwen_provider.py index a297dfe6..ee14eebc 100644 --- a/tests/test_qwen_provider.py +++ b/tests/test_qwen_provider.py @@ -39,12 +39,12 @@ def test_qwen_tool_format_translation() -> None: assert "parameters" in ds_tools[0] def test_qwen_error_classification() -> None: - from src.ai_client import ProviderError + from src.result_types import ErrorKind from src.qwen_adapter import classify_dashscope_error from dashscope.common.error import AuthenticationError err = classify_dashscope_error(AuthenticationError("bad key")) - assert err.kind == "auth" - assert err.provider == "qwen" + assert err.kind == ErrorKind.AUTH + assert err.source == "qwen_adapter" def test_list_qwen_models_returns_hardcoded_registry() -> None: from src.ai_client import _list_qwen_models