91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
import asyncio
|
|
import inspect
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
|
|
def test_speak_elevenlabs_called_when_key_set():
|
|
from rook.tts import speak
|
|
env = {'ELEVENLABS_API_KEY': 'key', 'ELEVENLABS_VOICE_ID': 'vid'}
|
|
with patch.dict('os.environ', env, clear=True):
|
|
with patch('rook.tts._speak_elevenlabs', return_value=b'audio') as mock_el:
|
|
with patch('rook.tts._play_audio'):
|
|
asyncio.run(speak('hello'))
|
|
mock_el.assert_called_with('hello')
|
|
|
|
|
|
def test_speak_falls_back_to_google_on_elevenlabs_error():
|
|
from rook.tts import speak
|
|
env = {
|
|
'ELEVENLABS_API_KEY': 'key',
|
|
'ELEVENLABS_VOICE_ID': 'vid',
|
|
'GOOGLE_TTS_KEY': 'gkey',
|
|
}
|
|
with patch.dict('os.environ', env, clear=True):
|
|
with patch('rook.tts._speak_elevenlabs', side_effect=Exception('API error')):
|
|
with patch('rook.tts._speak_google', return_value=b'audio') as mock_google:
|
|
with patch('rook.tts._play_audio'):
|
|
asyncio.run(speak('hello'))
|
|
mock_google.assert_called()
|
|
|
|
|
|
def test_speak_google_only_when_no_elevenlabs():
|
|
from rook.tts import speak
|
|
env = {'GOOGLE_TTS_KEY': 'gkey'}
|
|
with patch.dict('os.environ', env, clear=True):
|
|
with patch('rook.tts._speak_google', return_value=b'audio') as mock_google:
|
|
with patch('rook.tts._play_audio'):
|
|
asyncio.run(speak('hi'))
|
|
mock_google.assert_called_with('hi')
|
|
|
|
|
|
def test_speak_raises_tts_error_when_no_keys():
|
|
from rook.tts import speak, TTSError
|
|
with patch.dict('os.environ', {}, clear=True):
|
|
with pytest.raises(TTSError):
|
|
asyncio.run(speak('hi'))
|
|
|
|
|
|
def test_speak_calls_play_audio():
|
|
from rook.tts import speak
|
|
env = {'ELEVENLABS_API_KEY': 'key', 'ELEVENLABS_VOICE_ID': 'vid'}
|
|
with patch.dict('os.environ', env, clear=True):
|
|
with patch('rook.tts._speak_elevenlabs', return_value=b'bytes'):
|
|
with patch('rook.tts._play_audio') as mock_play:
|
|
asyncio.run(speak('test'))
|
|
mock_play.assert_called_with(b'bytes')
|
|
|
|
|
|
def test_speak_is_coroutine():
|
|
from rook.tts import speak
|
|
assert inspect.iscoroutinefunction(speak)
|
|
|
|
|
|
def test_speak_elevenlabs_internal():
|
|
mock_client = MagicMock()
|
|
mock_client.generate.return_value = iter([b'chunk1', b'chunk2'])
|
|
mock_elevenlabs_cls = MagicMock(return_value=mock_client)
|
|
with patch('elevenlabs.ElevenLabs', mock_elevenlabs_cls):
|
|
with patch.dict('os.environ', {'ELEVENLABS_API_KEY': 'k', 'ELEVENLABS_VOICE_ID': 'v'}):
|
|
from rook.tts import _speak_elevenlabs
|
|
result = _speak_elevenlabs('hi')
|
|
assert result == b'chunk1chunk2'
|
|
|
|
|
|
def test_speak_google_internal():
|
|
mock_response = MagicMock()
|
|
mock_response.audio_content = b'speech'
|
|
mock_client = MagicMock()
|
|
mock_client.synthesize_speech.return_value = mock_response
|
|
mock_tts_cls = MagicMock(return_value=mock_client)
|
|
with patch('google.cloud.texttospeech.TextToSpeechClient', mock_tts_cls):
|
|
from rook.tts import _speak_google
|
|
result = _speak_google('hi')
|
|
assert result == b'speech'
|
|
|
|
|
|
def test_play_audio_noop():
|
|
from rook.tts import _play_audio
|
|
result = _play_audio(b'data')
|
|
assert result is None
|