diff --git a/tests/test_ai_client_result.py b/tests/test_ai_client_result.py new file mode 100644 index 00000000..6a863300 --- /dev/null +++ b/tests/test_ai_client_result.py @@ -0,0 +1,59 @@ +from unittest.mock import MagicMock, patch +import pytest +from src import ai_client +from src.result_types import Result, ErrorInfo, ErrorKind + + +def test_send_result_public_api_returns_result() -> None: + with patch.object(ai_client, "set_provider"): + with patch.object(ai_client, "_send_anthropic_result", return_value=Result(data="hello")) as mock_send: + r = ai_client.send_result("system", "user") + assert isinstance(r, Result) + assert r.ok + assert r.data == "hello" + + +def test_send_deprecated_emits_warning() -> None: + import warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with patch.object(ai_client, "set_provider"): + with patch.object(ai_client, "_send_anthropic_result", return_value=Result(data="hi")): + result = ai_client.send("system", "user") + assert result == "hi" + assert any(issubclass(x.category, DeprecationWarning) for x in w) + + +def test_send_result_preserves_errors() -> None: + err = ErrorInfo(kind=ErrorKind.RATE_LIMIT, message="slow down", source="test") + with patch.object(ai_client, "set_provider"): + with patch.object(ai_client, "_send_anthropic_result", return_value=Result(data="", errors=[err])): + r = ai_client.send_result("system", "user") + assert not r.ok + assert r.errors == [err] + + +def test_send_extracts_data_from_result() -> None: + with patch.object(ai_client, "set_provider"): + with patch.object(ai_client, "_send_anthropic_result", return_value=Result(data="result text")): + result = ai_client.send("system", "user") + assert result == "result text" + + +def test_send_returns_empty_string_on_error_result() -> None: + err = ErrorInfo(kind=ErrorKind.AUTH, message="bad key", source="test") + with patch.object(ai_client, "set_provider"): + with patch.object(ai_client, "_send_anthropic_result", return_value=Result(data="", errors=[err])): + result = ai_client.send("system", "user") + assert result == "" + + +def test_classify_gemini_error_returns_error_info() -> None: + from src.ai_client import _classify_gemini_error + class FakeRateLimitError(Exception): pass + e = FakeRateLimitError("rate limited") + info = _classify_gemini_error(e, source="test.gemini") + assert isinstance(info, ErrorInfo) + assert info.kind == ErrorKind.RATE_LIMIT + assert info.source == "test.gemini" + assert info.original is e diff --git a/tests/test_deprecation_warnings.py b/tests/test_deprecation_warnings.py new file mode 100644 index 00000000..7a7a1638 --- /dev/null +++ b/tests/test_deprecation_warnings.py @@ -0,0 +1,25 @@ +import warnings +from unittest.mock import patch +from src import ai_client +from src.result_types import Result + + +def test_send_deprecated_warning_emitted_once_per_site() -> None: + with patch.object(ai_client, "set_provider"): + with patch.object(ai_client, "_send_anthropic_result", return_value=Result(data="x")): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ai_client.send("s", "u") + ai_client.send("s", "u") + deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecation_warnings) >= 1 + + +def test_send_result_does_not_emit_deprecation() -> None: + with patch.object(ai_client, "set_provider"): + with patch.object(ai_client, "_send_anthropic_result", return_value=Result(data="x")): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + ai_client.send_result("s", "u") + deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(deprecation_warnings) == 0