File size: 4,153 Bytes
7fa9d90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee36c8e
7fa9d90
 
 
 
 
 
ee36c8e
7fa9d90
 
 
 
 
 
ee36c8e
7fa9d90
 
 
 
 
 
 
ee36c8e
 
7fa9d90
 
93d86d5
7fa9d90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import aiohttp
import struct
import logging
from typing import Tuple

logger = logging.getLogger(__name__)


class TTSClient:
    """Client for Kokoro TTS via Hugging Face Cloud API"""
    
    def __init__(self, api_url: str):
        """
        Initialize TTS client
        
        Args:
            api_url: Base URL for the TTS API (HF_TTS environment variable)
        """
        self.api_url = api_url.rstrip('/')
        logger.info(f"Using cloud TTS API at {self.api_url}")
    
    async def generate(self, text: str, voice: str, speed: float = 1.0) -> Tuple[bytes, float]:
        """
        Generate speech from text
        
        Args:
            text: Text to convert to speech
            voice: Voice identifier (e.g., 'af_heart', 'am_adam')
            speed: Speech speed multiplier (0.5-2.0, default 1.0)
        
        Returns:
            Tuple of (audio_bytes, duration_seconds)
        """
        endpoint = f"{self.api_url}/v1/audio/speech"
        
        logger.debug(f"Generating audio with voice={voice}, speed={speed}, text_length={len(text)}")
        
        async with aiohttp.ClientSession() as session:
            async with session.post(
                endpoint,
                json={
                    "model": "kokoro",
                    "input": text,
                    "voice": voice,
                    "speed": speed
                },
                headers={"Content-Type": "application/json"},
                timeout=aiohttp.ClientTimeout(total=120)
            ) as response:
                if response.status != 200:
                    error_text = await response.text()
                    raise Exception(f"TTS API error ({response.status}): {error_text}")
                
                audio_data = await response.read()
                duration = self._estimate_audio_duration(audio_data)
                
                logger.debug(f"Generated audio: {len(audio_data)} bytes, {duration:.2f}s")
                return audio_data, duration
    
    def _estimate_audio_duration(self, audio_buffer: bytes) -> float:
        """
        Estimate audio duration from WAV buffer
        
        WAV format: 44 byte header, then PCM data
        """
        if len(audio_buffer) < 44:
            # Fallback estimation
            return (len(audio_buffer) - 44) / (2 * 24000)
        
        # Check if it's a valid WAV file (starts with 'RIFF')
        if audio_buffer[:4] != b'RIFF':
            # Fallback estimation
            return (len(audio_buffer) - 44) / (2 * 24000)
        
        try:
            # Parse WAV header
            # Data size at bytes 40-43
            data_size = struct.unpack('<I', audio_buffer[40:44])[0]
            # Sample rate at bytes 24-27
            sample_rate = struct.unpack('<I', audio_buffer[24:28])[0]
            # Bits per sample at bytes 34-35
            bits_per_sample = struct.unpack('<H', audio_buffer[34:36])[0]
            # Number of channels at bytes 22-23
            num_channels = struct.unpack('<H', audio_buffer[22:24])[0]
            
            bytes_per_sample = (bits_per_sample // 8) * num_channels
            num_samples = data_size / bytes_per_sample
            duration = num_samples / sample_rate
            
            return duration
        except Exception as e:
            logger.warning(f"Failed to parse WAV header: {e}, using fallback")
            # Fallback: estimate based on buffer size
            # Assuming 24kHz, 16-bit, mono
            estimated_samples = (len(audio_buffer) - 44) / 2
            return estimated_samples / 24000
    
    @staticmethod
    def list_available_voices() -> list:
        """Return list of available TTS voices"""
        return [
            "af_heart", "af_alloy", "af_aoede", "af_bella", "af_jessica",
            "af_kore", "af_nicole", "af_nova", "af_river", "af_sarah", "af_sky",
            "am_adam", "am_echo", "am_eric", "am_fenrir", "am_liam",
            "am_michael", "am_onyx", "am_puck", "am_santa",
            "bf_emma", "bf_isabella", "bm_george", "bm_lewis",
            "bf_alice", "bf_lily", "bm_daniel", "bm_fable"
        ]