Harshil748 commited on
Commit
51b23f6
·
1 Parent(s): b0dbe7f

Add voice cloning endpoint and XTTS model integration

Browse files
Files changed (3) hide show
  1. ARCHITECTURE.md +238 -0
  2. src/api.py +100 -0
  3. src/engine.py +106 -11
ARCHITECTURE.md ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🏗️ VoiceAPI System Architecture
2
+
3
+ ## High-Level System Diagram
4
+
5
+ ```mermaid
6
+ flowchart TB
7
+ subgraph Client["📱 Client Applications"]
8
+ Web["🌐 Web App"]
9
+ Mobile["📱 Mobile App"]
10
+ Healthcare["🏥 Healthcare Assistant"]
11
+ end
12
+
13
+ subgraph API["🚀 FastAPI Server (Port 7860)"]
14
+ Endpoint["/Get_Inference API"]
15
+ LangRouter["Language Router"]
16
+ end
17
+
18
+ subgraph Engine["⚙️ TTS Engine"]
19
+ Normalizer["Text Normalizer"]
20
+ Tokenizer["Tokenizer"]
21
+ StyleProc["Style Processor"]
22
+
23
+ subgraph Models["�� Model Types"]
24
+ VITS["VITS JIT Models\n(.pt files)"]
25
+ Coqui["Coqui TTS\n(.pth files)"]
26
+ MMS["Facebook MMS\n(HuggingFace)"]
27
+ end
28
+ end
29
+
30
+ subgraph Languages["🗣️ 11 Languages"]
31
+ Hindi["🇮🇳 Hindi"]
32
+ Bengali["🇧🇩 Bengali"]
33
+ Marathi["Marathi"]
34
+ Telugu["Telugu"]
35
+ Kannada["Kannada"]
36
+ Gujarati["Gujarati"]
37
+ Bhojpuri["Bhojpuri"]
38
+ Others["+ 4 more"]
39
+ end
40
+
41
+ subgraph Output["🔊 Audio Output"]
42
+ WAV["WAV File\n22050 Hz"]
43
+ end
44
+
45
+ Client -->|HTTP GET/POST| Endpoint
46
+ Endpoint -->|text, lang| LangRouter
47
+ LangRouter --> Normalizer
48
+ Normalizer --> Tokenizer
49
+ Tokenizer --> Models
50
+ VITS --> StyleProc
51
+ Coqui --> StyleProc
52
+ MMS --> StyleProc
53
+ StyleProc --> WAV
54
+ WAV -->|Response| Client
55
+
56
+ Models --> Languages
57
+ ```
58
+
59
+ ## Data Flow Diagram
60
+
61
+ ```mermaid
62
+ sequenceDiagram
63
+ participant C as Client
64
+ participant A as API Server
65
+ participant E as TTS Engine
66
+ participant M as Model
67
+ participant S as Style Processor
68
+
69
+ C->>A: GET /Get_Inference?text=नमस्ते&lang=hindi
70
+ A->>A: Parse parameters
71
+ A->>E: synthesize(text, voice)
72
+ E->>E: Normalize text
73
+ E->>E: Tokenize to IDs
74
+ E->>M: Load model (if not cached)
75
+ M->>M: Forward pass (inference)
76
+ M-->>E: Raw audio tensor
77
+ E->>S: Apply style (pitch, speed, energy)
78
+ S-->>E: Processed audio
79
+ E-->>A: TTSOutput (audio, sample_rate)
80
+ A->>A: Convert to WAV bytes
81
+ A-->>C: audio/wav response
82
+ ```
83
+
84
+ ## Model Architecture
85
+
86
+ ```mermaid
87
+ flowchart LR
88
+ subgraph Input["📝 Input"]
89
+ Text["Text Input"]
90
+ end
91
+
92
+ subgraph TextEncoder["🔤 Text Encoder"]
93
+ Embed["Character Embedding"]
94
+ TransEnc["Transformer Encoder\n(6 layers, 192 hidden)"]
95
+ end
96
+
97
+ subgraph FlowModel["🌊 Flow Model"]
98
+ Prior["Prior Encoder"]
99
+ Flow["Normalizing Flow"]
100
+ Duration["Duration Predictor"]
101
+ end
102
+
103
+ subgraph Decoder["🔊 HiFi-GAN Decoder"]
104
+ Upsample["Upsampling Layers"]
105
+ ResBlocks["Residual Blocks"]
106
+ Output["Audio Waveform"]
107
+ end
108
+
109
+ Text --> Embed --> TransEnc
110
+ TransEnc --> Prior
111
+ TransEnc --> Duration
112
+ Prior --> Flow
113
+ Duration --> Flow
114
+ Flow --> Upsample --> ResBlocks --> Output
115
+ ```
116
+
117
+ ## Training Pipeline
118
+
119
+ ```mermaid
120
+ flowchart TD
121
+ subgraph Data["📊 Training Data"]
122
+ OpenSLR["OpenSLR Datasets"]
123
+ CommonVoice["Mozilla Common Voice"]
124
+ IndicTTS["IndicTTS Corpus"]
125
+ AI4Bharat["AI4Bharat Indic-Voices"]
126
+ end
127
+
128
+ subgraph Prep["🔧 Data Preparation"]
129
+ Download["Download Audio"]
130
+ Normalize["Normalize to 22050 Hz"]
131
+ Transcript["Generate Transcripts"]
132
+ Split["Train/Val Split"]
133
+ end
134
+
135
+ subgraph Train["🏋️ Training"]
136
+ Config["Load Config YAML"]
137
+ VITS_Train["VITS Training\n(1000 epochs)"]
138
+ Checkpoint["Save Checkpoints"]
139
+ end
140
+
141
+ subgraph Export["📦 Export"]
142
+ JIT["JIT Trace Model"]
143
+ Chars["Generate chars.txt"]
144
+ Package["Package for Inference"]
145
+ end
146
+
147
+ Data --> Download --> Normalize --> Transcript --> Split
148
+ Split --> Config --> VITS_Train --> Checkpoint
149
+ Checkpoint --> JIT --> Chars --> Package
150
+ ```
151
+
152
+ ## Deployment Architecture
153
+
154
+ ```mermaid
155
+ flowchart TB
156
+ subgraph HF["☁️ HuggingFace Infrastructure"]
157
+ subgraph Space["🚀 HF Space (Docker)"]
158
+ Docker["Docker Container"]
159
+ FastAPI["FastAPI Server\n:7860"]
160
+ Models_Dir["models/ directory"]
161
+ end
162
+
163
+ subgraph ModelRepo["📦 Model Repository"]
164
+ ModelFiles["Harshil748/VoiceAPI-Models\n(~8GB)"]
165
+ end
166
+ end
167
+
168
+ subgraph External["🌐 External Services"]
169
+ MMS_HF["facebook/mms-tts-guj\n(Gujarati)"]
170
+ end
171
+
172
+ User["👤 User"] -->|HTTPS| FastAPI
173
+ Docker -->|Build time| ModelFiles
174
+ FastAPI -->|Runtime| MMS_HF
175
+ Models_Dir -.->|Loaded from| ModelFiles
176
+ ```
177
+
178
+ ## Voice Configuration Map
179
+
180
+ ```mermaid
181
+ mindmap
182
+ root((VoiceAPI))
183
+ Hindi
184
+ hi_male
185
+ hi_female
186
+ Bengali
187
+ bn_male
188
+ bn_female
189
+ Marathi
190
+ mr_male
191
+ mr_female
192
+ Telugu
193
+ te_male
194
+ te_female
195
+ Kannada
196
+ kn_male
197
+ kn_female
198
+ Gujarati
199
+ gu_mms
200
+ Bhojpuri
201
+ bho_male
202
+ bho_female
203
+ Chhattisgarhi
204
+ hne_male
205
+ hne_female
206
+ Maithili
207
+ mai_male
208
+ mai_female
209
+ Magahi
210
+ mag_male
211
+ mag_female
212
+ English
213
+ en_male
214
+ en_female
215
+ ```
216
+
217
+ ## Component Interaction
218
+
219
+ | Component | File | Purpose |
220
+ |-----------|------|---------|
221
+ | API Server | `src/api.py` | FastAPI REST endpoints |
222
+ | TTS Engine | `src/engine.py` | Model loading & inference |
223
+ | Tokenizer | `src/tokenizer.py` | Text → Token IDs |
224
+ | Config | `src/config.py` | Language & model configs |
225
+ | Model Loader | `src/model_loader.py` | Model file management |
226
+
227
+ ## Performance Characteristics
228
+
229
+ | Metric | Value |
230
+ |--------|-------|
231
+ | Inference Time | ~200-500ms per sentence |
232
+ | Model Load Time | ~2-5s per voice |
233
+ | Audio Sample Rate | 22050 Hz (16000 Hz for Gujarati) |
234
+ | Supported Formats | WAV |
235
+ | Concurrent Requests | Limited by memory |
236
+
237
+ ---
238
+ *Built for Voice Tech for All Hackathon*
src/api.py CHANGED
@@ -37,6 +37,17 @@ from .config import (
37
  STYLE_PRESETS,
38
  )
