54 lines
3.0 KiB
Python
54 lines
3.0 KiB
Python
from unittest.mock import MagicMock, patch
|
|
import pytest
|
|
from src import ai_client
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _reset_grok_state():
|
|
if hasattr(ai_client, '_grok_client'):
|
|
ai_client._grok_client = None
|
|
if hasattr(ai_client, '_grok_history'):
|
|
ai_client._grok_history = []
|
|
yield
|
|
|
|
def test_send_grok_uses_xai_endpoint(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
ai_client.set_provider("grok", "grok-2")
|
|
mock_client = MagicMock()
|
|
mock_client.chat.completions.create.return_value = MagicMock(
|
|
choices=[MagicMock(message=MagicMock(content="hi from grok", tool_calls=[]))],
|
|
usage=MagicMock(prompt_tokens=10, completion_tokens=5),
|
|
)
|
|
with patch("src.ai_client._ensure_grok_client", return_value=mock_client):
|
|
result = ai_client._send_grok("system", "user", ".", None, "", False, None, None, None)
|
|
assert result.ok and result.data == "hi from grok"
|
|
assert mock_client.chat.completions.create.called
|
|
|
|
def test_grok_2_vision_supports_image() -> None:
|
|
from src.vendor_capabilities import get_capabilities
|
|
caps = get_capabilities("grok", "grok-2-vision")
|
|
assert caps.vision is True
|
|
|
|
def test_grok_web_search_adds_search_parameters_to_extra_body() -> None:
|
|
"""caps.web_search=True should populate search_parameters.mode=auto in extra_body."""
|
|
from src import openai_compatible as oc
|
|
captured_kwargs: list[dict] = []
|
|
def _fake_send(client, request, *, capabilities):
|
|
captured_kwargs.append({"extra_body": request.extra_body, "model": request.model})
|
|
return MagicMock(text="ok", tool_calls=[], usage_input_tokens=0, usage_output_tokens=0, usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=None)
|
|
with patch.object(oc, "send_openai_compatible", side_effect=_fake_send), \
|
|
patch("src.ai_client._ensure_grok_client", return_value=MagicMock()), \
|
|
patch("src.ai_client._get_deepseek_tools", return_value=[]):
|
|
ai_client._send_grok("system", "user", ".", None, "", False, None, None, None)
|
|
assert any(kw.get("extra_body") is not None and kw["extra_body"].get("search_parameters", {}).get("mode") == "auto" for kw in captured_kwargs), f"web_search extra_body not found in {captured_kwargs}"
|
|
|
|
def test_grok_x_search_adds_x_source_to_extra_body() -> None:
|
|
"""caps.x_search=True should add sources=[{type:x}] to search_parameters."""
|
|
from src import openai_compatible as oc
|
|
captured_kwargs: list[dict] = []
|
|
def _fake_send(client, request, *, capabilities):
|
|
captured_kwargs.append({"extra_body": request.extra_body})
|
|
return MagicMock(text="ok", tool_calls=[], usage_input_tokens=0, usage_output_tokens=0, usage_cache_read_tokens=0, usage_cache_creation_tokens=0, raw_response=None)
|
|
with patch.object(oc, "send_openai_compatible", side_effect=_fake_send), \
|
|
patch("src.ai_client._ensure_grok_client", return_value=MagicMock()), \
|
|
patch("src.ai_client._get_deepseek_tools", return_value=[]):
|
|
ai_client._send_grok("system", "user", ".", None, "", False, None, None, None)
|
|
assert captured_kwargs[0]["extra_body"]["search_parameters"]["sources"] == [{"type": "x"}] |