ShadowHunter222 commited on
Commit
b725430
Β·
verified Β·
1 Parent(s): d61edf1

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ 3cpo_prompt.wav filter=lfs diff=lfs merge=lfs -text
37
+ aave_female_prompt.wav filter=lfs diff=lfs merge=lfs -text
38
+ her_prompt.wav filter=lfs diff=lfs merge=lfs -text
39
+ ivr_female_prompt.wav filter=lfs diff=lfs merge=lfs -text
3cpo_prompt.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a830bbf5494096e593dcfb6e099cfa334cb8b0b34d1403c69d36c02649c5ab15
3
+ size 513452
Dockerfile CHANGED
@@ -17,8 +17,9 @@ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/wh
17
  COPY requirements.txt .
18
  RUN pip install --no-cache-dir -r requirements.txt
19
 
20
- # Copy application code
21
  COPY config.py text_processor.py chatterbox_wrapper.py app.py ./
 
22
 
23
  # Pre-download ONNX models + tokenizer at build time
24
  RUN python -c "\
 
17
  COPY requirements.txt .
18
  RUN pip install --no-cache-dir -r requirements.txt
19
 
20
+ # Copy application code + local built-in voice samples from repo root
21
  COPY config.py text_processor.py chatterbox_wrapper.py app.py ./
22
+ COPY *.wav ./
23
 
24
  # Pre-download ONNX models + tokenizer at build time
25
  RUN python -c "\
aave_female_prompt.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:971a3568a5a1521612bff565ed416aea62e30da3e00a53d771ff2c26da78276d
3
+ size 1217636
app.py CHANGED
@@ -1,32 +1,16 @@
1
- """
2
- Chatterbox Turbo TTS -- FastAPI Server
3
- ======================================
4
- Production-ready API with true real-time MP3 streaming,
5
- in-memory voice cloning, and fully non-blocking inference.
6
-
7
- Endpoints:
8
- GET /health -> health check + optional warmup
9
- GET /info -> model info, supported tags, parameters
10
- POST /tts -> full audio response (WAV/MP3/FLAC)
11
- POST /tts/stream -> chunked MP3 streaming (MediaSource-ready)
12
- POST /tts/true-stream -> alias for /tts/stream (Kokoro compat)
13
- POST /tts/stop/{stream_id}-> cancel a specific active stream
14
- POST /tts/stop -> cancel ALL active streams
15
- POST /v1/audio/speech -> OpenAI-compatible streaming
16
- """
17
  import asyncio
 
18
  import io
19
  import json
20
  import logging
21
  import queue as stdlib_queue
22
  import threading
23
  import time
24
- import urllib.error
25
  import urllib.parse
26
- import urllib.request
27
  import uuid
28
  from concurrent.futures import ThreadPoolExecutor
29
- from typing import Generator, Optional
 
30
 
31
  import numpy as np
32
  import soundfile as sf
@@ -111,11 +95,15 @@ async def cors_middleware(request: Request, call_next):
111
 
112
  async def _resolve_voice(
113
  voice_ref: Optional[UploadFile],
 
114
  wrapper: ChatterboxWrapper,
115
  ) -> VoiceProfile:
116
- """Return a VoiceProfile from uploaded audio or the default voice."""
117
  if voice_ref is None or voice_ref.filename == "":
118
- return wrapper.default_voice
 
 
 
119
 
120
  audio_bytes = await voice_ref.read()
121
  if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES:
@@ -164,32 +152,165 @@ def _encode_mp3_chunk(audio: np.ndarray) -> bytes:
164
  return data
165
 
166
 
167
- def _build_helper_endpoint(base_url: str, path: str) -> str:
168
- return f"{base_url.rstrip('/')}{path}"
 
 
 
 
169
 
170
 
171
- def _internal_headers() -> dict[str, str]:
172
- headers = {"Content-Type": "application/json", "Accept": "audio/mpeg"}
 
 
 
 
 
 
173
  if Config.INTERNAL_SHARED_SECRET:
174
  headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET
175
  return headers
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def _helper_request_chunk(
179
  helper_base_url: str,
180
  payload: dict,
181
  timeout_sec: float,
 
182
  ) -> bytes:
183
- url = _build_helper_endpoint(helper_base_url, "/internal/chunk/synthesize")
184
- body = json.dumps(payload).encode("utf-8")
185
- req = urllib.request.Request(
186
- url=url,
187
- data=body,
188
- headers=_internal_headers(),
189
- method="POST",
190
- )
191
- with urllib.request.urlopen(req, timeout=timeout_sec) as resp:
192
- return resp.read()
193
 
194
 
