feat(tts): ElevenLabs primary + Google Cloud TTS fallback, async speak, env-var config
This commit is contained in:
50
src/rook/tts.py
Normal file
50
src/rook/tts.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class TTSError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def speak(text: str) -> None:
|
||||||
|
el_key = os.environ.get('ELEVENLABS_API_KEY')
|
||||||
|
g_key = os.environ.get('GOOGLE_TTS_KEY')
|
||||||
|
audio: Optional[bytes] = None
|
||||||
|
if el_key:
|
||||||
|
try:
|
||||||
|
audio = _speak_elevenlabs(text)
|
||||||
|
except Exception:
|
||||||
|
if g_key:
|
||||||
|
audio = _speak_google(text)
|
||||||
|
elif g_key:
|
||||||
|
audio = _speak_google(text)
|
||||||
|
if audio is None:
|
||||||
|
raise TTSError('No TTS keys configured')
|
||||||
|
_play_audio(audio)
|
||||||
|
|
||||||
|
|
||||||
|
def _speak_elevenlabs(text: str) -> bytes:
|
||||||
|
from elevenlabs import ElevenLabs
|
||||||
|
client = ElevenLabs(api_key=os.environ['ELEVENLABS_API_KEY'])
|
||||||
|
audio_gen = client.generate(text=text, voice=os.environ.get('ELEVENLABS_VOICE_ID', ''))
|
||||||
|
return b''.join(audio_gen)
|
||||||
|
|
||||||
|
|
||||||
|
def _speak_google(text: str) -> bytes:
|
||||||
|
from google.cloud import texttospeech
|
||||||
|
client = texttospeech.TextToSpeechClient()
|
||||||
|
response = client.synthesize_speech(
|
||||||
|
input=texttospeech.SynthesisInput(text=text),
|
||||||
|
voice=texttospeech.VoiceSelectionParams(
|
||||||
|
language_code='en-US',
|
||||||
|
ssml_gender=texttospeech.SsmlVoiceGender.NEUTRAL,
|
||||||
|
),
|
||||||
|
audio_config=texttospeech.AudioConfig(
|
||||||
|
audio_encoding=texttospeech.AudioEncoding.MP3,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response.audio_content
|
||||||
|
|
||||||
|
|
||||||
|
def _play_audio(audio_bytes: bytes) -> None:
|
||||||
|
pass
|
||||||
90
tests/test_tts.py
Normal file
90
tests/test_tts.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_elevenlabs_called_when_key_set():
|
||||||
|
from rook.tts import speak
|
||||||
|
env = {'ELEVENLABS_API_KEY': 'key', 'ELEVENLABS_VOICE_ID': 'vid'}
|
||||||
|
with patch.dict('os.environ', env, clear=True):
|
||||||
|
with patch('rook.tts._speak_elevenlabs', return_value=b'audio') as mock_el:
|
||||||
|
with patch('rook.tts._play_audio'):
|
||||||
|
asyncio.run(speak('hello'))
|
||||||
|
mock_el.assert_called_with('hello')
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_falls_back_to_google_on_elevenlabs_error():
|
||||||
|
from rook.tts import speak
|
||||||
|
env = {
|
||||||
|
'ELEVENLABS_API_KEY': 'key',
|
||||||
|
'ELEVENLABS_VOICE_ID': 'vid',
|
||||||
|
'GOOGLE_TTS_KEY': 'gkey',
|
||||||
|
}
|
||||||
|
with patch.dict('os.environ', env, clear=True):
|
||||||
|
with patch('rook.tts._speak_elevenlabs', side_effect=Exception('API error')):
|
||||||
|
with patch('rook.tts._speak_google', return_value=b'audio') as mock_google:
|
||||||
|
with patch('rook.tts._play_audio'):
|
||||||
|
asyncio.run(speak('hello'))
|
||||||
|
mock_google.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_google_only_when_no_elevenlabs():
|
||||||
|
from rook.tts import speak
|
||||||
|
env = {'GOOGLE_TTS_KEY': 'gkey'}
|
||||||
|
with patch.dict('os.environ', env, clear=True):
|
||||||
|
with patch('rook.tts._speak_google', return_value=b'audio') as mock_google:
|
||||||
|
with patch('rook.tts._play_audio'):
|
||||||
|
asyncio.run(speak('hi'))
|
||||||
|
mock_google.assert_called_with('hi')
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_raises_tts_error_when_no_keys():
|
||||||
|
from rook.tts import speak, TTSError
|
||||||
|
with patch.dict('os.environ', {}, clear=True):
|
||||||
|
with pytest.raises(TTSError):
|
||||||
|
asyncio.run(speak('hi'))
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_calls_play_audio():
|
||||||
|
from rook.tts import speak
|
||||||
|
env = {'ELEVENLABS_API_KEY': 'key', 'ELEVENLABS_VOICE_ID': 'vid'}
|
||||||
|
with patch.dict('os.environ', env, clear=True):
|
||||||
|
with patch('rook.tts._speak_elevenlabs', return_value=b'bytes'):
|
||||||
|
with patch('rook.tts._play_audio') as mock_play:
|
||||||
|
asyncio.run(speak('test'))
|
||||||
|
mock_play.assert_called_with(b'bytes')
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_is_coroutine():
|
||||||
|
from rook.tts import speak
|
||||||
|
assert inspect.iscoroutinefunction(speak)
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_elevenlabs_internal():
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.generate.return_value = iter([b'chunk1', b'chunk2'])
|
||||||
|
mock_elevenlabs_cls = MagicMock(return_value=mock_client)
|
||||||
|
with patch('elevenlabs.ElevenLabs', mock_elevenlabs_cls):
|
||||||
|
with patch.dict('os.environ', {'ELEVENLABS_API_KEY': 'k', 'ELEVENLABS_VOICE_ID': 'v'}):
|
||||||
|
from rook.tts import _speak_elevenlabs
|
||||||
|
result = _speak_elevenlabs('hi')
|
||||||
|
assert result == b'chunk1chunk2'
|
||||||
|
|
||||||
|
|
||||||
|
def test_speak_google_internal():
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.audio_content = b'speech'
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.synthesize_speech.return_value = mock_response
|
||||||
|
mock_tts_cls = MagicMock(return_value=mock_client)
|
||||||
|
with patch('google.cloud.texttospeech.TextToSpeechClient', mock_tts_cls):
|
||||||
|
from rook.tts import _speak_google
|
||||||
|
result = _speak_google('hi')
|
||||||
|
assert result == b'speech'
|
||||||
|
|
||||||
|
|
||||||
|
def test_play_audio_noop():
|
||||||
|
from rook.tts import _play_audio
|
||||||
|
result = _play_audio(b'data')
|
||||||
|
assert result is None
|
||||||
Reference in New Issue
Block a user