diff --git a/src/rook/tts.py b/src/rook/tts.py new file mode 100644 index 0000000..7e248e6 --- /dev/null +++ b/src/rook/tts.py @@ -0,0 +1,50 @@ +import os +from typing import Optional + + +class TTSError(Exception): + pass + + +async def speak(text: str) -> None: + el_key = os.environ.get('ELEVENLABS_API_KEY') + g_key = os.environ.get('GOOGLE_TTS_KEY') + audio: Optional[bytes] = None + if el_key: + try: + audio = _speak_elevenlabs(text) + except Exception: + if g_key: + audio = _speak_google(text) + elif g_key: + audio = _speak_google(text) + if audio is None: + raise TTSError('No TTS keys configured') + _play_audio(audio) + + +def _speak_elevenlabs(text: str) -> bytes: + from elevenlabs import ElevenLabs + client = ElevenLabs(api_key=os.environ['ELEVENLABS_API_KEY']) + audio_gen = client.generate(text=text, voice=os.environ.get('ELEVENLABS_VOICE_ID', '')) + return b''.join(audio_gen) + + +def _speak_google(text: str) -> bytes: + from google.cloud import texttospeech + client = texttospeech.TextToSpeechClient() + response = client.synthesize_speech( + input=texttospeech.SynthesisInput(text=text), + voice=texttospeech.VoiceSelectionParams( + language_code='en-US', + ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL, + ), + audio_config=texttospeech.AudioConfig( + audio_encoding=texttospeech.AudioEncoding.MP3, + ), + ) + return response.audio_content + + +def _play_audio(audio_bytes: bytes) -> None: + pass diff --git a/tests/test_tts.py b/tests/test_tts.py new file mode 100644 index 0000000..d47f4c3 --- /dev/null +++ b/tests/test_tts.py @@ -0,0 +1,90 @@ +import asyncio +import inspect +import pytest +from unittest.mock import patch, MagicMock + + +def test_speak_elevenlabs_called_when_key_set(): + from rook.tts import speak + env = {'ELEVENLABS_API_KEY': 'key', 'ELEVENLABS_VOICE_ID': 'vid'} + with patch.dict('os.environ', env, clear=True): + with patch('rook.tts._speak_elevenlabs', return_value=b'audio') as mock_el: + with patch('rook.tts._play_audio'): + asyncio.run(speak('hello')) + mock_el.assert_called_with('hello') + + +def test_speak_falls_back_to_google_on_elevenlabs_error(): + from rook.tts import speak + env = { + 'ELEVENLABS_API_KEY': 'key', + 'ELEVENLABS_VOICE_ID': 'vid', + 'GOOGLE_TTS_KEY': 'gkey', + } + with patch.dict('os.environ', env, clear=True): + with patch('rook.tts._speak_elevenlabs', side_effect=Exception('API error')): + with patch('rook.tts._speak_google', return_value=b'audio') as mock_google: + with patch('rook.tts._play_audio'): + asyncio.run(speak('hello')) + mock_google.assert_called() + + +def test_speak_google_only_when_no_elevenlabs(): + from rook.tts import speak + env = {'GOOGLE_TTS_KEY': 'gkey'} + with patch.dict('os.environ', env, clear=True): + with patch('rook.tts._speak_google', return_value=b'audio') as mock_google: + with patch('rook.tts._play_audio'): + asyncio.run(speak('hi')) + mock_google.assert_called_with('hi') + + +def test_speak_raises_tts_error_when_no_keys(): + from rook.tts import speak, TTSError + with patch.dict('os.environ', {}, clear=True): + with pytest.raises(TTSError): + asyncio.run(speak('hi')) + + +def test_speak_calls_play_audio(): + from rook.tts import speak + env = {'ELEVENLABS_API_KEY': 'key', 'ELEVENLABS_VOICE_ID': 'vid'} + with patch.dict('os.environ', env, clear=True): + with patch('rook.tts._speak_elevenlabs', return_value=b'bytes'): + with patch('rook.tts._play_audio') as mock_play: + asyncio.run(speak('test')) + mock_play.assert_called_with(b'bytes') + + +def test_speak_is_coroutine(): + from rook.tts import speak + assert inspect.iscoroutinefunction(speak) + + +def test_speak_elevenlabs_internal(): + mock_client = MagicMock() + mock_client.generate.return_value = iter([b'chunk1', b'chunk2']) + mock_elevenlabs_cls = MagicMock(return_value=mock_client) + with patch('elevenlabs.ElevenLabs', mock_elevenlabs_cls): + with patch.dict('os.environ', {'ELEVENLABS_API_KEY': 'k', 'ELEVENLABS_VOICE_ID': 'v'}): + from rook.tts import _speak_elevenlabs + result = _speak_elevenlabs('hi') + assert result == b'chunk1chunk2' + + +def test_speak_google_internal(): + mock_response = MagicMock() + mock_response.audio_content = b'speech' + mock_client = MagicMock() + mock_client.synthesize_speech.return_value = mock_response + mock_tts_cls = MagicMock(return_value=mock_client) + with patch('google.cloud.texttospeech.TextToSpeechClient', mock_tts_cls): + from rook.tts import _speak_google + result = _speak_google('hi') + assert result == b'speech' + + +def test_play_audio_noop(): + from rook.tts import _play_audio + result = _play_audio(b'data') + assert result is None