195
  def _helper_register_voice(
@@ -197,44 +318,45 @@ def _helper_register_voice(
197
  stream_id: str,
198
  audio_bytes: bytes,
199
  timeout_sec: float,
 
200
  ) -> str:
201
  """Register reference voice on helper once, return voice_key for chunk calls."""
202
- query = urllib.parse.urlencode({"stream_id": stream_id})
203
- url = _build_helper_endpoint(helper_base_url, f"/internal/voice/register?{query}")
204
- headers = {"Content-Type": "application/octet-stream", "Accept": "application/json"}
205
- if Config.INTERNAL_SHARED_SECRET:
206
- headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET
207
-
208
- req = urllib.request.Request(
209
- url=url,
210
- data=audio_bytes,
211
- headers=headers,
212
- method="POST",
213
- )
214
- with urllib.request.urlopen(req, timeout=timeout_sec) as resp:
215
- data = json.loads(resp.read().decode("utf-8"))
216
- voice_key = (data.get("voice_key") or "").strip()
217
- if not voice_key:
218
- raise RuntimeError("helper voice registration returned no voice_key")
219
- return voice_key
220
 
221
 
222
  def _helper_cancel_stream(helper_base_url: str, stream_id: str):
223
  """Best-effort cancellation signal to helper."""
224
  try:
225
- url = _build_helper_endpoint(helper_base_url, f"/internal/chunk/cancel/{stream_id}")
226
- req = urllib.request.Request(
227
- url=url,
228
- data=b"",
229
- headers=_internal_headers(),
230
- method="POST",
231
- )
232
- with urllib.request.urlopen(req, timeout=3.0):
233
- pass
234
  except Exception:
235
  pass
236
 
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  # ═══════════════════════════════════════════════════════════════════
239
  # Endpoints
240
  # ═══════════════════════════════════════════════════════════════════
@@ -242,12 +364,19 @@ def _helper_cancel_stream(helper_base_url: str, stream_id: str):
242
  @app.get("/health")
243
  async def health(warm_up: bool = False):
244
  wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
 
 
 
 
 
245
  status = {
246
  "status": "healthy" if wrapper else "loading",
247
  "model_loaded": wrapper is not None,
248
  "model_dtype": Config.MODEL_DTYPE,
249
  "streaming_supported": True,
250
  "voice_cache_entries": wrapper._voice_cache.size if wrapper else 0,
 
 
251
  }
252
  if warm_up and wrapper:
253
  try:
@@ -259,27 +388,22 @@ async def health(warm_up: bool = False):
259
  return status
260
 
261
 
262
- @app.get("/info")
263
- async def info():
 
 
 
 
 
 
264
  return {
265
- "model": Config.MODEL_ID,
266
- "dtype": Config.MODEL_DTYPE,
267
- "sample_rate": Config.SAMPLE_RATE,
268
- "paralinguistic_tags": list(Config.PARALINGUISTIC_TAGS),
269
- "tag_usage": "Insert tags directly in text, e.g. 'That is so funny! [laugh] Anyway…'",
270
- "parameters": {
271
- "max_new_tokens": {"default": Config.MAX_NEW_TOKENS, "range": "64–2048"},
272
- "repetition_penalty": {"default": Config.REPETITION_PENALTY, "range": "1.0–2.0"},
273
- },
274
- "voice_cloning": {
275
- "description": "Upload 3–30s reference WAV/MP3 as 'voice_ref' field",
276
- "max_upload_mb": Config.MAX_VOICE_UPLOAD_BYTES // (1024 * 1024),
277
- },
278
- "parallel_mode": {
279
- "enabled": Config.ENABLE_PARALLEL_MODE,
280
- "helper_configured": bool(Config.HELPER_BASE_URL),
281
- "helper_base_url": Config.HELPER_BASE_URL or None,
282
- "supports_voice_ref": True,
283
  },
284
  }
285
 
@@ -290,6 +414,7 @@ async def info():
290
  async def text_to_speech(
291
  text: str = Form(...),
292
  voice_ref: Optional[UploadFile] = File(None),
 
293
  output_format: str = Form("wav"),
294
  max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
295
  repetition_penalty: float = Form(Config.REPETITION_PENALTY),
@@ -302,7 +427,7 @@ async def text_to_speech(
302
  if not text or not text.strip():
303
  raise HTTPException(400, "Text is required")
304
 
305
- voice = await _resolve_voice(voice_ref, wrapper)
306
 
307
  loop = asyncio.get_running_loop()
308
  try:
@@ -329,9 +454,51 @@ async def text_to_speech(
329
  # ═══════════════════════════════════════════════════════════════════
330
 
331
  _active_streams: dict[str, threading.Event] = {}
332
- _internal_cancelled_streams: set[str] = set()
 
333
  _internal_cancel_lock = threading.Lock()
334
- _internal_stream_voice_keys: dict[str, set[str]] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
 
337
  # ═══════════════════════════════════════════════════════════════════
@@ -431,7 +598,7 @@ def _pipeline_stream_generator(
431
  _active_streams.pop(stream_id, None)
432
 
433
 
434
- def _parallel_odd_even_stream_generator(
435
  wrapper: ChatterboxWrapper,
436
  text: str,
437
  local_voice: VoiceProfile,
@@ -441,26 +608,43 @@ def _parallel_odd_even_stream_generator(
441
  stream_id: str,
442
  helper_base_url: str,
443
  ) -> Generator[bytes, None, None]:
444
- """Additive odd/even split streamer (primary handles odd, helper handles even)."""
 
 
 
 
 
445
  cancel_event = threading.Event()
446
  _active_streams[stream_id] = cancel_event
447
 
 
 
 
 
 
 
448
  clean_text = text_processor.sanitize(text.strip()[: Config.MAX_TEXT_LENGTH])
449
  chunks = text_processor.split_for_streaming(clean_text)
450
  total_chunks = len(chunks)
451
  if total_chunks == 0:
 
 
452
  _active_streams.pop(stream_id, None)
453
  return
454
 
455
  lock = threading.Lock()
456
  cond = threading.Condition(lock)
457
- ready: dict[int, bytes] = {}
458
  first_error: Optional[Exception] = None
459
  workers_done = 0
 
 
460
 
461
- def _publish(idx: int, data: bytes):
462
  with cond:
463
- ready[idx] = data
 
 
464
  cond.notify_all()
465
 
466
  def _set_error(err: Exception):
@@ -485,23 +669,46 @@ def _parallel_odd_even_stream_generator(
485
  )
486
  return _encode_mp3_chunk(audio)
487
 
488
- def _odd_worker():
489
  try:
490
  for idx in range(0, total_chunks, 2):
491
  if cancel_event.is_set():
492
  break
493
  data = _synth_local(chunks[idx])
494
- _publish(idx, data)
 
 
 
 
 
 
 
495
  except Exception as e:
496
  _set_error(e)
497
  finally:
498
  _worker_done()
499
 
500
- def _even_worker():
501
- helper_available = True
502
  helper_voice_key: Optional[str] = None
 
 
 
503
  try:
504
- if helper_voice_bytes:
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  attempts = 2 if Config.HELPER_RETRY_ONCE else 1
506
  last_err: Optional[Exception] = None
507
  for _ in range(attempts):
@@ -510,19 +717,25 @@ def _parallel_odd_even_stream_generator(
510
  helper_base_url=helper_base_url,
511
  stream_id=stream_id,
512
  audio_bytes=helper_voice_bytes,
513
- timeout_sec=max(1.0, Config.HELPER_TIMEOUT_SEC),
 
514
  )
515
  last_err = None
516
  break
517
  except Exception as reg_err:
518
  last_err = reg_err
519
  continue
 
520
  if last_err is not None:
521
  helper_available = False
522
  logger.warning(
523
- f"[{stream_id}] Helper voice registration failed; "
524
- "falling back to local synthesis for even chunks"
525
  )
 
 
 
 
526
 
527
  for idx in range(1, total_chunks, 2):
528
  if cancel_event.is_set():
@@ -547,9 +760,17 @@ def _parallel_odd_even_stream_generator(
547
  helper_data = _helper_request_chunk(
548
  helper_base_url=helper_base_url,
549
  payload=payload,
550
- timeout_sec=max(1.0, Config.HELPER_TIMEOUT_SEC),
 
 
 
 
 
 
 
 
 
551
  )
552
- _publish(idx, helper_data)
553
  last_err = None
554
  break
555
  except Exception as helper_err:
@@ -561,22 +782,31 @@ def _parallel_odd_even_stream_generator(
561
 
562
  helper_available = False
563
  logger.warning(
564
- f"[{stream_id}] Helper failed at chunk {idx}; "
565
- "falling back to local synthesis for remaining even chunks"
566
  )
567
 
568
- # Local fallback for even chunks
569
  data = _synth_local(chunks[idx])
570
- _publish(idx, data)
 
 
 
 
 
 
 
571
  except Exception as e:
572
  _set_error(e)
573
  finally:
 
 
574
  _worker_done()
575
 
576
- odd_thread = threading.Thread(target=_odd_worker, daemon=True)
577
- even_thread = threading.Thread(target=_even_worker, daemon=True)
578
- odd_thread.start()
579
- even_thread.start()
580
 
581
  next_idx = 0
582
  try:
@@ -586,7 +816,7 @@ def _parallel_odd_even_stream_generator(
586
  next_idx not in ready
587
  and first_error is None
588
  and not cancel_event.is_set()
589
- and workers_done < 2
590
  ):
591
  cond.wait(timeout=0.1)
592
 
@@ -594,11 +824,12 @@ def _parallel_odd_even_stream_generator(
594
  break
595
 
596
  if next_idx in ready:
597
- data = ready.pop(next_idx)
 
598
  elif first_error is not None:
599
  logger.error(f"[{stream_id}] Parallel stream error: {first_error}")
600
  break
601
- elif workers_done >= 2:
602
  logger.error(
603
  f"[{stream_id}] Parallel stream ended with missing chunk index {next_idx}"
604
  )
@@ -606,13 +837,39 @@ def _parallel_odd_even_stream_generator(
606
  else:
607
  continue
608
 
609
- yield data
 
 
 
 
 
 
 
 
610
  next_idx += 1
 
 
 
 
 
611
  finally:
612
  cancel_event.set()
613
- _helper_cancel_stream(helper_base_url, stream_id)
614
- odd_thread.join(timeout=1.0)
615
- even_thread.join(timeout=1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  _active_streams.pop(stream_id, None)
617
 
618
 
@@ -623,6 +880,7 @@ def _parallel_odd_even_stream_generator(
623
  async def stream_text_to_speech(
624
  text: str = Form(...),
625
  voice_ref: Optional[UploadFile] = File(None),
 
626
  max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
627
  repetition_penalty: float = Form(Config.REPETITION_PENALTY),
628
  ):
@@ -638,7 +896,7 @@ async def stream_text_to_speech(
638
  if not text or not text.strip():
639
  raise HTTPException(400, "Text is required")
640
 
641
- voice = await _resolve_voice(voice_ref, wrapper)
642
  stream_id = uuid.uuid4().hex[:12]
643
 
644
  return StreamingResponse(
@@ -660,11 +918,12 @@ async def stream_text_to_speech(
660
  async def parallel_stream_text_to_speech(
661
  text: str = Form(...),
662
  voice_ref: Optional[UploadFile] = File(None),
 
663
  max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
664
  repetition_penalty: float = Form(Config.REPETITION_PENALTY),
665
  helper_url: Optional[str] = Form(None),
666
  ):
667
- """Additive odd/even split stream mode (primary + helper)."""
668
  wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
669
  if not wrapper:
670
  raise HTTPException(503, "Model not loaded")
@@ -689,17 +948,32 @@ async def parallel_stream_text_to_speech(
689
  except Exception as e:
690
  logger.error(f"Parallel voice encoding failed: {e}")
691
  raise HTTPException(400, "Could not process voice file for parallel mode")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692
 
693
  resolved_helper = (helper_url or Config.HELPER_BASE_URL).strip()
694
  if not resolved_helper:
695
  raise HTTPException(
696
  400,
697
- "Helper URL not configured. Set CB_HELPER_BASE_URL or pass helper_url.",
698
  )
699
 
700
  stream_id = uuid.uuid4().hex[:12]
701
  return StreamingResponse(
702
- _parallel_odd_even_stream_generator(
703
  wrapper=wrapper,
704
  text=text,
705
  local_voice=local_voice,
@@ -714,7 +988,7 @@ async def parallel_stream_text_to_speech(
714
  "Content-Disposition": "attachment; filename=tts_parallel_stream.mp3",
715
  "Transfer-Encoding": "chunked",
716
  "X-Stream-Id": stream_id,
717
- "X-Streaming-Type": "parallel-odd-even",
718
  "Cache-Control": "no-cache",
719
  },
720
  )
@@ -777,8 +1051,13 @@ async def internal_voice_register(http_request: Request):
777
  stream_id = (http_request.query_params.get("stream_id") or "").strip()
778
  if stream_id:
779
  with _internal_cancel_lock:
780
- keys = _internal_stream_voice_keys.setdefault(stream_id, set())
 
781
  keys.add(voice_key)
 
 
 
 
782
 
783
  return {"status": "registered", "voice_key": voice_key}
784
 
@@ -795,8 +1074,10 @@ async def internal_chunk_synthesize(
795
  raise HTTPException(403, "Forbidden")
796
 
797
  with _internal_cancel_lock:
 
798
  if request.stream_id in _internal_cancelled_streams:
799
  raise HTTPException(409, "Stream already cancelled")
 
800
 
801
  wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
802
  if not wrapper:
@@ -805,6 +1086,9 @@ async def internal_chunk_synthesize(
805
  voice_profile = wrapper.default_voice
806
  if request.voice_key:
807
  cached_voice = wrapper._voice_cache.get(request.voice_key)
 
 
 
808
  if cached_voice is None:
809
  raise HTTPException(409, "Voice key expired or not found")
810
  voice_profile = cached_voice
@@ -845,26 +1129,48 @@ async def internal_chunk_cancel(stream_id: str, http_request: Request):
845
  raise HTTPException(403, "Forbidden")
846
 
847
  with _internal_cancel_lock:
848
- _internal_cancelled_streams.add(stream_id)
 
 
 
849
  _internal_stream_voice_keys.pop(stream_id, None)
850
  return {"status": "cancelled", "stream_id": stream_id}
851
 
852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853
  @app.post("/v1/audio/speech")
854
  async def openai_compatible_tts(request: TTSJsonRequest):
855
  """OpenAI-compatible streaming endpoint (JSON body, no file upload).
856
 
857
- Uses the default voice. For voice cloning, use /tts/stream with FormData.
858
  """
859
  wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
860
  if not wrapper:
861
  raise HTTPException(503, "Model not loaded")
862
 
 
 
 
 
 
863
  stream_id = uuid.uuid4().hex[:12]
864
 
865
  return StreamingResponse(
866
  _pipeline_stream_generator(
867
- wrapper, request.text, wrapper.default_voice,
868
  request.max_new_tokens, request.repetition_penalty, stream_id,
869
  ),
870
  media_type="audio/mpeg",
@@ -889,6 +1195,10 @@ async def stop_stream(stream_id: str):
889
  event = _active_streams.get(stream_id)
890
  if event:
891
  event.set()
 
 
 
 
892
  logger.info(f"Stream {stream_id} cancelled by client")
893
  return {"status": "stopped", "stream_id": stream_id}
894
  return {"status": "not_found", "stream_id": stream_id}
@@ -897,9 +1207,16 @@ async def stop_stream(stream_id: str):
897
  @app.post("/tts/stop")
898
  async def stop_all_streams():
899
  """Emergency stop: cancel ALL active TTS streams."""
900
- count = len(_active_streams)
901
- for sid, event in list(_active_streams.items()):
 
 
 
 
 
902
  event.set()
 
 
903
  _active_streams.clear()
904
  logger.info(f"Stopped all streams ({count} active)")
905
  return {"status": "stopped_all", "count": count}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
+ import http.client
3
  import io
4
  import json
5
  import logging
6
  import queue as stdlib_queue
7
  import threading
8
  import time
 
9
  import urllib.parse
 
10
  import uuid
11
  from concurrent.futures import ThreadPoolExecutor
12
+ from dataclasses import dataclass
13
+ from typing import Any, Generator, Optional
14
 
15
  import numpy as np
16
  import soundfile as sf
 
95
 
96
  async def _resolve_voice(
97
  voice_ref: Optional[UploadFile],
98
+ voice_name: Optional[str],
99
  wrapper: ChatterboxWrapper,
100
  ) -> VoiceProfile:
101
+ """Return a VoiceProfile from uploaded audio or built-in voice selection."""
102
  if voice_ref is None or voice_ref.filename == "":
103
+ try:
104
+ return wrapper.get_builtin_voice(voice_name)
105
+ except ValueError as e:
106
+ raise HTTPException(status_code=400, detail=str(e))
107
 
108
  audio_bytes = await voice_ref.read()
109
  if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES:
 
152
  return data
153
 
154
 
155
+ @dataclass(frozen=True)
156
+ class _ChunkPacket:
157
+ index: int
158
+ data: bytes
159
+ lane: str
160
+ produced_at: float
161
 
162
 
163
+ def _internal_headers(
164
+ *,
165
+ content_type: Optional[str] = "application/json",
166
+ accept: str = "audio/mpeg",
167
+ ) -> dict[str, str]:
168
+ headers: dict[str, str] = {"Accept": accept, "Connection": "keep-alive"}
169
+ if content_type:
170
+ headers["Content-Type"] = content_type
171
  if Config.INTERNAL_SHARED_SECRET:
172
  headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET
173
  return headers
174
 
175
 
176
+ class _HelperHttpClient:
177
+ """Small persistent HTTP client for helper server keep-alive calls."""
178
+
179
+ def __init__(self, base_url: str, default_timeout: float):
180
+ parsed = urllib.parse.urlparse((base_url or "").strip())
181
+ if parsed.scheme not in {"http", "https"} or not parsed.hostname:
182
+ raise ValueError(f"Invalid helper URL: {base_url!r}")
183
+
184
+ self._scheme = parsed.scheme
185
+ self._host = parsed.hostname
186
+ self._port = parsed.port
187
+ self._base_path = (parsed.path or "").rstrip("/")
188
+ self._default_timeout = max(1.0, float(default_timeout))
189
+ self._conn: Optional[http.client.HTTPConnection] = None
190
+
191
+ def __enter__(self):
192
+ return self
193
+
194
+ def __exit__(self, exc_type, exc, tb):
195
+ self.close()
196
+
197
+ def close(self):
198
+ if self._conn is not None:
199
+ try:
200
+ self._conn.close()
201
+ except Exception:
202
+ pass
203
+ self._conn = None
204
+
205
+ def _target(self, path: str, query: Optional[str] = None) -> str:
206
+ normalized = path if path.startswith("/") else f"/{path}"
207
+ target = f"{self._base_path}{normalized}"
208
+ if query:
209
+ target = f"{target}?{query}"
210
+ return target
211
+
212
+ def _make_connection(self, timeout_sec: float) -> http.client.HTTPConnection:
213
+ if self._scheme == "https":
214
+ return http.client.HTTPSConnection(self._host, self._port, timeout=timeout_sec)
215
+ return http.client.HTTPConnection(self._host, self._port, timeout=timeout_sec)
216
+
217
+ def _ensure_connection(self, timeout_sec: float) -> http.client.HTTPConnection:
218
+ if self._conn is None:
219
+ self._conn = self._make_connection(timeout_sec)
220
+ else:
221
+ self._conn.timeout = timeout_sec
222
+ return self._conn
223
+
224
+ def _request(
225
+ self,
226
+ method: str,
227
+ path: str,
228
+ *,
229
+ body: Optional[bytes] = None,
230
+ headers: Optional[dict[str, str]] = None,
231
+ timeout_sec: Optional[float] = None,
232
+ query: Optional[str] = None,
233
+ ) -> tuple[int, bytes, dict[str, str]]:
234
+ timeout = max(1.0, float(timeout_sec or self._default_timeout))
235
+ target = self._target(path, query=query)
236
+ req_headers = headers or {}
237
+
238
+ conn = self._ensure_connection(timeout)
239
+ try:
240
+ conn.request(method=method, url=target, body=body, headers=req_headers)
241
+ resp = conn.getresponse()
242
+ payload = resp.read()
243
+ resp_headers = {k.lower(): v for k, v in resp.getheaders()}
244
+ except Exception:
245
+ # Force reconnect on next attempt if socket is stale/reset.
246
+ self.close()
247
+ raise
248
+
249
+ if resp.status >= 400:
250
+ snippet = payload[:256].decode("utf-8", errors="replace")
251
+ raise RuntimeError(
252
+ f"helper {method} {target} returned {resp.status}: {snippet}"
253
+ )
254
+ return resp.status, payload, resp_headers
255
+
256
+ def request_chunk(self, payload: dict[str, Any], timeout_sec: float) -> bytes:
257
+ _, data, _ = self._request(
258
+ "POST",
259
+ "/internal/chunk/synthesize",
260
+ body=json.dumps(payload).encode("utf-8"),
261
+ headers=_internal_headers(content_type="application/json", accept="audio/mpeg"),
262
+ timeout_sec=timeout_sec,
263
+ )
264
+ return data
265
+
266
+ def register_voice(self, stream_id: str, audio_bytes: bytes, timeout_sec: float) -> str:
267
+ query = urllib.parse.urlencode({"stream_id": stream_id})
268
+ _, data, _ = self._request(
269
+ "POST",
270
+ "/internal/voice/register",
271
+ query=query,
272
+ body=audio_bytes,
273
+ headers=_internal_headers(
274
+ content_type="application/octet-stream",
275
+ accept="application/json",
276
+ ),
277
+ timeout_sec=timeout_sec,
278
+ )
279
+ payload = json.loads(data.decode("utf-8"))
280
+ voice_key = (payload.get("voice_key") or "").strip()
281
+ if not voice_key:
282
+ raise RuntimeError("helper voice registration returned no voice_key")
283
+ return voice_key
284
+
285
+ def cancel_stream(self, stream_id: str, timeout_sec: float = 3.0):
286
+ self._request(
287
+ "POST",
288
+ f"/internal/chunk/cancel/{stream_id}",
289
+ body=b"",
290
+ headers=_internal_headers(),
291
+ timeout_sec=timeout_sec,
292
+ )
293
+
294
+ def complete_stream(self, stream_id: str, timeout_sec: float = 3.0):
295
+ self._request(
296
+ "POST",
297
+ f"/internal/chunk/complete/{stream_id}",
298
+ body=b"",
299
+ headers=_internal_headers(),
300
+ timeout_sec=timeout_sec,
301
+ )
302
+
303
+
304
  def _helper_request_chunk(
305
  helper_base_url: str,
306
  payload: dict,
307
  timeout_sec: float,
308
+ helper_client: Optional[_HelperHttpClient] = None,
309
  ) -> bytes:
310
+ if helper_client is not None:
311
+ return helper_client.request_chunk(payload, timeout_sec=timeout_sec)
312
+ with _HelperHttpClient(helper_base_url, default_timeout=timeout_sec) as helper_client_single:
313
+ return helper_client_single.request_chunk(payload, timeout_sec=timeout_sec)
 
 
 
 
 
 
314
 
315
 
316
  def _helper_register_voice(
 
318
  stream_id: str,
319
  audio_bytes: bytes,
320
  timeout_sec: float,
321
+ helper_client: Optional[_HelperHttpClient] = None,
322
  ) -> str:
323
  """Register reference voice on helper once, return voice_key for chunk calls."""
324
+ if helper_client is not None:
325
+ return helper_client.register_voice(
326
+ stream_id=stream_id,
327
+ audio_bytes=audio_bytes,
328
+ timeout_sec=timeout_sec,
329
+ )
330
+ with _HelperHttpClient(helper_base_url, default_timeout=timeout_sec) as helper_client_single:
331
+ return helper_client_single.register_voice(
332
+ stream_id=stream_id,
333
+ audio_bytes=audio_bytes,
334
+ timeout_sec=timeout_sec,
335
+ )
 
 
 
 
 
 
336
 
337
 
338
  def _helper_cancel_stream(helper_base_url: str, stream_id: str):
339
  """Best-effort cancellation signal to helper."""
340
  try:
341
+ with _HelperHttpClient(helper_base_url, default_timeout=3.0) as helper_client:
342
+ helper_client.cancel_stream(stream_id=stream_id, timeout_sec=3.0)
 
 
 
 
 
 
 
343
  except Exception:
344
  pass
345
 
346
 
347
+ def _helper_complete_stream(helper_base_url: str, stream_id: str):
348
+ """Best-effort stream completion cleanup on helper.
349
+
350
+ Falls back to cancel for backwards compatibility if helper does not expose
351
+ the completion endpoint yet.
352
+ """
353
+ try:
354
+ with _HelperHttpClient(helper_base_url, default_timeout=3.0) as helper_client:
355
+ helper_client.complete_stream(stream_id=stream_id, timeout_sec=3.0)
356
+ except Exception:
357
+ _helper_cancel_stream(helper_base_url, stream_id)
358
+
359
+
360
  # ═══════════════════════════════════════════════════════════════════
361
  # Endpoints
362
  # ═══════════════════════════════════════════════════════════════════
 
364
  @app.get("/health")
365
  async def health(warm_up: bool = False):
366
  wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
367
+ with _internal_cancel_lock:
368
+ _purge_internal_stream_state_locked()
369
+ cancelled_count = len(_internal_cancelled_streams)
370
+ voice_state_count = len(_internal_stream_voice_keys)
371
+
372
  status = {
373
  "status": "healthy" if wrapper else "loading",
374
  "model_loaded": wrapper is not None,
375
  "model_dtype": Config.MODEL_DTYPE,
376
  "streaming_supported": True,
377
  "voice_cache_entries": wrapper._voice_cache.size if wrapper else 0,
378
+ "internal_cancelled_streams": cancelled_count,
379
+ "internal_stream_voice_states": voice_state_count,
380
  }
381
  if warm_up and wrapper:
382
  try:
 
388
  return status
389
 
390
 
391
+
392
+ @app.get("/voices")
393
+ async def list_voices():
394
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
395
+ if not wrapper:
396
+ raise HTTPException(503, "Model not loaded")
397
+
398
+ voices = wrapper.list_builtin_voices()
399
  return {
400
+ "count": len(voices),
401
+ "default_voice": wrapper.default_voice_name,
402
+ "voices": voices,
403
+ "usage": {
404
+ "form_field": "voice_name",
405
+ "json_field": "voice",
406
+ "note": "If voice_ref is uploaded, it overrides voice_name.",
 
 
 
 
 
 
 
 
 
 
 
407
  },
408
  }
409
 
 
414
  async def text_to_speech(
415
  text: str = Form(...),
416
  voice_ref: Optional[UploadFile] = File(None),
417
+ voice_name: str = Form("default"),
418
  output_format: str = Form("wav"),
419
  max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
420
  repetition_penalty: float = Form(Config.REPETITION_PENALTY),
 
427
  if not text or not text.strip():
428
  raise HTTPException(400, "Text is required")
429
 
430
+ voice = await _resolve_voice(voice_ref, voice_name, wrapper)
431
 
432
  loop = asyncio.get_running_loop()
433
  try:
 
454
  # ═══════════════════════════════════════════════════════════════════
455
 
456
  _active_streams: dict[str, threading.Event] = {}
457
+ # stream_id -> expires_at epoch seconds
458
+ _internal_cancelled_streams: dict[str, float] = {}
459
  _internal_cancel_lock = threading.Lock()
460
+ # stream_id -> (voice_keys, expires_at)
461
+ _internal_stream_voice_keys: dict[str, tuple[set[str], float]] = {}
462
+
463
+ # stream_id -> helper base URLs (used to cancel helpers quickly on /tts/stop)
464
+ _stream_helper_routes: dict[str, set[str]] = {}
465
+ _stream_routes_lock = threading.Lock()
466
+
467
+
468
+ def _purge_internal_stream_state_locked(now: Optional[float] = None):
469
+ now_ts = now if now is not None else time.time()
470
+
471
+ expired_cancel_ids = [
472
+ sid for sid, expires_at in _internal_cancelled_streams.items()
473
+ if expires_at <= now_ts
474
+ ]
475
+ for sid in expired_cancel_ids:
476
+ _internal_cancelled_streams.pop(sid, None)
477
+
478
+ expired_voice_state_ids = [
479
+ sid for sid, (_, expires_at) in _internal_stream_voice_keys.items()
480
+ if expires_at <= now_ts
481
+ ]
482
+ for sid in expired_voice_state_ids:
483
+ _internal_stream_voice_keys.pop(sid, None)
484
+
485
+
486
+ def _touch_internal_stream_voice_keys_locked(stream_id: str):
487
+ if not stream_id:
488
+ return
489
+ entry = _internal_stream_voice_keys.get(stream_id)
490
+ if entry is None:
491
+ return
492
+ keys, _ = entry
493
+ _internal_stream_voice_keys[stream_id] = (
494
+ keys,
495
+ time.time() + max(1, Config.INTERNAL_STREAM_STATE_TTL_SEC),
496
+ )
497
+
498
+
499
+ def _clear_internal_stream_state_locked(stream_id: str):
500
+ _internal_cancelled_streams.pop(stream_id, None)
501
+ _internal_stream_voice_keys.pop(stream_id, None)
502
 
503
 
504
  # ═══════════════════════════════════════════════════════════════════
 
598
  _active_streams.pop(stream_id, None)
599
 
600
 
601
+ def _parallel_two_way_stream_generator(
602
  wrapper: ChatterboxWrapper,
603
  text: str,
604
  local_voice: VoiceProfile,
 
608
  stream_id: str,
609
  helper_base_url: str,
610
  ) -> Generator[bytes, None, None]:
611
+ """Additive 2-way split streamer (primary + helper).
612
+
613
+ Routing pattern:
614
+ - chunk 0,2,4... -> primary (local)
615
+ - chunk 1,3,5... -> helper
616
+ """
617
  cancel_event = threading.Event()
618
  _active_streams[stream_id] = cancel_event
619
 
620
+ helper_base_url = (helper_base_url or "").strip()
621
+ helper_route_set = {helper_base_url} if helper_base_url else set()
622
+ if helper_route_set:
623
+ with _stream_routes_lock:
624
+ _stream_helper_routes[stream_id] = set(helper_route_set)
625
+
626
  clean_text = text_processor.sanitize(text.strip()[: Config.MAX_TEXT_LENGTH])
627
  chunks = text_processor.split_for_streaming(clean_text)
628
  total_chunks = len(chunks)
629
  if total_chunks == 0:
630
+ with _stream_routes_lock:
631
+ _stream_helper_routes.pop(stream_id, None)
632
  _active_streams.pop(stream_id, None)
633
  return
634
 
635
  lock = threading.Lock()
636
  cond = threading.Condition(lock)
637
+ ready: dict[int, _ChunkPacket] = {}
638
  first_error: Optional[Exception] = None
639
  workers_done = 0
640
+ expected_workers = 2
641
+ stream_completed = False
642
 
643
+ def _publish(packet: _ChunkPacket):
644
  with cond:
645
+ # First write wins for an index to avoid duplicate retry races.
646
+ if packet.index not in ready:
647
+ ready[packet.index] = packet
648
  cond.notify_all()
649
 
650
  def _set_error(err: Exception):
 
669
  )
670
  return _encode_mp3_chunk(audio)
671
 
672
+ def _local_worker():
673
  try:
674
  for idx in range(0, total_chunks, 2):
675
  if cancel_event.is_set():
676
  break
677
  data = _synth_local(chunks[idx])
678
+ _publish(
679
+ _ChunkPacket(
680
+ index=idx,
681
+ data=data,
682
+ lane="primary",
683
+ produced_at=time.perf_counter(),
684
+ )
685
+ )
686
  except Exception as e:
687
  _set_error(e)
688
  finally:
689
  _worker_done()
690
 
691
+ def _helper_worker():
692
+ helper_available = bool(helper_base_url)
693
  helper_voice_key: Optional[str] = None
694
+ helper_timeout = max(1.0, Config.HELPER_TIMEOUT_SEC)
695
+ helper_client: Optional[_HelperHttpClient] = None
696
+
697
  try:
698
+ if helper_available:
699
+ try:
700
+ helper_client = _HelperHttpClient(
701
+ helper_base_url,
702
+ default_timeout=helper_timeout,
703
+ )
704
+ except Exception as conn_err:
705
+ helper_available = False
706
+ logger.warning(
707
+ f"[{stream_id}] helper keep-alive init failed ({conn_err}); "
708
+ "using local fallback for helper lane"
709
+ )
710
+
711
+ if helper_available and helper_voice_bytes:
712
  attempts = 2 if Config.HELPER_RETRY_ONCE else 1
713
  last_err: Optional[Exception] = None
714
  for _ in range(attempts):
 
717
  helper_base_url=helper_base_url,
718
  stream_id=stream_id,
719
  audio_bytes=helper_voice_bytes,
720
+ timeout_sec=helper_timeout,
721
+ helper_client=helper_client,
722
  )
723
  last_err = None
724
  break
725
  except Exception as reg_err:
726
  last_err = reg_err
727
  continue
728
+
729
  if last_err is not None:
730
  helper_available = False
731
  logger.warning(
732
+ f"[{stream_id}] helper voice registration failed; "
733
+ "falling back to local synthesis for helper lane"
734
  )
735
+ elif not helper_available:
736
+ logger.info(
737
+ f"[{stream_id}] helper URL not configured; using local fallback"
738
+ )
739
 
740
  for idx in range(1, total_chunks, 2):
741
  if cancel_event.is_set():
 
760
  helper_data = _helper_request_chunk(
761
  helper_base_url=helper_base_url,
762
  payload=payload,
763
+ timeout_sec=helper_timeout,
764
+ helper_client=helper_client,
765
+ )
766
+ _publish(
767
+ _ChunkPacket(
768
+ index=idx,
769
+ data=helper_data,
770
+ lane="helper",
771
+ produced_at=time.perf_counter(),
772
+ )
773
  )
 
774
  last_err = None
775
  break
776
  except Exception as helper_err:
 
782
 
783
  helper_available = False
784
  logger.warning(
785
+ f"[{stream_id}] helper failed at chunk {idx}; "
786
+ "falling back to local synthesis for remaining helper chunks"
787
  )
788
 
789
+ # Local fallback for helper lane
790
  data = _synth_local(chunks[idx])
791
+ _publish(
792
+ _ChunkPacket(
793
+ index=idx,
794
+ data=data,
795
+ lane="helper-local-fallback",
796
+ produced_at=time.perf_counter(),
797
+ )
798
+ )
799
  except Exception as e:
800
  _set_error(e)
801
  finally:
802
+ if helper_client is not None:
803
+ helper_client.close()
804
  _worker_done()
805
 
806
+ local_thread = threading.Thread(target=_local_worker, daemon=True)
807
+ helper_thread = threading.Thread(target=_helper_worker, daemon=True)
808
+ local_thread.start()
809
+ helper_thread.start()
810
 
811
  next_idx = 0
812
  try:
 
816
  next_idx not in ready
817
  and first_error is None
818
  and not cancel_event.is_set()
819
+ and workers_done < expected_workers
820
  ):
821
  cond.wait(timeout=0.1)
822
 
 
824
  break
825
 
826
  if next_idx in ready:
827
+ packet = ready.pop(next_idx)
828
+ buffered_chunks = len(ready)
829
  elif first_error is not None:
830
  logger.error(f"[{stream_id}] Parallel stream error: {first_error}")
831
  break
832
+ elif workers_done >= expected_workers:
833
  logger.error(
834
  f"[{stream_id}] Parallel stream ended with missing chunk index {next_idx}"
835
  )
 
837
  else:
838
  continue
839
 
840
+ logger.debug(
841
+ "[%s] stitch emit chunk %s/%s from %s (buffered=%s)",
842
+ stream_id,
843
+ next_idx + 1,
844
+ total_chunks,
845
+ packet.lane,
846
+ buffered_chunks,
847
+ )
848
+ yield packet.data
849
  next_idx += 1
850
+ stream_completed = (
851
+ next_idx >= total_chunks
852
+ and first_error is None
853
+ and not cancel_event.is_set()
854
+ )
855
  finally:
856
  cancel_event.set()
857
+
858
+ # For fast stop/cancel, signal helpers first; for normal completion, wait for
859
+ # workers to flush and then ask helpers to clear stream state.
860
+ if not stream_completed:
861
+ for base_url in helper_route_set:
862
+ _helper_cancel_stream(base_url, stream_id)
863
+
864
+ local_thread.join(timeout=1.0)
865
+ helper_thread.join(timeout=1.0)
866
+
867
+ if stream_completed:
868
+ for base_url in helper_route_set:
869
+ _helper_complete_stream(base_url, stream_id)
870
+
871
+ with _stream_routes_lock:
872
+ _stream_helper_routes.pop(stream_id, None)
873
  _active_streams.pop(stream_id, None)
874
 
875
 
 
880
  async def stream_text_to_speech(
881
  text: str = Form(...),
882
  voice_ref: Optional[UploadFile] = File(None),
883
+ voice_name: str = Form("default"),
884
  max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
885
  repetition_penalty: float = Form(Config.REPETITION_PENALTY),
886
  ):
 
896
  if not text or not text.strip():
897
  raise HTTPException(400, "Text is required")
898
 
899
+ voice = await _resolve_voice(voice_ref, voice_name, wrapper)
900
  stream_id = uuid.uuid4().hex[:12]
901
 
902
  return StreamingResponse(
 
918
  async def parallel_stream_text_to_speech(
919
  text: str = Form(...),
920
  voice_ref: Optional[UploadFile] = File(None),
921
+ voice_name: str = Form("default"),
922
  max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
923
  repetition_penalty: float = Form(Config.REPETITION_PENALTY),
924
  helper_url: Optional[str] = Form(None),
925
  ):
926
+ """Additive 2-way split stream mode (primary + helper)."""
927
  wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
928
  if not wrapper:
929
  raise HTTPException(503, "Model not loaded")
 
948
  except Exception as e:
949
  logger.error(f"Parallel voice encoding failed: {e}")
950
  raise HTTPException(400, "Could not process voice file for parallel mode")
951
+ else:
952
+ try:
953
+ selected_voice_id = wrapper.resolve_voice_id(voice_name)
954
+ local_voice = wrapper.get_builtin_voice(selected_voice_id)
955
+ except ValueError as e:
956
+ raise HTTPException(status_code=400, detail=str(e))
957
+
958
+ # Ensure helper uses the same selected built-in voice.
959
+ if selected_voice_id != wrapper.default_voice_name:
960
+ helper_voice_bytes = wrapper.get_builtin_voice_bytes(selected_voice_id)
961
+ if not helper_voice_bytes:
962
+ raise HTTPException(
963
+ status_code=400,
964
+ detail=f"Selected voice '{voice_name}' is unavailable for helper registration",
965
+ )
966
 
967
  resolved_helper = (helper_url or Config.HELPER_BASE_URL).strip()
968
  if not resolved_helper:
969
  raise HTTPException(
970
  400,
971
+ "No helper configured. Set CB_HELPER_BASE_URL or pass helper_url.",
972
  )
973
 
974
  stream_id = uuid.uuid4().hex[:12]
975
  return StreamingResponse(
976
+ _parallel_two_way_stream_generator(
977
  wrapper=wrapper,
978
  text=text,
979
  local_voice=local_voice,
 
988
  "Content-Disposition": "attachment; filename=tts_parallel_stream.mp3",
989
  "Transfer-Encoding": "chunked",
990
  "X-Stream-Id": stream_id,
991
+ "X-Streaming-Type": "parallel-2way",
992
  "Cache-Control": "no-cache",
993
  },
994
  )
 
1051
  stream_id = (http_request.query_params.get("stream_id") or "").strip()
1052
  if stream_id:
1053
  with _internal_cancel_lock:
1054
+ _purge_internal_stream_state_locked()
1055
+ keys, _ = _internal_stream_voice_keys.get(stream_id, (set(), 0.0))
1056
  keys.add(voice_key)
1057
+ _internal_stream_voice_keys[stream_id] = (
1058
+ keys,
1059
+ time.time() + max(1, Config.INTERNAL_STREAM_STATE_TTL_SEC),
1060
+ )
1061
 
1062
  return {"status": "registered", "voice_key": voice_key}
1063
 
 
1074
  raise HTTPException(403, "Forbidden")
1075
 
1076
  with _internal_cancel_lock:
1077
+ _purge_internal_stream_state_locked()
1078
  if request.stream_id in _internal_cancelled_streams:
1079
  raise HTTPException(409, "Stream already cancelled")
1080
+ _touch_internal_stream_voice_keys_locked(request.stream_id)
1081
 
1082
  wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
1083
  if not wrapper:
 
1086
  voice_profile = wrapper.default_voice
1087
  if request.voice_key:
1088
  cached_voice = wrapper._voice_cache.get(request.voice_key)
1089
+ if cached_voice is None:
1090
+ # Built-in voices are permanent in wrapper registry even if TTL cache entry expired.
1091
+ cached_voice = wrapper.get_builtin_voice_by_hash(request.voice_key)
1092
  if cached_voice is None:
1093
  raise HTTPException(409, "Voice key expired or not found")
1094
  voice_profile = cached_voice
 
1129
  raise HTTPException(403, "Forbidden")
1130
 
1131
  with _internal_cancel_lock:
1132
+ _purge_internal_stream_state_locked()
1133
+ _internal_cancelled_streams[stream_id] = (
1134
+ time.time() + max(1, Config.INTERNAL_CANCEL_TTL_SEC)
1135
+ )
1136
  _internal_stream_voice_keys.pop(stream_id, None)
1137
  return {"status": "cancelled", "stream_id": stream_id}
1138
 
1139
 
1140
+ @app.post("/internal/chunk/complete/{stream_id}")
1141
+ async def internal_chunk_complete(stream_id: str, http_request: Request):
1142
+ """Best-effort immediate cleanup after stream completes normally."""
1143
+ if Config.INTERNAL_SHARED_SECRET:
1144
+ provided = http_request.headers.get("X-Internal-Secret", "")
1145
+ if provided != Config.INTERNAL_SHARED_SECRET:
1146
+ raise HTTPException(403, "Forbidden")
1147
+
1148
+ with _internal_cancel_lock:
1149
+ _purge_internal_stream_state_locked()
1150
+ _clear_internal_stream_state_locked(stream_id)
1151
+ return {"status": "completed", "stream_id": stream_id}
1152
+
1153
+
1154
  @app.post("/v1/audio/speech")
1155
  async def openai_compatible_tts(request: TTSJsonRequest):
1156
  """OpenAI-compatible streaming endpoint (JSON body, no file upload).
1157
 
1158
+ Uses built-in voice selection via `voice`. For voice cloning, use /tts/stream with FormData.
1159
  """
1160
  wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
1161
  if not wrapper:
1162
  raise HTTPException(503, "Model not loaded")
1163
 
1164
+ try:
1165
+ selected_voice = wrapper.get_builtin_voice(request.voice)
1166
+ except ValueError as e:
1167
+ raise HTTPException(400, str(e))
1168
+
1169
  stream_id = uuid.uuid4().hex[:12]
1170
 
1171
  return StreamingResponse(
1172
  _pipeline_stream_generator(
1173
+ wrapper, request.text, selected_voice,
1174
  request.max_new_tokens, request.repetition_penalty, stream_id,
1175
  ),
1176
  media_type="audio/mpeg",
 
1195
  event = _active_streams.get(stream_id)
1196
  if event:
1197
  event.set()
1198
+ with _stream_routes_lock:
1199
+ helper_routes = set(_stream_helper_routes.pop(stream_id, set()))
1200
+ for helper_url in helper_routes:
1201
+ _helper_cancel_stream(helper_url, stream_id)
1202
  logger.info(f"Stream {stream_id} cancelled by client")
1203
  return {"status": "stopped", "stream_id": stream_id}
1204
  return {"status": "not_found", "stream_id": stream_id}
 
1207
  @app.post("/tts/stop")
1208
  async def stop_all_streams():
1209
  """Emergency stop: cancel ALL active TTS streams."""
1210
+ active_items = list(_active_streams.items())
1211
+ count = len(active_items)
1212
+ with _stream_routes_lock:
1213
+ stream_routes = {sid: set(urls) for sid, urls in _stream_helper_routes.items()}
1214
+ _stream_helper_routes.clear()
1215
+
1216
+ for sid, event in active_items:
1217
  event.set()
1218
+ for helper_url in stream_routes.get(sid, set()):
1219
+ _helper_cancel_stream(helper_url, sid)
1220
  _active_streams.clear()
1221
  logger.info(f"Stopped all streams ({count} active)")
1222
  return {"status": "stopped_all", "count": count}
chatterbox_wrapper.py CHANGED
@@ -27,6 +27,7 @@ import tempfile
27
  import time
28
  from collections import OrderedDict
29
  from dataclasses import dataclass
 
30
  from typing import Callable, Generator, Optional
31
 
32
  import librosa
@@ -48,6 +49,21 @@ _SUPPORTED_AUDIO_EXTENSIONS = {
48
  }
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # ═══════════════════════════════════════════════════════════════════
52
  # Data Structures
53
  # ═══════════════════════════════════════════════════════════════════
@@ -203,8 +219,15 @@ class ChatterboxWrapper:
203
  ttl_seconds=self.cfg.VOICE_CACHE_TTL_SEC,
204
  )
205
 
206
- logger.info("Encoding default reference voice …")
207
- self.default_voice = self._load_default_voice()
 
 
 
 
 
 
 
208
 
209
  logger.info("βœ… ChatterboxWrapper ready")
210
 
@@ -260,16 +283,190 @@ class ChatterboxWrapper:
260
  opts.enable_mem_reuse = True
261
  return opts
262
 
263
- # ─── Default voice ────────────────────────────────────────────
264
 
265
- def _load_default_voice(self) -> VoiceProfile:
266
  path = hf_hub_download(
267
  self.cfg.DEFAULT_VOICE_REPO,
268
  filename=self.cfg.DEFAULT_VOICE_FILE,
269
  cache_dir=self.cfg.MODELS_DIR,
270
  )
271
- audio, _ = librosa.load(path, sr=self.cfg.SAMPLE_RATE)
272
- return self._encode_audio_array(audio, audio_hash="__default__")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  # ─── Voice encoding ──────────────────────────────────────────
275
 
 
27
  import time
28
  from collections import OrderedDict
29
  from dataclasses import dataclass
30
+ from pathlib import Path
31
  from typing import Callable, Generator, Optional
32
 
33
  import librosa
 
49
  }
50
 
51
 
52
+ def _slugify(text: str) -> str:
53
+ buf = []
54
+ prev_underscore = False
55
+ for ch in text.strip().lower():
56
+ if ch.isalnum():
57
+ buf.append(ch)
58
+ prev_underscore = False
59
+ else:
60
+ if not prev_underscore:
61
+ buf.append("_")
62
+ prev_underscore = True
63
+ slug = "".join(buf).strip("_")
64
+ return slug or "voice"
65
+
66
+
67
  # ═══════════════════════════════════════════════════════════════════
68
  # Data Structures
69
  # ═══════════════════════════════════════════════════════════════════
 
219
  ttl_seconds=self.cfg.VOICE_CACHE_TTL_SEC,
220
  )
221
 
222
+ self._builtin_voice_profiles: dict[str, VoiceProfile] = {}
223
+ self._builtin_voice_bytes: dict[str, bytes] = {}
224
+ self._builtin_voice_by_hash: dict[str, VoiceProfile] = {}
225
+ self._voice_alias_to_id: dict[str, str] = {}
226
+ self._builtin_voice_catalog: list[dict] = []
227
+ self._default_voice_id: str = "default"
228
+
229
+ logger.info("Loading built-in voices (HF default + local samples) …")
230
+ self.default_voice = self._load_builtin_voices()
231
 
232
  logger.info("βœ… ChatterboxWrapper ready")
233
 
 
283
  opts.enable_mem_reuse = True
284
  return opts
285
 
286
+ # ─── Built-in voices (HF default + local samples) ────────────
287
 
288
+ def _download_hf_default_voice_bytes(self) -> bytes:
289
  path = hf_hub_download(
290
  self.cfg.DEFAULT_VOICE_REPO,
291
  filename=self.cfg.DEFAULT_VOICE_FILE,
292
  cache_dir=self.cfg.MODELS_DIR,
293
  )
294
+ return Path(path).read_bytes()
295
+
296
+ def _list_local_voice_paths(self) -> list[Path]:
297
+ wrapper_dir = Path(__file__).resolve().parent
298
+
299
+ # Support both module-level and repo-root deployment layouts.
300
+ candidates = []
301
+ for d in (wrapper_dir, Path.cwd().resolve(), wrapper_dir.parent):
302
+ try:
303
+ resolved = d.resolve()
304
+ except Exception:
305
+ continue
306
+ if resolved.is_dir() and resolved not in candidates:
307
+ candidates.append(resolved)
308
+
309
+ voices: list[Path] = []
310
+ seen_real_paths: set[str] = set()
311
+ for root in candidates:
312
+ try:
313
+ entries = sorted(root.iterdir(), key=lambda x: x.name.lower())
314
+ except Exception:
315
+ continue
316
+
317
+ for p in entries:
318
+ if not p.is_file():
319
+ continue
320
+ if p.suffix.lower() not in _SUPPORTED_AUDIO_EXTENSIONS:
321
+ continue
322
+ real_path = str(p.resolve())
323
+ if real_path in seen_real_paths:
324
+ continue
325
+ seen_real_paths.add(real_path)
326
+ voices.append(p)
327
+
328
+ logger.info(
329
+ "Local voice scan complete: %s files across %s",
330
+ len(voices),
331
+ [str(x) for x in candidates],
332
+ )
333
+ return voices
334
+
335
+ def _make_unique_voice_id(self, preferred: str) -> str:
336
+ base = _slugify(preferred)
337
+ candidate = base
338
+ idx = 2
339
+ while candidate in self._builtin_voice_profiles:
340
+ candidate = f"{base}_{idx}"
341
+ idx += 1
342
+ return candidate
343
+
344
+ def _register_builtin_voice(
345
+ self,
346
+ *,
347
+ preferred_id: str,
348
+ display_name: str,
349
+ source: str,
350
+ source_ref: str,
351
+ audio_bytes: bytes,
352
+ is_default: bool = False,
353
+ ) -> str:
354
+ if not audio_bytes:
355
+ raise ValueError("Voice file is empty")
356
+
357
+ voice_id = self._make_unique_voice_id(preferred_id)
358
+ audio_hash = hashlib.md5(audio_bytes).hexdigest()
359
+
360
+ profile = self._voice_cache.get(audio_hash)
361
+ if profile is None:
362
+ audio = _load_audio_bytes(audio_bytes, sr=self.cfg.SAMPLE_RATE)
363
+ profile = self._encode_audio_array(audio, audio_hash=audio_hash)
364
+ self._voice_cache.put(audio_hash, profile)
365
+ else:
366
+ # Keep hash attached to cached profile for metadata/voice-key usage.
367
+ profile.audio_hash = audio_hash
368
+
369
+ self._builtin_voice_profiles[voice_id] = profile
370
+ self._builtin_voice_bytes[voice_id] = audio_bytes
371
+ if audio_hash:
372
+ self._builtin_voice_by_hash[audio_hash] = profile
373
+
374
+ aliases: list[str] = []
375
+ for alias in (voice_id, _slugify(Path(display_name).stem)):
376
+ if alias not in self._voice_alias_to_id:
377
+ self._voice_alias_to_id[alias] = voice_id
378
+ aliases.append(alias)
379
+
380
+ if is_default:
381
+ self._default_voice_id = voice_id
382
+ self._voice_alias_to_id["default"] = voice_id
383
+ if "default" not in aliases:
384
+ aliases.append("default")
385
+
386
+ self._builtin_voice_catalog.append(
387
+ {
388
+ "id": voice_id,
389
+ "display_name": display_name,
390
+ "source": source,
391
+ "source_ref": source_ref,
392
+ "aliases": aliases,
393
+ "voice_key": audio_hash,
394
+ }
395
+ )
396
+ return voice_id
397
+
398
+ def _load_builtin_voices(self) -> VoiceProfile:
399
+ # 1) HF default voice (kept as true default fallback)
400
+ hf_bytes = self._download_hf_default_voice_bytes()
401
+ self._register_builtin_voice(
402
+ preferred_id="default_hf_voice",
403
+ display_name=self.cfg.DEFAULT_VOICE_FILE,
404
+ source="huggingface",
405
+ source_ref=f"{self.cfg.DEFAULT_VOICE_REPO}:{self.cfg.DEFAULT_VOICE_FILE}",
406
+ audio_bytes=hf_bytes,
407
+ is_default=True,
408
+ )
409
+
410
+ # 2) Local voice samples placed next to app files
411
+ for path in self._list_local_voice_paths():
412
+ # Avoid duplicate entry if someone also copied default_voice.wav locally.
413
+ if path.name == self.cfg.DEFAULT_VOICE_FILE:
414
+ continue
415
+ try:
416
+ self._register_builtin_voice(
417
+ preferred_id=path.stem,
418
+ display_name=path.name,
419
+ source="local",
420
+ source_ref=str(path.name),
421
+ audio_bytes=path.read_bytes(),
422
+ is_default=False,
423
+ )
424
+ except Exception as e:
425
+ logger.warning(f"Skipping local voice {path.name}: {e}")
426
+
427
+ default_profile = self._builtin_voice_profiles.get(self._default_voice_id)
428
+ if default_profile is None:
429
+ raise RuntimeError("Default built-in voice could not be initialized")
430
+
431
+ logger.info(
432
+ f"Built-in voices loaded: {len(self._builtin_voice_catalog)} "
433
+ f"(default={self._default_voice_id})"
434
+ )
435
+ return default_profile
436
+
437
+ def list_builtin_voices(self) -> list[dict]:
438
+ """Return metadata for startup-preloaded voices."""
439
+ return [dict(v) for v in self._builtin_voice_catalog]
440
+
441
+ @property
442
+ def default_voice_name(self) -> str:
443
+ return self._default_voice_id
444
+
445
+ def resolve_voice_id(self, voice_name: Optional[str]) -> str:
446
+ if voice_name is None:
447
+ return self._default_voice_id
448
+ key = _slugify(str(voice_name))
449
+ if not key:
450
+ return self._default_voice_id
451
+ voice_id = self._voice_alias_to_id.get(key)
452
+ if voice_id is None:
453
+ available = ", ".join(sorted(self._voice_alias_to_id.keys()))
454
+ raise ValueError(f"Unknown voice '{voice_name}'. Available: {available}")
455
+ return voice_id
456
+
457
+ def get_builtin_voice(self, voice_name: Optional[str]) -> VoiceProfile:
458
+ voice_id = self.resolve_voice_id(voice_name)
459
+ profile = self._builtin_voice_profiles[voice_id]
460
+ if profile.audio_hash:
461
+ self._voice_cache.put(profile.audio_hash, profile)
462
+ return profile
463
+
464
+ def get_builtin_voice_bytes(self, voice_name: Optional[str]) -> Optional[bytes]:
465
+ voice_id = self.resolve_voice_id(voice_name)
466
+ return self._builtin_voice_bytes.get(voice_id)
467
+
468
+ def get_builtin_voice_by_hash(self, audio_hash: str) -> Optional[VoiceProfile]:
469
+ return self._builtin_voice_by_hash.get((audio_hash or "").strip())
470
 
471
  # ─── Voice encoding ──────────────────────────────────────────
472
 
config.py CHANGED
@@ -77,11 +77,14 @@ class Config:
77
  # Smaller chunks = faster TTFB (first audio arrives sooner)
78
  # ~200 chars β‰ˆ 1–2 sentences β‰ˆ fastest first-chunk on 2 vCPU
79
  MAX_CHUNK_CHARS: int = int(os.getenv("CB_MAX_CHUNK_CHARS", "100"))
80
- # Additive parallel mode (odd/even split across primary/helper).
81
  ENABLE_PARALLEL_MODE: bool = _get_bool("CB_ENABLE_PARALLEL_MODE", True)
82
- HELPER_BASE_URL: str = os.getenv("CB_HELPER_BASE_URL", "https://shadowhunter222-hello2.hf.space").strip()
83
  HELPER_TIMEOUT_SEC: float = float(os.getenv("CB_HELPER_TIMEOUT_SEC", "45"))
84
  HELPER_RETRY_ONCE: bool = _get_bool("CB_HELPER_RETRY_ONCE", True)
 
 
 
85
  # Optional shared secret for internal chunk endpoints.
86
  INTERNAL_SHARED_SECRET: str = os.getenv("CB_INTERNAL_SHARED_SECRET", "").strip()
87
 
@@ -91,10 +94,13 @@ class Config:
91
 
92
  ALLOWED_ORIGINS: list = [
93
  "https://toolboxesai.com",
 
 
 
94
  "http://localhost:8788", "http://127.0.0.1:8788",
95
  "http://localhost:5502", "http://127.0.0.1:5502",
96
  "http://localhost:5501", "http://127.0.0.1:5501",
97
  "http://localhost:5500", "http://127.0.0.1:5500",
98
  "http://localhost:5173", "http://127.0.0.1:5173",
99
  "http://localhost:7860", "http://127.0.0.1:7860",
100
- ]
 
77
  # Smaller chunks = faster TTFB (first audio arrives sooner)
78
  # ~200 chars β‰ˆ 1–2 sentences β‰ˆ fastest first-chunk on 2 vCPU
79
  MAX_CHUNK_CHARS: int = int(os.getenv("CB_MAX_CHUNK_CHARS", "100"))
80
+ # Additive parallel mode (2-way split: primary + helper).
81
  ENABLE_PARALLEL_MODE: bool = _get_bool("CB_ENABLE_PARALLEL_MODE", True)
82
+ HELPER_BASE_URL: str = os.getenv("CB_HELPER_BASE_URL", "https://shadowhunter222-chab2.hf.space").strip()
83
  HELPER_TIMEOUT_SEC: float = float(os.getenv("CB_HELPER_TIMEOUT_SEC", "45"))
84
  HELPER_RETRY_ONCE: bool = _get_bool("CB_HELPER_RETRY_ONCE", True)
85
+ # Internal housekeeping TTLs to avoid retaining stream metadata indefinitely.
86
+ INTERNAL_CANCEL_TTL_SEC: int = int(os.getenv("CB_INTERNAL_CANCEL_TTL_SEC", "120"))
87
+ INTERNAL_STREAM_STATE_TTL_SEC: int = int(os.getenv("CB_INTERNAL_STREAM_STATE_TTL_SEC", "600"))
88
  # Optional shared secret for internal chunk endpoints.
89
  INTERNAL_SHARED_SECRET: str = os.getenv("CB_INTERNAL_SHARED_SECRET", "").strip()
90
 
 
94
 
95
  ALLOWED_ORIGINS: list = [
96
  "https://toolboxesai.com",
97
+ "https://www.toolboxesai.com",
98
+ "www.toolboxesai.com",
99
+ "toolboxesai.com",
100
  "http://localhost:8788", "http://127.0.0.1:8788",
101
  "http://localhost:5502", "http://127.0.0.1:5502",
102
  "http://localhost:5501", "http://127.0.0.1:5501",
103
  "http://localhost:5500", "http://127.0.0.1:5500",
104
  "http://localhost:5173", "http://127.0.0.1:5173",
105
  "http://localhost:7860", "http://127.0.0.1:7860",
106
+ ]
her_prompt.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8eaabbeafe26ad6f78b56dcc32608763eeb69485db074c7136c6818f04a93ced
3
+ size 725328
ivr_female_prompt.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64953bf94657c4334532319fd4f20e9859c31af4445940916b04f129ef1f89e6
3
+ size 2779278
text_processor.py CHANGED
@@ -4,6 +4,26 @@ Chatterbox Turbo TTS β€” Text Processor
4
  Sanitizes raw input text and splits it into sentence-level chunks
5
  for streaming TTS. Paralinguistic tags ([laugh], [cough], …) are
6
  explicitly preserved so the model can render them.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  """
8
  import re
9
  from typing import List
@@ -47,20 +67,64 @@ _RE_EMOJI = re.compile(
47
  r"]+", re.UNICODE,
48
  )
49
  _RE_HTML_ENTITY = re.compile(r"&(?:#x?[\da-fA-F]+|\w+);")
 
 
 
50
  _HTML_ENTITIES = {
51
  "&amp;": " and ", "&lt;": " less than ", "&gt;": " greater than ",
52
  "&nbsp;": " ", "&quot;": '"', "&apos;": "'",
53
- "&mdash;": ", ", "&ndash;": ", ", "&hellip;": ".",
54
  }
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # β€” Punctuation normalization
57
- _RE_REPEATED_DOT = re.compile(r"\.{2,}")
58
- _RE_REPEATED_EXCLAM = re.compile(r"!{2,}")
59
- _RE_REPEATED_QUEST = re.compile(r"\?{2,}")
60
- _RE_REPEATED_SEMI = re.compile(r";{2,}")
61
- _RE_REPEATED_COLON = re.compile(r":{2,}")
62
- _RE_REPEATED_COMMA = re.compile(r",{2,}")
63
- _RE_REPEATED_DASH = re.compile(r"-{3,}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # β€” Whitespace
66
  _RE_MULTI_SPACE = re.compile(r"[ \t]+")
@@ -68,7 +132,14 @@ _RE_MULTI_NEWLINE = re.compile(r"\n{3,}")
68
  _RE_SPACE_BEFORE_PUN = re.compile(r"\s+([.!?,;:])")
69
 
70
  # β€” Sentence boundary (split point)
71
- _RE_SENTENCE_SPLIT = re.compile(r"(?<=[.!?;:])\s+")
 
 
 
 
 
 
 
72
 
73
  _MIN_MERGE_WORDS = 5
74
 
@@ -78,11 +149,22 @@ _MIN_MERGE_WORDS = 5
78
  # ═══════════════════════════════════════════════════════════════════
79
 
80
  def sanitize(text: str) -> str:
81
- """Clean raw input for TTS while preserving paralinguistic tags."""
 
 
 
 
82
  if not text:
83
  return text
84
 
85
- # 1. Protect paralinguistic tags by replacing with placeholders
 
 
 
 
 
 
 
86
  tags_found: list[tuple[int, str]] = []
87
  def _protect_tag(m):
88
  idx = len(tags_found)
@@ -90,7 +172,16 @@ def sanitize(text: str) -> str:
90
  return f"Β§TAG{idx}Β§"
91
  text = _RE_PARA_TAG.sub(_protect_tag, text)
92
 
93
- # 2. Strip non-speakable structures
 
 
 
 
 
 
 
 
 
94
  text = _RE_URL.sub("", text)
95
  text = _RE_CODE_BLOCK.sub("", text)
96
  text = _RE_IMAGE.sub(lambda m: m.group(1) if m.group(1) else "", text)
@@ -107,29 +198,36 @@ def sanitize(text: str) -> str:
107
  text = _RE_BULLET.sub("", text)
108
  text = _RE_ORDERED.sub("", text)
109
 
110
- # 3. Emojis, hashtags
111
  text = _RE_EMOJI.sub("", text)
112
  text = re.sub(r"#(\w+)", r"\1", text)
113
 
114
- # 4. HTML entities
115
  text = _RE_HTML_ENTITY.sub(lambda m: _HTML_ENTITIES.get(m.group(0), ""), text)
116
 
117
- # 5. Collapse repeated punctuation
118
- text = _RE_REPEATED_DOT.sub(".", text)
119
- text = _RE_REPEATED_EXCLAM.sub("!", text)
120
- text = _RE_REPEATED_QUEST.sub("?", text)
121
- text = _RE_REPEATED_SEMI.sub(";", text)
122
- text = _RE_REPEATED_COLON.sub(":", text)
123
- text = _RE_REPEATED_COMMA.sub(",", text)
124
- text = _RE_REPEATED_DASH.sub("β€”", text)
 
 
125
 
126
- # 6. Whitespace
127
  text = _RE_SPACE_BEFORE_PUN.sub(r"\1", text)
128
  text = _RE_MULTI_SPACE.sub(" ", text)
129
  text = _RE_MULTI_NEWLINE.sub("\n\n", text)
130
  text = text.strip()
131
 
132
- # 7. Restore paralinguistic tags
 
 
 
 
 
133
  for idx, original in tags_found:
134
  text = text.replace(f"Β§TAG{idx}Β§", original)
135
 
@@ -140,13 +238,25 @@ def split_for_streaming(text: str, max_chars: int = Config.MAX_CHUNK_CHARS) -> L
140
  """Split sanitized text into sentence-level chunks for streaming.
141
 
142
  Strategy:
143
- 1. Split on sentence-ending punctuation boundaries
144
- 2. Enforce max_chars per chunk (split long sentences on commas / spaces)
145
- 3. Merge short chunks (≀5 words) with the next to avoid tiny segments
 
 
 
146
  """
147
  if not text:
148
  return []
149
 
 
 
 
 
 
 
 
 
 
150
  # Step 1: sentence split
151
  raw_chunks = _RE_SENTENCE_SPLIT.split(text)
152
  raw_chunks = [c.strip() for c in raw_chunks if c.strip()]
@@ -161,23 +271,30 @@ def split_for_streaming(text: str, max_chars: int = Config.MAX_CHUNK_CHARS) -> L
161
 
162
  # Step 3: merge short chunks
163
  if len(sized) <= 1:
164
- return sized
165
-
166
- merged: List[str] = []
167
- carry = ""
168
- for i, chunk in enumerate(sized):
 
 
 
 
 
 
 
169
  if carry:
170
- chunk = carry + " " + chunk
171
- carry = ""
172
- if len(chunk.split()) <= _MIN_MERGE_WORDS and i < len(sized) - 1:
173
- carry = chunk
174
- else:
175
- merged.append(chunk)
176
- if carry:
177
- if merged:
178
- merged[-1] += " " + carry
179
- else:
180
- merged.append(carry)
181
 
182
  return merged
183
 
@@ -191,16 +308,41 @@ def _break_long_chunk(text: str, max_chars: int) -> List[str]:
191
  parts: List[str] = []
192
  remaining = text
193
  while len(remaining) > max_chars:
194
- # Try comma first
195
- pos = remaining.rfind(",", 0, max_chars)
196
- if pos == -1:
197
- pos = remaining.rfind(" ", 0, max_chars)
198
- if pos == -1:
199
- pos = max_chars # hard break
200
- segment = remaining[:pos].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  if segment:
202
  parts.append(segment)
203
- remaining = remaining[pos:].lstrip(", ")
204
  if remaining.strip():
205
  parts.append(remaining.strip())
206
  return parts
 
4
  Sanitizes raw input text and splits it into sentence-level chunks
5
  for streaming TTS. Paralinguistic tags ([laugh], [cough], …) are
6
  explicitly preserved so the model can render them.
7
+
8
+ Punctuation Philosophy (based on Resemble AI recommendations):
9
+ βœ… PRESERVE (benefits prosody):
10
+ β€’ Ellipsis ... β†’ dramatic pause, trailing thought, hesitation
11
+ β€’ Em dash β€” β†’ abrupt transition, dramatic break
12
+ β€’ Comma , β†’ short natural pause / breathing point
13
+ β€’ Period . β†’ full stop, pitch drop, sentence boundary
14
+ β€’ ! and ? β†’ exclamatory / interrogative inflection
15
+ β€’ Semicolon ; β†’ medium pause, clause bridge (NOT a split point)
16
+ β€’ Colon : β†’ medium pause, introduces explanation (NOT a split point)
17
+ β€’ Parentheses () β†’ quieter/explanatory tone shift
18
+ β€’ Quotes "" β†’ dialogue cue
19
+ β€’ Apostrophe ' β†’ contractions (don't, it's)
20
+ β€’ CAPS words β†’ emphasis / volume increase
21
+
22
+ ❌ FILTER (harms output):
23
+ β€’ Excessive repeated punctuation (!!!! β†’ !, ???? β†’ ?, ,,, β†’ ,)
24
+ β€’ 4+ dots (.... β†’ ...)
25
+ β€’ Emojis, URLs, markdown, HTML tags
26
+ β€’ Non-standard Unicode punctuation (guillemets, etc.)
27
  """
28
  import re
29
  from typing import List
 
67
  r"]+", re.UNICODE,
68
  )
69
  _RE_HTML_ENTITY = re.compile(r"&(?:#x?[\da-fA-F]+|\w+);")
70
+
71
+ # HTML entities β†’ speakable replacements
72
+ # NOTE: &hellip; β†’ "..." (preserves dramatic pause), &mdash;/&ndash; β†’ "β€”" (preserves dramatic break)
73
  _HTML_ENTITIES = {
74
  "&amp;": " and ", "&lt;": " less than ", "&gt;": " greater than ",
75
  "&nbsp;": " ", "&quot;": '"', "&apos;": "'",
76
+ "&mdash;": "β€”", "&ndash;": "β€”", "&hellip;": "...",
77
  }
78
 
79
+ # β€” Smart/curly quote normalization β†’ ASCII equivalents
80
+ # These Unicode variants may confuse the tokenizer; normalizing ensures clean input.
81
+ _SMART_QUOTE_MAP = str.maketrans({
82
+ "\u201c": '"', # " left double quotation mark
83
+ "\u201d": '"', # " right double quotation mark
84
+ "\u2018": "'", # ' left single quotation mark
85
+ "\u2019": "'", # ' right single quotation mark
86
+ "\u00ab": '"', # Β« left guillemet
87
+ "\u00bb": '"', # Β» right guillemet
88
+ "\u201e": '"', # β€ž double low quotation mark
89
+ "\u201f": '"', # β€Ÿ double high reversed quotation mark
90
+ "\u2032": "'", # β€² prime
91
+ "\u2033": '"', # β€³ double prime
92
+ "\u2013": "β€”", # – en dash β†’ em dash (dramatic pause)
93
+ "\u2014": "β€”", # β€” em dash (keep as-is after mapping)
94
+ "\u2026": "...", # … horizontal ellipsis β†’ three dots
95
+ })
96
+
97
+ # β€” ALL CAPS normalization
98
+ # Words entirely in caps (length >= 4) often get spelled out by the TTS engine (e.g. NOTHING).
99
+ # By converting them to Title Case, they'll be processed naturally as words.
100
+ _RE_ALL_CAPS = re.compile(r"\b[A-Z]{4,}\b")
101
+
102
  # β€” Punctuation normalization
103
+ # Ellipsis (... / ..) is PRESERVED β€” it creates dramatic pauses in Chatterbox.
104
+ # Only 4+ dots are excessive and get capped to standard ellipsis.
105
+ _RE_EXCESSIVE_DOTS = re.compile(r"\.{4,}") # ....+ β†’ ... (cap excessive)
106
+ _RE_NORMALIZE_DOTS = re.compile(r"\.{2,3}") # .. or ... β†’ ... (standardize)
107
+ _RE_REPEATED_EXCLAM = re.compile(r"!{2,}") # !! β†’ !
108
+ _RE_REPEATED_QUEST = re.compile(r"\?{2,}") # ?? β†’ ?
109
+ _RE_REPEATED_SEMI = re.compile(r";{2,}") # ;; β†’ ;
110
+ _RE_REPEATED_COLON = re.compile(r":{2,}") # :: β†’ :
111
+ _RE_REPEATED_COMMA = re.compile(r",{2,}") # ,, β†’ ,
112
+ _RE_REPEATED_DASH = re.compile(r"-{3,}") # --- β†’ β€” (em dash)
113
+
114
+ # β€” Abbreviation protection
115
+ # Common abbreviations ending in "." that should NOT trigger sentence splitting.
116
+ # These get a placeholder before splitting, then get restored.
117
+ _ABBREVIATIONS = (
118
+ "Mr", "Mrs", "Ms", "Dr", "Prof", "Sr", "Jr", "St", "Ave", "Blvd",
119
+ "vs", "etc", "approx", "dept", "est", "govt", "inc", "corp", "ltd",
120
+ "Jan", "Feb", "Mar", "Apr", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
121
+ "Gen", "Col", "Sgt", "Capt", "Lt", "Cmdr", "Adm",
122
+ "Fig", "Vol", "No", "Ref", "Rev", "Ph",
123
+ )
124
+ _RE_ABBREV = re.compile(
125
+ r"\b(" + "|".join(re.escape(a) for a in _ABBREVIATIONS) + r")\.",
126
+ re.IGNORECASE,
127
+ )
128
 
129
  # β€” Whitespace
130
  _RE_MULTI_SPACE = re.compile(r"[ \t]+")
 
132
  _RE_SPACE_BEFORE_PUN = re.compile(r"\s+([.!?,;:])")
133
 
134
  # β€” Sentence boundary (split point)
135
+ # Split ONLY on true sentence-ending punctuation: . ! ?
136
+ # Semicolons and colons are clause connectors β€” they bridge related thoughts
137
+ # and should NOT be used as split points (creates choppy, unnatural fragments).
138
+ # Ellipsis (...) is also intentionally excluded from splitting: letting it split the stream
139
+ # creates a compound lag between chunks, making the pause artificially excessive.
140
+ _RE_SENTENCE_SPLIT = re.compile(
141
+ r"""(?:(?<=[.!?])(?<!\.\.\.)|(?<=[.!?][)\]"'])(?<!\.\.\.\.))\s+"""
142
+ )
143
 
144
  _MIN_MERGE_WORDS = 5
145
 
 
149
  # ═══════════════════════════════════════════════════════════════════
150
 
151
  def sanitize(text: str) -> str:
152
+ """Clean raw input for TTS while preserving prosody-beneficial punctuation.
153
+
154
+ Preserves: ellipsis (...), em dashes (β€”), commas, periods, !, ?, ;, :, quotes.
155
+ Removes: emojis, URLs, markdown, HTML, excessive repeated punctuation.
156
+ """
157
  if not text:
158
  return text
159
 
160
+ # 0. Normalize smart/curly quotes and Unicode punctuation FIRST
161
+ # This ensures downstream regex works on clean ASCII-like punctuation.
162
+ text = text.translate(_SMART_QUOTE_MAP)
163
+
164
+ # 1. Normalize ALL CAPS words to Title Case to prevent spelling out
165
+ text = _RE_ALL_CAPS.sub(lambda m: m.group(0).capitalize(), text)
166
+
167
+ # 2. Protect paralinguistic tags by replacing with placeholders
168
  tags_found: list[tuple[int, str]] = []
169
  def _protect_tag(m):
170
  idx = len(tags_found)
 
172
  return f"Β§TAG{idx}Β§"
173
  text = _RE_PARA_TAG.sub(_protect_tag, text)
174
 
175
+ # 3. Protect abbreviations from sentence-boundary splitting
176
+ # "Dr. Smith" β†’ "DrΒ§ Smith" (restored later)
177
+ abbrevs_found: list[tuple[int, str]] = []
178
+ def _protect_abbrev(m):
179
+ idx = len(abbrevs_found)
180
+ abbrevs_found.append((idx, m.group(0)))
181
+ return f"{m.group(1)}Β§ABR{idx}Β§"
182
+ text = _RE_ABBREV.sub(_protect_abbrev, text)
183
+
184
+ # 4. Strip non-speakable structures
185
  text = _RE_URL.sub("", text)
186
  text = _RE_CODE_BLOCK.sub("", text)
187
  text = _RE_IMAGE.sub(lambda m: m.group(1) if m.group(1) else "", text)
 
198
  text = _RE_BULLET.sub("", text)
199
  text = _RE_ORDERED.sub("", text)
200
 
201
+ # 5. Emojis, hashtags
202
  text = _RE_EMOJI.sub("", text)
203
  text = re.sub(r"#(\w+)", r"\1", text)
204
 
205
+ # 6. HTML entities β†’ speakable text
206
  text = _RE_HTML_ENTITY.sub(lambda m: _HTML_ENTITIES.get(m.group(0), ""), text)
207
 
208
+ # 7. Normalize punctuation (PRESERVE prosody-beneficial marks)
209
+ # Order matters: handle excessive dots first, then standardize ellipsis.
210
+ text = _RE_EXCESSIVE_DOTS.sub("...", text) # ....+ β†’ ... (cap)
211
+ text = _RE_NORMALIZE_DOTS.sub("...", text) # .. or ... β†’ ... (standardize)
212
+ text = _RE_REPEATED_EXCLAM.sub("!", text) # !! β†’ !
213
+ text = _RE_REPEATED_QUEST.sub("?", text) # ?? β†’ ?
214
+ text = _RE_REPEATED_SEMI.sub(";", text) # ;; β†’ ;
215
+ text = _RE_REPEATED_COLON.sub(":", text) # :: β†’ :
216
+ text = _RE_REPEATED_COMMA.sub(",", text) # ,, β†’ ,
217
+ text = _RE_REPEATED_DASH.sub("β€”", text) # --- β†’ em dash
218
 
219
+ # 8. Whitespace cleanup
220
  text = _RE_SPACE_BEFORE_PUN.sub(r"\1", text)
221
  text = _RE_MULTI_SPACE.sub(" ", text)
222
  text = _RE_MULTI_NEWLINE.sub("\n\n", text)
223
  text = text.strip()
224
 
225
+ # 9. Restore abbreviations
226
+ for idx, original in abbrevs_found:
227
+ # Restore the full abbreviation with its period
228
+ text = text.replace(f"Β§ABR{idx}Β§", ".")
229
+
230
+ # 10. Restore paralinguistic tags
231
  for idx, original in tags_found:
232
  text = text.replace(f"Β§TAG{idx}Β§", original)
233
 
 
238
  """Split sanitized text into sentence-level chunks for streaming.
239
 
240
  Strategy:
241
+ 1. Protect abbreviation dots (Mr., Dr., etc.) from triggering splits
242
+ 2. Split on sentence-ending punctuation boundaries (. ! ?)
243
+ β€” NOT on semicolons, colons, or ellipsis (those are non-breaking boundaries)
244
+ 3. Enforce max_chars per chunk (split long sentences on commas / spaces)
245
+ 4. Merge short chunks (≀5 words) with the next to avoid tiny segments
246
+ 5. Restore abbreviation dots
247
  """
248
  if not text:
249
  return []
250
 
251
+ # Step 0: protect abbreviation dots from sentence-boundary splitting
252
+ # "Mr. Smith" β†’ "MrΒ§ABRSΒ§ Smith" (prevents split on that period)
253
+ abbrev_placeholders: list[tuple[int, str]] = []
254
+ def _protect_abbrev_split(m):
255
+ idx = len(abbrev_placeholders)
256
+ abbrev_placeholders.append((idx, m.group(0)))
257
+ return f"{m.group(1)}Β§ABRS{idx}Β§"
258
+ text = _RE_ABBREV.sub(_protect_abbrev_split, text)
259
+
260
  # Step 1: sentence split
261
  raw_chunks = _RE_SENTENCE_SPLIT.split(text)
262
  raw_chunks = [c.strip() for c in raw_chunks if c.strip()]
 
271
 
272
  # Step 3: merge short chunks
273
  if len(sized) <= 1:
274
+ merged = sized
275
+ else:
276
+ merged = []
277
+ carry = ""
278
+ for i, chunk in enumerate(sized):
279
+ if carry:
280
+ chunk = carry + " " + chunk
281
+ carry = ""
282
+ if len(chunk.split()) <= _MIN_MERGE_WORDS and i < len(sized) - 1:
283
+ carry = chunk
284
+ else:
285
+ merged.append(chunk)
286
  if carry:
287
+ if merged:
288
+ merged[-1] += " " + carry
289
+ else:
290
+ merged.append(carry)
291
+
292
+ # Step 4: restore abbreviation dots
293
+ if abbrev_placeholders:
294
+ for i, chunk in enumerate(merged):
295
+ for idx, original in abbrev_placeholders:
296
+ chunk = chunk.replace(f"Β§ABRS{idx}Β§", ".")
297
+ merged[i] = chunk
298
 
299
  return merged
300
 
 
308
  parts: List[str] = []
309
  remaining = text
310
  while len(remaining) > max_chars:
311
+ break_pos = -1
312
+ include_break_char = False
313
+
314
+ # Prefer punctuation/pauses first to keep prosody natural.
315
+ for marker in (",", ";", ":", "β€”", "-", "!", "?"):
316
+ pos = remaining.rfind(marker, 0, max_chars)
317
+ if pos > break_pos:
318
+ break_pos = pos
319
+ include_break_char = True
320
+
321
+ # Then prefer nearest space before limit.
322
+ space_pos = remaining.rfind(" ", 0, max_chars)
323
+ if space_pos > break_pos:
324
+ break_pos = space_pos
325
+ include_break_char = False
326
+
327
+ # If nothing before limit, look slightly ahead to avoid mid-word cuts.
328
+ if break_pos == -1:
329
+ forward_limit = min(len(remaining), max_chars + 24)
330
+ m = re.search(r"[\s,;:!?]", remaining[max_chars:forward_limit])
331
+ if m:
332
+ break_pos = max_chars + m.start()
333
+ include_break_char = remaining[break_pos] in ",;:!?"
334
+ else:
335
+ break_pos = max_chars
336
+ include_break_char = False
337
+
338
+ cut_at = break_pos + (1 if include_break_char else 0)
339
+ if cut_at <= 0:
340
+ cut_at = min(max_chars, len(remaining))
341
+
342
+ segment = remaining[:cut_at].strip()
343
  if segment:
344
  parts.append(segment)
345
+ remaining = remaining[cut_at:].lstrip()
346
  if remaining.strip():
347
  parts.append(remaining.strip())
348
  return parts