72 lines
3.2 KiB
Python
72 lines
3.2 KiB
Python
from unittest.mock import MagicMock, patch
|
|
import pytest
|
|
from src import ai_client
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _reset_llama_state():
|
|
if hasattr(ai_client, '_llama_client'):
|
|
ai_client._llama_client = None
|
|
if hasattr(ai_client, '_llama_history'):
|
|
ai_client._llama_history = []
|
|
if hasattr(ai_client, '_llama_base_url'):
|
|
ai_client._llama_base_url = "http://localhost:11434/v1"
|
|
if hasattr(ai_client, '_llama_api_key'):
|
|
ai_client._llama_api_key = "ollama"
|
|
yield
|
|
|
|
def test_send_llama_ollama_backend(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
ai_client._llama_base_url = "http://localhost:11434/v1"
|
|
ai_client.set_provider("llama", "llama-3.2-3b-preview")
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"message": {"role": "assistant", "content": "hi from ollama"},
|
|
"done": True,
|
|
}
|
|
mock_requests = MagicMock()
|
|
mock_requests.post.return_value = mock_response
|
|
with patch("src.ai_client._require_warmed", return_value=mock_requests):
|
|
result = ai_client._send_llama("system", "user", ".", None, "", False, None, None, None)
|
|
assert result.ok and "hi from ollama" in result.data
|
|
called_url = mock_requests.post.call_args.args[0]
|
|
assert called_url == "http://localhost:11434/api/chat"
|
|
|
|
def test_send_llama_openrouter_backend(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
ai_client._llama_base_url = "https://openrouter.ai/api/v1"
|
|
ai_client.set_provider("llama", "llama-3.1-70b-versatile")
|
|
captured_client = MagicMock()
|
|
captured_client.chat.completions.create.return_value = MagicMock(
|
|
choices=[MagicMock(message=MagicMock(content="hi from openrouter", tool_calls=[]))],
|
|
usage=MagicMock(prompt_tokens=5, completion_tokens=3),
|
|
)
|
|
with patch("src.ai_client._ensure_llama_client", return_value=captured_client) as ensure:
|
|
result = ai_client._send_llama("system", "user", ".", None, "", False, None, None, None)
|
|
assert result.ok and result.data == "hi from openrouter"
|
|
assert ensure.called
|
|
|
|
def test_send_llama_custom_url(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
ai_client._llama_base_url = "http://my-server:9999/v1"
|
|
mock_client = MagicMock()
|
|
mock_client.chat.completions.create.return_value = MagicMock(
|
|
choices=[MagicMock(message=MagicMock(content="hi from custom", tool_calls=[]))],
|
|
usage=MagicMock(prompt_tokens=5, completion_tokens=3),
|
|
)
|
|
with patch("src.ai_client._ensure_llama_client", return_value=mock_client):
|
|
result = ai_client._send_llama("system", "user", ".", None, "", False, None, None, None)
|
|
assert result.ok and result.data == "hi from custom"
|
|
|
|
def test_llama_model_discovery_unions_ollama_and_openrouter() -> None:
|
|
from src.ai_client import _list_llama_models
|
|
models = _list_llama_models()
|
|
assert "llama-3.1-8b-instant" in models
|
|
assert "llama-3.2-11b-vision-preview" in models
|
|
assert "llama-3.3-70b-specdec" in models
|
|
|
|
def test_llama_3_2_vision_vision_capability() -> None:
|
|
from src.vendor_capabilities import get_capabilities
|
|
caps = get_capabilities("llama", "llama-3.2-11b-vision-preview")
|
|
assert caps.vision is True
|
|
|
|
def test_llama_local_backend_cost_tracking_false_for_ollama() -> None:
|
|
ai_client._llama_base_url = "http://localhost:11434/v1"
|
|
from src.ai_client import _get_llama_cost_tracking
|
|
assert _get_llama_cost_tracking() is False |