39
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Language name to voice key mapping (for hackathon API)
41
  LANG_TO_VOICE = {
42
  "hindi": "hi_female",
@@ -152,6 +163,16 @@ class SynthesizeResponse(BaseModel):
152
  inference_time: float
153
 
154
 
 
 
 
 
 
 
 
 
 
 
155
  class VoiceInfo(BaseModel):
156
  """Information about a voice"""
157
 
@@ -332,6 +353,85 @@ async def synthesize_stream(request: SynthesizeRequest):
332
  raise HTTPException(status_code=500, detail=str(e))
333
 
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  @app.get("/synthesize/get")
336
  async def synthesize_get(
337
  text: str = Query(
 
37
  STYLE_PRESETS,
38
  )
39
 
40
+ # Language mapping for XTTS voice cloning
41
+ XTTS_LANG_MAP = {
42
+ "english": "en",
43
+ "hindi": "hi",
44
+ "bengali": "bn",
45
+ "gujarati": "gu",
46
+ "marathi": "mr",
47
+ "telugu": "te",
48
+ "kannada": "kn",
49
+ }
50
+
51
  # Language name to voice key mapping (for hackathon API)
52
  LANG_TO_VOICE = {
53
  "hindi": "hi_female",
 
163
  inference_time: float
164
 
165
 
166
+ class CloneResponse(BaseModel):
167
+ """Response metadata for voice cloning"""
168
+
169
+ success: bool
170
+ duration: float
171
+ sample_rate: int
172
+ inference_time: float
173
+ language: str
174
+
175
+
176
  class VoiceInfo(BaseModel):
177
  """Information about a voice"""
178
 
 
353
  raise HTTPException(status_code=500, detail=str(e))
354
 
355
 
356
+ @app.post("/clone", response_class=Response)
357
+ async def clone_voice(
358
+ text: str = Query(..., description="Text to synthesize with cloned voice"),
359
+ lang: str = Query(
360
+ "english",
361
+ description="Language name (english, hindi, bengali, gujarati, marathi, telugu, kannada)",
362
+ ),
363
+ speed: float = Query(1.0, description="Speech speed", ge=0.5, le=2.0),
364
+ pitch: float = Query(1.0, description="Pitch", ge=0.5, le=2.0),
365
+ energy: float = Query(1.0, description="Energy", ge=0.5, le=2.0),
366
+ style: Optional[str] = Query(None, description="Style preset"),
367
+ speaker_wav: UploadFile = File(
368
+ ..., description="Reference speaker WAV (3-15 seconds recommended)"
369
+ ),
370
+ ):
371
+ """
372
+ Clone a custom voice from uploaded sample using XTTS v2.
373
+ """
374
+ engine = get_engine()
375
+ lang_lower = lang.lower().strip()
376
+
377
+ if lang_lower not in XTTS_LANG_MAP:
378
+ supported = ", ".join(sorted(XTTS_LANG_MAP.keys()))
379
+ raise HTTPException(
380
+ status_code=400,
381
+ detail=f"Unsupported clone language: {lang}. Supported: {supported}",
382
+ )
383
+
384
+ temp_path = None
385
+ try:
386
+ data = await speaker_wav.read()
387
+ if len(data) < 44:
388
+ raise HTTPException(status_code=400, detail="Invalid speaker_wav file")
389
+
390
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
391
+ tmp.write(data)
392
+ temp_path = tmp.name
393
+
394
+ start_time = time.time()
395
+ output = engine.clone_voice(
396
+ text=text,
397
+ speaker_wav_path=temp_path,
398
+ language_code=XTTS_LANG_MAP[lang_lower],
399
+ speed=speed,
400
+ pitch=pitch,
401
+ energy=energy,
402
+ style=style,
403
+ normalize_text=True,
404
+ )
405
+ inference_time = time.time() - start_time
406
+
407
+ buffer = io.BytesIO()
408
+ sf.write(buffer, output.audio, output.sample_rate, format="WAV")
409
+ buffer.seek(0)
410
+
411
+ return Response(
412
+ content=buffer.read(),
413
+ media_type="audio/wav",
414
+ headers={
415
+ "X-Duration": str(output.duration),
416
+ "X-Sample-Rate": str(output.sample_rate),
417
+ "X-Language": lang_lower,
418
+ "X-Voice": "custom_cloned",
419
+ "X-Inference-Time": str(inference_time),
420
+ },
421
+ )
422
+ except HTTPException:
423
+ raise
424
+ except Exception as e:
425
+ logger.error(f"Clone error: {e}")
426
+ raise HTTPException(status_code=500, detail=str(e))
427
+ finally:
428
+ if temp_path and os.path.exists(temp_path):
429
+ try:
430
+ os.remove(temp_path)
431
+ except OSError:
432
+ pass
433
+
434
+
435
  @app.get("/synthesize/get")
436
  async def synthesize_get(
437
  text: str = Query(
src/engine.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
  TTS Engine for Multi-lingual Indian Language Speech Synthesis
3
 
4
- This engine uses VITS (Variational Inference with adversarial learning
5
  for Text-to-Speech) models trained on various Indian language datasets.
6
 
7
  Supported Languages:
@@ -25,7 +25,11 @@ from dataclasses import dataclass
25
 
26
  from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR, STYLE_PRESETS
27
  from .tokenizer import TTSTokenizer, CharactersConfig, TextNormalizer
28
- from .model_loader import _ensure_models_available, get_model_path, list_available_models
 
 
 
 
29
 
30
  logger = logging.getLogger(__name__)
31
 
@@ -33,6 +37,7 @@ logger = logging.getLogger(__name__)
33
  @dataclass
34
  class TTSOutput:
35
  """Output from TTS synthesis"""
 
36
  audio: np.ndarray
37
  sample_rate: int
38
  duration: float
@@ -48,13 +53,16 @@ class StyleProcessor:
48
  """
49
 
50
  @staticmethod
51
- def apply_pitch_shift(audio: np.ndarray, sample_rate: int, pitch_factor: float) -> np.ndarray:
 
 
52
  """Shift pitch without changing duration"""
53
  if pitch_factor == 1.0:
54
  return audio
55
 
56
  try:
57
  import librosa
 
58
  semitones = 12 * np.log2(pitch_factor)
59
  shifted = librosa.effects.pitch_shift(
60
  audio.astype(np.float32), sr=sample_rate, n_steps=semitones
@@ -62,23 +70,28 @@ class StyleProcessor:
62
  return shifted
63
  except ImportError:
64
  from scipy import signal
 
65
  stretched = signal.resample(audio, int(len(audio) / pitch_factor))
66
  return signal.resample(stretched, len(audio))
67
 
68
  @staticmethod
69
- def apply_speed_change(audio: np.ndarray, sample_rate: int, speed_factor: float) -> np.ndarray:
 
 
70
  """Change speed/tempo without changing pitch"""
71
  if speed_factor == 1.0:
72
  return audio
73
 
74
  try:
75
  import librosa
 
76
  stretched = librosa.effects.time_stretch(
77
  audio.astype(np.float32), rate=speed_factor
78
  )
79
  return stretched
80
  except ImportError:
81
  from scipy import signal
 
82
  target_length = int(len(audio) / speed_factor)
83
  return signal.resample(audio, target_length)
84
 
@@ -160,6 +173,7 @@ class TTSEngine:
160
  self._coqui_models: Dict[str, Any] = {}
161
  self._mms_models: Dict[str, Any] = {}
162
  self._mms_tokenizers: Dict[str, Any] = {}
 
163
 
164
  # Text normalizer
165
  self.normalizer = TextNormalizer()
@@ -216,7 +230,9 @@ class TTSEngine:
216
  else:
217
  raise FileNotFoundError(f"No model file found in {model_dir}")
218
 
219
- def _load_jit_voice(self, voice_key: str, model_dir: Path, model_path: Path) -> bool:
 
 
220
  """Load a JIT traced VITS model"""
221
  chars_path = model_dir / "chars.txt"
222
  if chars_path.exists():
@@ -238,7 +254,9 @@ class TTSEngine:
238
  logger.info(f"Loaded voice: {voice_key}")
239
  return True
240
 
241
- def _load_coqui_voice(self, voice_key: str, model_dir: Path, checkpoint_path: Path) -> bool:
 
 
242
  """Load a Coqui TTS checkpoint model"""
243
  config_path = model_dir / "config.json"
244
  if not config_path.exists():
@@ -333,6 +351,71 @@ class TTSEngine:
333
  torch.cuda.empty_cache() if self.device.type == "cuda" else None
334
  logger.info(f"Unloaded voice: {voice_key}")
335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  def synthesize(
337
  self,
338
  text: str,
@@ -423,7 +506,9 @@ class TTSEngine:
423
  """Synthesize speech and save to file"""
424
  import soundfile as sf
425
 
426
- output = self.synthesize(text, voice, speed, pitch, energy, style, normalize_text)
 
 
427
  sf.write(output_path, output.audio, output.sample_rate)
428
 
429
  logger.info(f"Saved audio to {output_path} (duration: {output.duration:.2f}s)")
@@ -454,8 +539,14 @@ class TTSEngine:
454
  voices[key] = {
455
  "name": config.name,
456
  "code": config.code,
457
- "gender": "male" if "male" in key else ("female" if "female" in key else "neutral"),
458
- "loaded": key in self._models or key in self._coqui_models or key in self._mms_models,
 
 
 
 
 
 
459
  "downloaded": is_mms or get_model_path(key) is not None,
460
  "type": model_type,
461
  }
@@ -465,12 +556,16 @@ class TTSEngine:
465
  """Get available style presets"""
466
  return STYLE_PRESETS
467
 
468
- def batch_synthesize(self, texts: List[str], voice: str = "hi_male", speed: float = 1.0) -> List[TTSOutput]:
 
 
469
  """Synthesize multiple texts"""
470
  return [self.synthesize(text, voice, speed) for text in texts]
471
 
472
 
473
- def synthesize(text: str, voice: str = "hi_male", output_path: Optional[str] = None) -> Union[TTSOutput, str]:
 
 
474
  """Quick synthesis function"""
475
  engine = TTSEngine()
476
 
 
1
  """
2
  TTS Engine for Multi-lingual Indian Language Speech Synthesis
3
 
4
+ This engine uses VITS (Variational Inference with adversarial learning
5
  for Text-to-Speech) models trained on various Indian language datasets.
6
 
7
  Supported Languages:
 
25
 
26
  from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR, STYLE_PRESETS
27
  from .tokenizer import TTSTokenizer, CharactersConfig, TextNormalizer
28
+ from .model_loader import (
29
+ _ensure_models_available,
30
+ get_model_path,
31
+ list_available_models,
32
+ )
33
 
34
  logger = logging.getLogger(__name__)
35
 
 
37
  @dataclass
38
  class TTSOutput:
39
  """Output from TTS synthesis"""
40
+
41
  audio: np.ndarray
42
  sample_rate: int
43
  duration: float
 
53
  """
54
 
55
  @staticmethod
56
+ def apply_pitch_shift(
57
+ audio: np.ndarray, sample_rate: int, pitch_factor: float
58
+ ) -> np.ndarray:
59
  """Shift pitch without changing duration"""
60
  if pitch_factor == 1.0:
61
  return audio
62
 
63
  try:
64
  import librosa
65
+
66
  semitones = 12 * np.log2(pitch_factor)
67
  shifted = librosa.effects.pitch_shift(
68
  audio.astype(np.float32), sr=sample_rate, n_steps=semitones
 
70
  return shifted
71
  except ImportError:
72
  from scipy import signal
73
+
74
  stretched = signal.resample(audio, int(len(audio) / pitch_factor))
75
  return signal.resample(stretched, len(audio))
76
 
77
  @staticmethod
78
+ def apply_speed_change(
79
+ audio: np.ndarray, sample_rate: int, speed_factor: float
80
+ ) -> np.ndarray:
81
  """Change speed/tempo without changing pitch"""
82
  if speed_factor == 1.0:
83
  return audio
84
 
85
  try:
86
  import librosa
87
+
88
  stretched = librosa.effects.time_stretch(
89
  audio.astype(np.float32), rate=speed_factor
90
  )
91
  return stretched
92
  except ImportError:
93
  from scipy import signal
94
+
95
  target_length = int(len(audio) / speed_factor)
96
  return signal.resample(audio, target_length)
97
 
 
173
  self._coqui_models: Dict[str, Any] = {}
174
  self._mms_models: Dict[str, Any] = {}
175
  self._mms_tokenizers: Dict[str, Any] = {}
176
+ self._xtts_model: Optional[Any] = None
177
 
178
  # Text normalizer
179
  self.normalizer = TextNormalizer()
 
230
  else:
231
  raise FileNotFoundError(f"No model file found in {model_dir}")
232
 
233
+ def _load_jit_voice(
234
+ self, voice_key: str, model_dir: Path, model_path: Path
235
+ ) -> bool:
236
  """Load a JIT traced VITS model"""
237
  chars_path = model_dir / "chars.txt"
238
  if chars_path.exists():
 
254
  logger.info(f"Loaded voice: {voice_key}")
255
  return True
256
 
257
+ def _load_coqui_voice(
258
+ self, voice_key: str, model_dir: Path, checkpoint_path: Path
259
+ ) -> bool:
260
  """Load a Coqui TTS checkpoint model"""
261
  config_path = model_dir / "config.json"
262
  if not config_path.exists():
 
351
  torch.cuda.empty_cache() if self.device.type == "cuda" else None
352
  logger.info(f"Unloaded voice: {voice_key}")
353
 
354
+ def _get_xtts_model(self):
355
+ """Lazy-load Coqui XTTS model for voice cloning."""
356
+ if self._xtts_model is not None:
357
+ return self._xtts_model
358
+
359
+ try:
360
+ from TTS.api import TTS
361
+ except ImportError as exc:
362
+ raise ImportError(
363
+ "Coqui TTS is required for voice cloning. Install with: pip install TTS"
364
+ ) from exc
365
+
366
+ logger.info("Loading XTTS v2 voice cloning model...")
367
+ self._xtts_model = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
368
+ if self.device.type == "cuda":
369
+ self._xtts_model = self._xtts_model.to("cuda")
370
+ logger.info("XTTS v2 loaded")
371
+ return self._xtts_model
372
+
373
+ def clone_voice(
374
+ self,
375
+ text: str,
376
+ speaker_wav_path: str,
377
+ language_code: str = "en",
378
+ speed: float = 1.0,
379
+ pitch: float = 1.0,
380
+ energy: float = 1.0,
381
+ style: Optional[str] = None,
382
+ normalize_text: bool = True,
383
+ ) -> TTSOutput:
384
+ """Clone a speaker voice from a reference WAV using XTTS v2."""
385
+ xtts = self._get_xtts_model()
386
+
387
+ if normalize_text:
388
+ text = self.normalizer.clean_text(text, language_code)
389
+
390
+ wav = xtts.tts(
391
+ text=text,
392
+ speaker_wav=speaker_wav_path,
393
+ language=language_code,
394
+ )
395
+
396
+ audio_np = np.array(wav, dtype=np.float32)
397
+ sample_rate = 24000
398
+
399
+ if style and style in STYLE_PRESETS:
400
+ preset = STYLE_PRESETS[style]
401
+ speed = speed * preset["speed"]
402
+ pitch = pitch * preset["pitch"]
403
+ energy = energy * preset["energy"]
404
+
405
+ audio_np = self.style_processor.apply_style(
406
+ audio_np, sample_rate, speed=speed, pitch=pitch, energy=energy
407
+ )
408
+
409
+ duration = len(audio_np) / sample_rate
410
+ return TTSOutput(
411
+ audio=audio_np,
412
+ sample_rate=sample_rate,
413
+ duration=duration,
414
+ voice="custom_cloned",
415
+ text=text,
416
+ style=style,
417
+ )
418
+
419
  def synthesize(
420
  self,
421
  text: str,
 
506
  """Synthesize speech and save to file"""
507
  import soundfile as sf
508
 
509
+ output = self.synthesize(
510
+ text, voice, speed, pitch, energy, style, normalize_text
511
+ )
512
  sf.write(output_path, output.audio, output.sample_rate)
513
 
514
  logger.info(f"Saved audio to {output_path} (duration: {output.duration:.2f}s)")
 
539
  voices[key] = {
540
  "name": config.name,
541
  "code": config.code,
542
+ "gender": (
543
+ "male"
544
+ if "male" in key
545
+ else ("female" if "female" in key else "neutral")
546
+ ),
547
+ "loaded": key in self._models
548
+ or key in self._coqui_models
549
+ or key in self._mms_models,
550
  "downloaded": is_mms or get_model_path(key) is not None,
551
  "type": model_type,
552
  }
 
556
  """Get available style presets"""
557
  return STYLE_PRESETS
558
 
559
+ def batch_synthesize(
560
+ self, texts: List[str], voice: str = "hi_male", speed: float = 1.0
561
+ ) -> List[TTSOutput]:
562
  """Synthesize multiple texts"""
563
  return [self.synthesize(text, voice, speed) for text in texts]
564
 
565
 
566
+ def synthesize(
567
+ text: str, voice: str = "hi_male", output_path: Optional[str] = None
568
+ ) -> Union[TTSOutput, str]:
569
  """Quick synthesis function"""
570
  engine = TTSEngine()
571