File size: 5,591 Bytes
7219c67
 
a7c6f3a
7219c67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7c6f3a
 
 
 
7219c67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7c6f3a
 
 
 
 
7219c67
 
a7c6f3a
7219c67
a7c6f3a
 
2f7f045
d4f61bc
7219c67
 
 
a7c6f3a
7219c67
a7c6f3a
 
7219c67
 
 
a7c6f3a
7219c67
 
a7c6f3a
7219c67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7c6f3a
 
 
7219c67
a7c6f3a
7219c67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
TTS Handler for Text Story module.
Handles voice generation and audio processing using Kokoro TTS.
"""

import os
import logging
import aiohttp
from pydub import AudioSegment

logger = logging.getLogger(__name__)


class TTSHandler:
    """
    Handles Text-to-Speech generation using Kokoro TTS.
    Also handles silence trimming and duration detection.
    """
    
    def __init__(self):
        self.tts_url = os.getenv("HF_TTS", "")
        if not self.tts_url:
            logger.warning("TTSHandler: HF_TTS not configured, TTS will fail")
        else:
            # Remove trailing slash
            self.tts_url = self.tts_url.rstrip('/')
            logger.info(f"TTSHandler: Using TTS endpoint {self.tts_url}")
    
    async def generate_tts(self, text: str, voice: str, output_path: str) -> float:
        """
        Generate TTS audio for text.
        
        Args:
            text: Text to speak
            voice: Kokoro voice ID (e.g., 'af_heart', 'am_fenrir')
            output_path: Path to save WAV file
            
        Returns:
            Duration in seconds
        """
        if not self.tts_url:
            raise ValueError("HF_TTS environment variable not set")
        
        # Correct endpoint format (same as video_creator)
        endpoint = f"{self.tts_url}/v1/audio/speech"
        
        logger.info(f"TTS: Generating voice '{voice}' for: {text[:50]}...")
        
        try:
            async with aiohttp.ClientSession() as session:
                # Correct payload format for Kokoro TTS
                payload = {
                    "model": "kokoro",
                    "input": text,
                    "voice": voice,
                    "speed": 1.4  # Faster voice for engaging content
                }
                
                async with session.post(
                    endpoint,
                    json=payload,
                    headers={"Content-Type": "application/json"},
                    timeout=aiohttp.ClientTimeout(total=120)
                ) as response:
                    if response.status != 200:
                        error_text = await response.text()
                        raise Exception(f"TTS API error ({response.status}): {error_text}")
                    
                    audio_data = await response.read()
                    logger.info(f"TTS: Received {len(audio_data)} bytes")
                    
                    # Save raw audio
                    temp_path = output_path + ".temp.wav"
                    with open(temp_path, "wb") as f:
                        f.write(audio_data)
                    
                    # Trim silence and get duration
                    duration = self.trim_silence(temp_path, output_path)
                    
                    # Cleanup temp
                    if os.path.exists(temp_path):
                        os.remove(temp_path)
                    
                    logger.info(f"TTS: Generated {len(text)} chars, {duration:.2f}s")
                    return duration
                    
        except aiohttp.ClientError as e:
            logger.error(f"TTS network error: {type(e).__name__}: {e}")
            raise Exception(f"TTS network error: {e}")
        except Exception as e:
            logger.error(f"TTS generation failed: {type(e).__name__}: {e}")
            raise
    
    def trim_silence(self, input_path: str, output_path: str, 
                     silence_thresh: int = -40, min_silence_len: int = 100) -> float:
        """
        Trim leading and trailing silence from audio.
        
        Args:
            input_path: Input audio file
            output_path: Output audio file
            silence_thresh: Silence threshold in dB
            min_silence_len: Minimum silence length in ms
            
        Returns:
            Duration of trimmed audio in seconds
        """
        try:
            audio = AudioSegment.from_file(input_path)
            
            # Detect non-silent parts
            from pydub.silence import detect_nonsilent
            
            nonsilent_ranges = detect_nonsilent(
                audio, 
                min_silence_len=min_silence_len,
                silence_thresh=silence_thresh
            )
            
            if nonsilent_ranges:
                # Get start and end of non-silent audio
                start_ms = max(0, nonsilent_ranges[0][0] - 50)  # Add 50ms padding
                end_ms = min(len(audio), nonsilent_ranges[-1][1] + 100)  # Add 100ms padding
                
                trimmed = audio[start_ms:end_ms]
            else:
                # No speech detected, use original
                trimmed = audio
            
            # Export trimmed audio
            trimmed.export(output_path, format="wav")
            
            duration = len(trimmed) / 1000.0  # Convert ms to seconds
            return duration
            
        except Exception as e:
            logger.error(f"Silence trim failed: {e}")
            # Fallback: just copy the file
            import shutil
            shutil.copy2(input_path, output_path)
            audio = AudioSegment.from_file(output_path)
            return len(audio) / 1000.0
    
    def get_duration(self, audio_path: str) -> float:
        """Get duration of audio file in seconds."""
        try:
            audio = AudioSegment.from_file(audio_path)
            return len(audio) / 1000.0
        except Exception as e:
            logger.error(f"Failed to get audio duration: {e}")
            return 2.0  # Default fallback