diff --git a/tests/test_grok_provider.py b/tests/test_grok_provider.py new file mode 100644 index 00000000..25f00cef --- /dev/null +++ b/tests/test_grok_provider.py @@ -0,0 +1,28 @@ +from unittest.mock import MagicMock, patch +import pytest +from src import ai_client + +@pytest.fixture(autouse=True) +def _reset_grok_state(): + if hasattr(ai_client, '_grok_client'): + ai_client._grok_client = None + if hasattr(ai_client, '_grok_history'): + ai_client._grok_history = [] + yield + +def test_send_grok_uses_xai_endpoint(monkeypatch: pytest.MonkeyPatch) -> None: + ai_client.set_provider("grok", "grok-2") + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="hi from grok", tool_calls=[]))], + usage=MagicMock(prompt_tokens=10, completion_tokens=5), + ) + with patch("src.ai_client._ensure_grok_client", return_value=mock_client): + result = ai_client._send_grok("system", "user", ".", None, "", False, None, None, None) + assert result == "hi from grok" + assert mock_client.chat.completions.create.called + +def test_grok_2_vision_supports_image() -> None: + from src.vendor_capabilities import get_capabilities + caps = get_capabilities("grok", "grok-2-vision") + assert caps.vision is True \ No newline at end of file diff --git a/tests/test_llama_provider.py b/tests/test_llama_provider.py new file mode 100644 index 00000000..f9f785b1 --- /dev/null +++ b/tests/test_llama_provider.py @@ -0,0 +1,68 @@ +from unittest.mock import MagicMock, patch +import pytest +from src import ai_client + +@pytest.fixture(autouse=True) +def _reset_llama_state(): + if hasattr(ai_client, '_llama_client'): + ai_client._llama_client = None + if hasattr(ai_client, '_llama_history'): + ai_client._llama_history = [] + if hasattr(ai_client, '_llama_base_url'): + ai_client._llama_base_url = "http://localhost:11434/v1" + if hasattr(ai_client, '_llama_api_key'): + ai_client._llama_api_key = "ollama" + yield + +def test_send_llama_ollama_backend(monkeypatch: pytest.MonkeyPatch) -> None: + ai_client._llama_base_url = "http://localhost:11434/v1" + ai_client.set_provider("llama", "llama-3.2-3b-preview") + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="hi from ollama", tool_calls=[]))], + usage=MagicMock(prompt_tokens=5, completion_tokens=3), + ) + with patch("src.ai_client._ensure_llama_client", return_value=mock_client): + result = ai_client._send_llama("system", "user", ".", None, "", False, None, None, None) + assert result == "hi from ollama" + +def test_send_llama_openrouter_backend(monkeypatch: pytest.MonkeyPatch) -> None: + ai_client._llama_base_url = "https://openrouter.ai/api/v1" + ai_client.set_provider("llama", "llama-3.1-70b-versatile") + captured_client = MagicMock() + captured_client.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="hi from openrouter", tool_calls=[]))], + usage=MagicMock(prompt_tokens=5, completion_tokens=3), + ) + with patch("src.ai_client._ensure_llama_client", return_value=captured_client) as ensure: + result = ai_client._send_llama("system", "user", ".", None, "", False, None, None, None) + assert result == "hi from openrouter" + assert ensure.called + +def test_send_llama_custom_url(monkeypatch: pytest.MonkeyPatch) -> None: + ai_client._llama_base_url = "http://my-server:9999/v1" + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="hi from custom", tool_calls=[]))], + usage=MagicMock(prompt_tokens=5, completion_tokens=3), + ) + with patch("src.ai_client._ensure_llama_client", return_value=mock_client): + result = ai_client._send_llama("system", "user", ".", None, "", False, None, None, None) + assert result == "hi from custom" + +def test_llama_model_discovery_unions_ollama_and_openrouter() -> None: + from src.ai_client import _list_llama_models + models = _list_llama_models() + assert "llama-3.1-8b-instant" in models + assert "llama-3.2-11b-vision-preview" in models + assert "llama-3.3-70b-specdec" in models + +def test_llama_3_2_vision_vision_capability() -> None: + from src.vendor_capabilities import get_capabilities + caps = get_capabilities("llama", "llama-3.2-11b-vision-preview") + assert caps.vision is True + +def test_llama_local_backend_cost_tracking_false_for_ollama() -> None: + ai_client._llama_base_url = "http://localhost:11434/v1" + from src.ai_client import _get_llama_cost_tracking + assert _get_llama_cost_tracking() is False \ No newline at end of file