ShadowHunter222 commited on
Commit
d61edf1
Β·
verified Β·
1 Parent(s): 35e362a

Upload 6 files

Browse files
Files changed (6) hide show
  1. Dockerfile +36 -0
  2. app.py +915 -0
  3. chatterbox_wrapper.py +534 -0
  4. config.py +100 -0
  5. requirements.txt +24 -0
  6. text_processor.py +206 -0
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ══════════════════════════════════════════════════════════════
2
+ # Chatterbox Turbo TTS β€” CPU-Optimised Docker Image
3
+ # ══════════════════════════════════════════════════════════════
4
+ FROM python:3.11-slim
5
+
6
+ # Audio codec libraries for soundfile/librosa
7
+ RUN apt-get update && \
8
+ apt-get install -y --no-install-recommends libsndfile1 ffmpeg && \
9
+ rm -rf /var/lib/apt/lists/*
10
+
11
+ WORKDIR /app
12
+
13
+ # Install PyTorch CPU first (from dedicated index for smaller size)
14
+ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
15
+
16
+ # Install remaining dependencies
17
+ COPY requirements.txt .
18
+ RUN pip install --no-cache-dir -r requirements.txt
19
+
20
+ # Copy application code
21
+ COPY config.py text_processor.py chatterbox_wrapper.py app.py ./
22
+
23
+ # Pre-download ONNX models + tokenizer at build time
24
+ RUN python -c "\
25
+ from chatterbox_wrapper import ChatterboxWrapper; \
26
+ ChatterboxWrapper(download_only=True); \
27
+ print('Models pre-downloaded successfully')"
28
+
29
+ # Prevent thread oversubscription in production
30
+ ENV OMP_NUM_THREADS=1
31
+ ENV MKL_NUM_THREADS=1
32
+ ENV OPENBLAS_NUM_THREADS=1
33
+
34
+ EXPOSE 7860
35
+
36
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
app.py ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatterbox Turbo TTS -- FastAPI Server
3
+ ======================================
4
+ Production-ready API with true real-time MP3 streaming,
5
+ in-memory voice cloning, and fully non-blocking inference.
6
+
7
+ Endpoints:
8
+ GET /health -> health check + optional warmup
9
+ GET /info -> model info, supported tags, parameters
10
+ POST /tts -> full audio response (WAV/MP3/FLAC)
11
+ POST /tts/stream -> chunked MP3 streaming (MediaSource-ready)
12
+ POST /tts/true-stream -> alias for /tts/stream (Kokoro compat)
13
+ POST /tts/stop/{stream_id}-> cancel a specific active stream
14
+ POST /tts/stop -> cancel ALL active streams
15
+ POST /v1/audio/speech -> OpenAI-compatible streaming
16
+ """
17
+ import asyncio
18
+ import io
19
+ import json
20
+ import logging
21
+ import queue as stdlib_queue
22
+ import threading
23
+ import time
24
+ import urllib.error
25
+ import urllib.parse
26
+ import urllib.request
27
+ import uuid
28
+ from concurrent.futures import ThreadPoolExecutor
29
+ from typing import Generator, Optional
30
+
31
+ import numpy as np
32
+ import soundfile as sf
33
+ from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
34
+ from fastapi.responses import Response, StreamingResponse
35
+ from contextlib import asynccontextmanager
36
+
37
+ from config import Config
38
+ from chatterbox_wrapper import ChatterboxWrapper, GenerationCancelled, VoiceProfile
39
+ import text_processor
40
+
41
+ # ── Logging ───────────────────────────────────────────────────────
42
+ logging.basicConfig(
43
+ level=logging.INFO,
44
+ format="%(asctime)s β”‚ %(levelname)-7s β”‚ %(name)s β”‚ %(message)s",
45
+ datefmt="%H:%M:%S",
46
+ )
47
+ logger = logging.getLogger(__name__)
48
+
49
+ # ── Thread pool for CPU-bound inference ───────────────────────────
50
+ tts_executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
51
+
52
+
53
+ # ── Lifespan ──────────────────────────────────────────────────────
54
+
55
+ @asynccontextmanager
56
+ async def lifespan(app: FastAPI):
57
+ try:
58
+ wrapper = ChatterboxWrapper()
59
+ app.state.wrapper = wrapper
60
+ logger.info("βœ… Model loaded, server ready")
61
+ except Exception as e:
62
+ logger.error(f"❌ Model loading failed: {e}")
63
+ raise
64
+ yield
65
+ tts_executor.shutdown(wait=False)
66
+
67
+
68
+ app = FastAPI(
69
+ title="Chatterbox Turbo TTS API",
70
+ version="1.0.0",
71
+ docs_url="/docs",
72
+ lifespan=lifespan,
73
+ )
74
+
75
+
76
+ # ── CORS Middleware ───────────────────────────────────────────────
77
+
78
+ @app.middleware("http")
79
+ async def cors_middleware(request: Request, call_next):
80
+ origin = request.headers.get("origin")
81
+
82
+ # Preflight
83
+ if request.method == "OPTIONS" and origin in Config.ALLOWED_ORIGINS:
84
+ return Response(
85
+ status_code=200,
86
+ headers={
87
+ "Access-Control-Allow-Origin": origin,
88
+ "Access-Control-Allow-Methods": "*",
89
+ "Access-Control-Allow-Headers": "*",
90
+ "Access-Control-Allow-Credentials": "true",
91
+ },
92
+ )
93
+
94
+ if not origin or origin in Config.ALLOWED_ORIGINS:
95
+ response = await call_next(request)
96
+ if origin:
97
+ response.headers["Access-Control-Allow-Origin"] = origin
98
+ response.headers["Access-Control-Allow-Credentials"] = "true"
99
+ response.headers["Access-Control-Allow-Methods"] = "*"
100
+ response.headers["Access-Control-Allow-Headers"] = "*"
101
+ response.headers["Access-Control-Expose-Headers"] = "X-Stream-Id"
102
+ return response
103
+
104
+ logger.warning(f"🚫 Blocked origin: {origin}")
105
+ return Response(status_code=403, content="Forbidden: Origin not allowed")
106
+
107
+
108
+ # ═══════════════════════════════════════════════════════════════════
109
+ # Helper: resolve voice from optional upload
110
+ # ═══════════════════════════════════════════════════════════════════
111
+
112
+ async def _resolve_voice(
113
+ voice_ref: Optional[UploadFile],
114
+ wrapper: ChatterboxWrapper,
115
+ ) -> VoiceProfile:
116
+ """Return a VoiceProfile from uploaded audio or the default voice."""
117
+ if voice_ref is None or voice_ref.filename == "":
118
+ return wrapper.default_voice
119
+
120
+ audio_bytes = await voice_ref.read()
121
+ if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES:
122
+ raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)")
123
+ if len(audio_bytes) == 0:
124
+ raise HTTPException(status_code=400, detail="Empty voice file")
125
+
126
+ loop = asyncio.get_running_loop()
127
+ try:
128
+ return await loop.run_in_executor(
129
+ tts_executor, wrapper.encode_voice_from_bytes, audio_bytes
130
+ )
131
+ except ValueError as e:
132
+ raise HTTPException(status_code=400, detail=str(e))
133
+ except Exception as e:
134
+ logger.error(f"Voice encoding failed: {e}")
135
+ raise HTTPException(
136
+ status_code=400,
137
+ detail=f"Could not process voice file: {str(e)}. "
138
+ f"Supported formats: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM."
139
+ )
140
+
141
+
142
+ # ═══════════════════════════════════════════════════════════════════
143
+ # Helper: encode numpy audio to bytes in given format
144
+ # ═══════════════════════════════════════════════════════════════════
145
+
146
+ def _encode_audio(audio: np.ndarray, fmt: str = "wav") -> tuple[bytes, str]:
147
+ buf = io.BytesIO()
148
+ fmt_lower = fmt.lower()
149
+ if fmt_lower == "mp3":
150
+ sf.write(buf, audio, Config.SAMPLE_RATE, format="mp3")
151
+ media = "audio/mpeg"
152
+ elif fmt_lower == "flac":
153
+ sf.write(buf, audio, Config.SAMPLE_RATE, format="flac")
154
+ media = "audio/flac"
155
+ else:
156
+ sf.write(buf, audio, Config.SAMPLE_RATE, format="wav")
157
+ media = "audio/wav"
158
+ return buf.getvalue(), media
159
+
160
+
161
+ def _encode_mp3_chunk(audio: np.ndarray) -> bytes:
162
+ """Encode one numpy chunk to MP3 bytes (same encoder path as current server)."""
163
+ data, _ = _encode_audio(audio, fmt="mp3")
164
+ return data
165
+
166
+
167
+ def _build_helper_endpoint(base_url: str, path: str) -> str:
168
+ return f"{base_url.rstrip('/')}{path}"
169
+
170
+
171
+ def _internal_headers() -> dict[str, str]:
172
+ headers = {"Content-Type": "application/json", "Accept": "audio/mpeg"}
173
+ if Config.INTERNAL_SHARED_SECRET:
174
+ headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET
175
+ return headers
176
+
177
+
178
+ def _helper_request_chunk(
179
+ helper_base_url: str,
180
+ payload: dict,
181
+ timeout_sec: float,
182
+ ) -> bytes:
183
+ url = _build_helper_endpoint(helper_base_url, "/internal/chunk/synthesize")
184
+ body = json.dumps(payload).encode("utf-8")
185
+ req = urllib.request.Request(
186
+ url=url,
187
+ data=body,
188
+ headers=_internal_headers(),
189
+ method="POST",
190
+ )
191
+ with urllib.request.urlopen(req, timeout=timeout_sec) as resp:
192
+ return resp.read()
193
+
194
+
195
+ def _helper_register_voice(
196
+ helper_base_url: str,
197
+ stream_id: str,
198
+ audio_bytes: bytes,
199
+ timeout_sec: float,
200
+ ) -> str:
201
+ """Register reference voice on helper once, return voice_key for chunk calls."""
202
+ query = urllib.parse.urlencode({"stream_id": stream_id})
203
+ url = _build_helper_endpoint(helper_base_url, f"/internal/voice/register?{query}")
204
+ headers = {"Content-Type": "application/octet-stream", "Accept": "application/json"}
205
+ if Config.INTERNAL_SHARED_SECRET:
206
+ headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET
207
+
208
+ req = urllib.request.Request(
209
+ url=url,
210
+ data=audio_bytes,
211
+ headers=headers,
212
+ method="POST",
213
+ )
214
+ with urllib.request.urlopen(req, timeout=timeout_sec) as resp:
215
+ data = json.loads(resp.read().decode("utf-8"))
216
+ voice_key = (data.get("voice_key") or "").strip()
217
+ if not voice_key:
218
+ raise RuntimeError("helper voice registration returned no voice_key")
219
+ return voice_key
220
+
221
+
222
+ def _helper_cancel_stream(helper_base_url: str, stream_id: str):
223
+ """Best-effort cancellation signal to helper."""
224
+ try:
225
+ url = _build_helper_endpoint(helper_base_url, f"/internal/chunk/cancel/{stream_id}")
226
+ req = urllib.request.Request(
227
+ url=url,
228
+ data=b"",
229
+ headers=_internal_headers(),
230
+ method="POST",
231
+ )
232
+ with urllib.request.urlopen(req, timeout=3.0):
233
+ pass
234
+ except Exception:
235
+ pass
236
+
237
+
238
+ # ═══════════════════════════════════════════════════════════════════
239
+ # Endpoints
240
+ # ═══════════════════════════════════════════════════════════════════
241
+
242
+ @app.get("/health")
243
+ async def health(warm_up: bool = False):
244
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
245
+ status = {
246
+ "status": "healthy" if wrapper else "loading",
247
+ "model_loaded": wrapper is not None,
248
+ "model_dtype": Config.MODEL_DTYPE,
249
+ "streaming_supported": True,
250
+ "voice_cache_entries": wrapper._voice_cache.size if wrapper else 0,
251
+ }
252
+ if warm_up and wrapper:
253
+ try:
254
+ loop = asyncio.get_running_loop()
255
+ await loop.run_in_executor(tts_executor, wrapper.warmup)
256
+ status["warm_up"] = "success"
257
+ except Exception as e:
258
+ status["warm_up"] = f"failed: {e}"
259
+ return status
260
+
261
+
262
+ @app.get("/info")
263
+ async def info():
264
+ return {
265
+ "model": Config.MODEL_ID,
266
+ "dtype": Config.MODEL_DTYPE,
267
+ "sample_rate": Config.SAMPLE_RATE,
268
+ "paralinguistic_tags": list(Config.PARALINGUISTIC_TAGS),
269
+ "tag_usage": "Insert tags directly in text, e.g. 'That is so funny! [laugh] Anyway…'",
270
+ "parameters": {
271
+ "max_new_tokens": {"default": Config.MAX_NEW_TOKENS, "range": "64–2048"},
272
+ "repetition_penalty": {"default": Config.REPETITION_PENALTY, "range": "1.0–2.0"},
273
+ },
274
+ "voice_cloning": {
275
+ "description": "Upload 3–30s reference WAV/MP3 as 'voice_ref' field",
276
+ "max_upload_mb": Config.MAX_VOICE_UPLOAD_BYTES // (1024 * 1024),
277
+ },
278
+ "parallel_mode": {
279
+ "enabled": Config.ENABLE_PARALLEL_MODE,
280
+ "helper_configured": bool(Config.HELPER_BASE_URL),
281
+ "helper_base_url": Config.HELPER_BASE_URL or None,
282
+ "supports_voice_ref": True,
283
+ },
284
+ }
285
+
286
+
287
+ # ── POST /tts ─────────────────────────────────────────────────────
288
+
289
+ @app.post("/tts", response_class=Response)
290
+ async def text_to_speech(
291
+ text: str = Form(...),
292
+ voice_ref: Optional[UploadFile] = File(None),
293
+ output_format: str = Form("wav"),
294
+ max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
295
+ repetition_penalty: float = Form(Config.REPETITION_PENALTY),
296
+ ):
297
+ """Generate complete audio for the given text."""
298
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
299
+ if not wrapper:
300
+ raise HTTPException(503, "Model not loaded")
301
+
302
+ if not text or not text.strip():
303
+ raise HTTPException(400, "Text is required")
304
+
305
+ voice = await _resolve_voice(voice_ref, wrapper)
306
+
307
+ loop = asyncio.get_running_loop()
308
+ try:
309
+ audio = await loop.run_in_executor(
310
+ tts_executor,
311
+ wrapper.generate_speech,
312
+ text, voice, max_new_tokens, repetition_penalty,
313
+ )
314
+ except ValueError as e:
315
+ raise HTTPException(400, str(e))
316
+ except Exception as e:
317
+ logger.error(f"TTS error: {e}")
318
+ raise HTTPException(500, "Internal server error")
319
+
320
+ data, media_type = _encode_audio(audio, output_format)
321
+ return Response(
322
+ content=data,
323
+ media_type=media_type,
324
+ headers={"Content-Disposition": f"attachment; filename=tts_output.{output_format}"},
325
+ )
326
+
327
+ # ═══════════════════════════════════════════════════════════════════
328
+ # Active Stream Registry (for cancellation)
329
+ # ═══════════════════════════════════════════════════════════════════
330
+
331
+ _active_streams: dict[str, threading.Event] = {}
332
+ _internal_cancelled_streams: set[str] = set()
333
+ _internal_cancel_lock = threading.Lock()
334
+ _internal_stream_voice_keys: dict[str, set[str]] = {}
335
+
336
+
337
+ # ═══════════════════════════════════════════════════════════════════
338
+ # Pipeline Streaming Generator
339
+ # ═══════════════════════════════════════════════════════════════════
340
+
341
+ def _pipeline_stream_generator(
342
+ wrapper: ChatterboxWrapper,
343
+ text: str,
344
+ voice: VoiceProfile,
345
+ max_new_tokens: int,
346
+ repetition_penalty: float,
347
+ stream_id: str,
348
+ ) -> Generator[bytes, None, None]:
349
+ """Two-stage producer-consumer pipeline for minimal inter-chunk gaps.
350
+
351
+ Architecture:
352
+ Producer thread (heavyweight, ~80% CPU):
353
+ ONNX token generation β†’ audio decoding β†’ raw numpy arrays β†’ queue
354
+
355
+ Consumer (this generator, lightweight, ~20% CPU):
356
+ queue β†’ MP3 encode β†’ yield to HTTP response
357
+
358
+ Why this helps:
359
+ - ONNX model runs CONTINUOUSLY without waiting for MP3 encode or HTTP
360
+ - MP3 encoding (libsndfile, C code) releases GIL β†’ true parallelism
361
+ - ONNX inference (C++ code) also releases GIL β†’ both run simultaneously
362
+ - Queue(maxsize=2) lets producer stay 1-2 chunks ahead
363
+
364
+ Cancellation:
365
+ - cancel_event checked between chunks + every 25 autoregressive steps
366
+ - Client disconnect triggers GeneratorExit β†’ finally sets cancel
367
+ - /tts/stop endpoint sets cancel externally
368
+ """
369
+ cancel_event = threading.Event()
370
+ _active_streams[stream_id] = cancel_event
371
+
372
+ # Raw audio buffer: producer puts numpy arrays, consumer takes them
373
+ audio_buffer: stdlib_queue.Queue = stdlib_queue.Queue(maxsize=2)
374
+
375
+ def _producer():
376
+ """Heavyweight worker: runs ONNX model continuously."""
377
+ try:
378
+ for audio_chunk in wrapper.stream_speech(
379
+ text, voice,
380
+ max_new_tokens=max_new_tokens,
381
+ repetition_penalty=repetition_penalty,
382
+ is_cancelled=cancel_event.is_set,
383
+ ):
384
+ if cancel_event.is_set():
385
+ break
386
+ while not cancel_event.is_set():
387
+ try:
388
+ audio_buffer.put(audio_chunk, timeout=0.1)
389
+ break
390
+ except stdlib_queue.Full:
391
+ continue
392
+ except GenerationCancelled:
393
+ logger.info(f"[{stream_id}] Generation cancelled")
394
+ except Exception as e:
395
+ while not cancel_event.is_set():
396
+ try:
397
+ audio_buffer.put(e, timeout=0.1)
398
+ break
399
+ except stdlib_queue.Full:
400
+ continue
401
+ finally:
402
+ while not cancel_event.is_set():
403
+ try:
404
+ audio_buffer.put(None, timeout=0.1)
405
+ break
406
+ except stdlib_queue.Full:
407
+ continue
408
+
409
+ producer = threading.Thread(target=_producer, daemon=True)
410
+ producer.start()
411
+
412
+ try:
413
+ # Consumer: lightweight MP3 encoding + yield
414
+ while True:
415
+ item = audio_buffer.get()
416
+ if item is None:
417
+ break
418
+ if isinstance(item, Exception):
419
+ logger.error(f"[{stream_id}] Stream error: {item}")
420
+ break
421
+ if cancel_event.is_set():
422
+ break
423
+
424
+ # MP3 encode (C code, releases GIL, runs parallel with next ONNX step)
425
+ buf = io.BytesIO()
426
+ sf.write(buf, item, Config.SAMPLE_RATE, format="mp3")
427
+ yield buf.getvalue()
428
+ finally:
429
+ # Cleanup: signal producer to stop + deregister
430
+ cancel_event.set()
431
+ _active_streams.pop(stream_id, None)
432
+
433
+
434
+ def _parallel_odd_even_stream_generator(
435
+ wrapper: ChatterboxWrapper,
436
+ text: str,
437
+ local_voice: VoiceProfile,
438
+ helper_voice_bytes: Optional[bytes],
439
+ max_new_tokens: int,
440
+ repetition_penalty: float,
441
+ stream_id: str,
442
+ helper_base_url: str,
443
+ ) -> Generator[bytes, None, None]:
444
+ """Additive odd/even split streamer (primary handles odd, helper handles even)."""
445
+ cancel_event = threading.Event()
446
+ _active_streams[stream_id] = cancel_event
447
+
448
+ clean_text = text_processor.sanitize(text.strip()[: Config.MAX_TEXT_LENGTH])
449
+ chunks = text_processor.split_for_streaming(clean_text)
450
+ total_chunks = len(chunks)
451
+ if total_chunks == 0:
452
+ _active_streams.pop(stream_id, None)
453
+ return
454
+
455
+ lock = threading.Lock()
456
+ cond = threading.Condition(lock)
457
+ ready: dict[int, bytes] = {}
458
+ first_error: Optional[Exception] = None
459
+ workers_done = 0
460
+
461
+ def _publish(idx: int, data: bytes):
462
+ with cond:
463
+ ready[idx] = data
464
+ cond.notify_all()
465
+
466
+ def _set_error(err: Exception):
467
+ nonlocal first_error
468
+ with cond:
469
+ if first_error is None:
470
+ first_error = err
471
+ cond.notify_all()
472
+
473
+ def _worker_done():
474
+ nonlocal workers_done
475
+ with cond:
476
+ workers_done += 1
477
+ cond.notify_all()
478
+
479
+ def _synth_local(chunk_text: str) -> bytes:
480
+ audio = wrapper.generate_speech(
481
+ chunk_text,
482
+ local_voice,
483
+ max_new_tokens=max_new_tokens,
484
+ repetition_penalty=repetition_penalty,
485
+ )
486
+ return _encode_mp3_chunk(audio)
487
+
488
+ def _odd_worker():
489
+ try:
490
+ for idx in range(0, total_chunks, 2):
491
+ if cancel_event.is_set():
492
+ break
493
+ data = _synth_local(chunks[idx])
494
+ _publish(idx, data)
495
+ except Exception as e:
496
+ _set_error(e)
497
+ finally:
498
+ _worker_done()
499
+
500
+ def _even_worker():
501
+ helper_available = True
502
+ helper_voice_key: Optional[str] = None
503
+ try:
504
+ if helper_voice_bytes:
505
+ attempts = 2 if Config.HELPER_RETRY_ONCE else 1
506
+ last_err: Optional[Exception] = None
507
+ for _ in range(attempts):
508
+ try:
509
+ helper_voice_key = _helper_register_voice(
510
+ helper_base_url=helper_base_url,
511
+ stream_id=stream_id,
512
+ audio_bytes=helper_voice_bytes,
513
+ timeout_sec=max(1.0, Config.HELPER_TIMEOUT_SEC),
514
+ )
515
+ last_err = None
516
+ break
517
+ except Exception as reg_err:
518
+ last_err = reg_err
519
+ continue
520
+ if last_err is not None:
521
+ helper_available = False
522
+ logger.warning(
523
+ f"[{stream_id}] Helper voice registration failed; "
524
+ "falling back to local synthesis for even chunks"
525
+ )
526
+
527
+ for idx in range(1, total_chunks, 2):
528
+ if cancel_event.is_set():
529
+ break
530
+
531
+ if helper_available:
532
+ payload = {
533
+ "stream_id": stream_id,
534
+ "chunk_index": idx,
535
+ "text": chunks[idx],
536
+ "max_new_tokens": max_new_tokens,
537
+ "repetition_penalty": repetition_penalty,
538
+ "output_format": "mp3",
539
+ }
540
+ if helper_voice_key:
541
+ payload["voice_key"] = helper_voice_key
542
+
543
+ attempts = 2 if Config.HELPER_RETRY_ONCE else 1
544
+ last_err: Optional[Exception] = None
545
+ for _ in range(attempts):
546
+ try:
547
+ helper_data = _helper_request_chunk(
548
+ helper_base_url=helper_base_url,
549
+ payload=payload,
550
+ timeout_sec=max(1.0, Config.HELPER_TIMEOUT_SEC),
551
+ )
552
+ _publish(idx, helper_data)
553
+ last_err = None
554
+ break
555
+ except Exception as helper_err:
556
+ last_err = helper_err
557
+ continue
558
+
559
+ if last_err is None:
560
+ continue
561
+
562
+ helper_available = False
563
+ logger.warning(
564
+ f"[{stream_id}] Helper failed at chunk {idx}; "
565
+ "falling back to local synthesis for remaining even chunks"
566
+ )
567
+
568
+ # Local fallback for even chunks
569
+ data = _synth_local(chunks[idx])
570
+ _publish(idx, data)
571
+ except Exception as e:
572
+ _set_error(e)
573
+ finally:
574
+ _worker_done()
575
+
576
+ odd_thread = threading.Thread(target=_odd_worker, daemon=True)
577
+ even_thread = threading.Thread(target=_even_worker, daemon=True)
578
+ odd_thread.start()
579
+ even_thread.start()
580
+
581
+ next_idx = 0
582
+ try:
583
+ while next_idx < total_chunks:
584
+ with cond:
585
+ while (
586
+ next_idx not in ready
587
+ and first_error is None
588
+ and not cancel_event.is_set()
589
+ and workers_done < 2
590
+ ):
591
+ cond.wait(timeout=0.1)
592
+
593
+ if cancel_event.is_set():
594
+ break
595
+
596
+ if next_idx in ready:
597
+ data = ready.pop(next_idx)
598
+ elif first_error is not None:
599
+ logger.error(f"[{stream_id}] Parallel stream error: {first_error}")
600
+ break
601
+ elif workers_done >= 2:
602
+ logger.error(
603
+ f"[{stream_id}] Parallel stream ended with missing chunk index {next_idx}"
604
+ )
605
+ break
606
+ else:
607
+ continue
608
+
609
+ yield data
610
+ next_idx += 1
611
+ finally:
612
+ cancel_event.set()
613
+ _helper_cancel_stream(helper_base_url, stream_id)
614
+ odd_thread.join(timeout=1.0)
615
+ even_thread.join(timeout=1.0)
616
+ _active_streams.pop(stream_id, None)
617
+
618
+
619
+ # ── POST /tts/stream & /tts/true-stream ──────────────────────────
620
+
621
+ @app.post("/tts/stream")
622
+ @app.post("/tts/true-stream")
623
+ async def stream_text_to_speech(
624
+ text: str = Form(...),
625
+ voice_ref: Optional[UploadFile] = File(None),
626
+ max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
627
+ repetition_penalty: float = Form(Config.REPETITION_PENALTY),
628
+ ):
629
+ """True real-time streaming: yields MP3 chunks as each sentence finishes.
630
+
631
+ Response includes X-Stream-Id header for cancellation via /tts/stop.
632
+ Compatible with frontend's MediaSource + ReadableStream pattern.
633
+ """
634
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
635
+ if not wrapper:
636
+ raise HTTPException(503, "Model not loaded")
637
+
638
+ if not text or not text.strip():
639
+ raise HTTPException(400, "Text is required")
640
+
641
+ voice = await _resolve_voice(voice_ref, wrapper)
642
+ stream_id = uuid.uuid4().hex[:12]
643
+
644
+ return StreamingResponse(
645
+ _pipeline_stream_generator(
646
+ wrapper, text, voice, max_new_tokens, repetition_penalty, stream_id,
647
+ ),
648
+ media_type="audio/mpeg",
649
+ headers={
650
+ "Content-Disposition": "attachment; filename=tts_stream.mp3",
651
+ "Transfer-Encoding": "chunked",
652
+ "X-Stream-Id": stream_id,
653
+ "X-Streaming-Type": "true-realtime",
654
+ "Cache-Control": "no-cache",
655
+ },
656
+ )
657
+
658
+
659
+ @app.post("/tts/parallel-stream")
660
+ async def parallel_stream_text_to_speech(
661
+ text: str = Form(...),
662
+ voice_ref: Optional[UploadFile] = File(None),
663
+ max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
664
+ repetition_penalty: float = Form(Config.REPETITION_PENALTY),
665
+ helper_url: Optional[str] = Form(None),
666
+ ):
667
+ """Additive odd/even split stream mode (primary + helper)."""
668
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
669
+ if not wrapper:
670
+ raise HTTPException(503, "Model not loaded")
671
+ if not Config.ENABLE_PARALLEL_MODE:
672
+ raise HTTPException(503, "Parallel mode is disabled")
673
+ if not text or not text.strip():
674
+ raise HTTPException(400, "Text is required")
675
+
676
+ local_voice: VoiceProfile = wrapper.default_voice
677
+ helper_voice_bytes: Optional[bytes] = None
678
+ if voice_ref is not None and voice_ref.filename:
679
+ helper_voice_bytes = await voice_ref.read()
680
+ if len(helper_voice_bytes) > Config.MAX_VOICE_UPLOAD_BYTES:
681
+ raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)")
682
+ if len(helper_voice_bytes) == 0:
683
+ raise HTTPException(status_code=400, detail="Empty voice file")
684
+ loop = asyncio.get_running_loop()
685
+ try:
686
+ local_voice = await loop.run_in_executor(
687
+ tts_executor, wrapper.encode_voice_from_bytes, helper_voice_bytes
688
+ )
689
+ except Exception as e:
690
+ logger.error(f"Parallel voice encoding failed: {e}")
691
+ raise HTTPException(400, "Could not process voice file for parallel mode")
692
+
693
+ resolved_helper = (helper_url or Config.HELPER_BASE_URL).strip()
694
+ if not resolved_helper:
695
+ raise HTTPException(
696
+ 400,
697
+ "Helper URL not configured. Set CB_HELPER_BASE_URL or pass helper_url.",
698
+ )
699
+
700
+ stream_id = uuid.uuid4().hex[:12]
701
+ return StreamingResponse(
702
+ _parallel_odd_even_stream_generator(
703
+ wrapper=wrapper,
704
+ text=text,
705
+ local_voice=local_voice,
706
+ helper_voice_bytes=helper_voice_bytes,
707
+ max_new_tokens=max_new_tokens,
708
+ repetition_penalty=repetition_penalty,
709
+ stream_id=stream_id,
710
+ helper_base_url=resolved_helper,
711
+ ),
712
+ media_type="audio/mpeg",
713
+ headers={
714
+ "Content-Disposition": "attachment; filename=tts_parallel_stream.mp3",
715
+ "Transfer-Encoding": "chunked",
716
+ "X-Stream-Id": stream_id,
717
+ "X-Streaming-Type": "parallel-odd-even",
718
+ "Cache-Control": "no-cache",
719
+ },
720
+ )
721
+
722
+
723
+ # ── JSON body variant (Kokoro/OpenAI compatibility) ───────────────
724
+
725
+ from pydantic import BaseModel, Field
726
+
727
+
728
+ class InternalChunkRequest(BaseModel):
729
+ stream_id: str = Field(..., min_length=1, max_length=64)
730
+ chunk_index: int = Field(..., ge=0)
731
+ text: str = Field(..., min_length=1, max_length=10000)
732
+ max_new_tokens: int = Field(default=Config.MAX_NEW_TOKENS, ge=64, le=2048)
733
+ repetition_penalty: float = Field(default=Config.REPETITION_PENALTY, ge=1.0, le=2.0)
734
+ output_format: str = Field(default="mp3")
735
+ voice_key: Optional[str] = Field(default=None, min_length=1, max_length=64)
736
+
737
+
738
+ class TTSJsonRequest(BaseModel):
739
+ text: str = Field(..., min_length=1, max_length=50000)
740
+ voice: str = Field(default="default")
741
+ speed: float = Field(default=1.0, ge=0.5, le=2.0) # reserved for future use
742
+ max_new_tokens: int = Field(default=Config.MAX_NEW_TOKENS, ge=64, le=2048)
743
+ repetition_penalty: float = Field(default=Config.REPETITION_PENALTY, ge=1.0, le=2.0)
744
+
745
+
746
+ @app.post("/internal/voice/register")
747
+ async def internal_voice_register(http_request: Request):
748
+ """Register voice once for a stream; returns reusable voice_key."""
749
+ if Config.INTERNAL_SHARED_SECRET:
750
+ provided = http_request.headers.get("X-Internal-Secret", "")
751
+ if provided != Config.INTERNAL_SHARED_SECRET:
752
+ raise HTTPException(403, "Forbidden")
753
+
754
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
755
+ if not wrapper:
756
+ raise HTTPException(503, "Model not loaded")
757
+
758
+ audio_bytes = await http_request.body()
759
+ if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES:
760
+ raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)")
761
+ if len(audio_bytes) == 0:
762
+ raise HTTPException(status_code=400, detail="Empty voice file")
763
+
764
+ loop = asyncio.get_running_loop()
765
+ try:
766
+ voice = await loop.run_in_executor(
767
+ tts_executor, wrapper.encode_voice_from_bytes, audio_bytes
768
+ )
769
+ except Exception as e:
770
+ logger.error(f"[internal] voice register failed: {e}")
771
+ raise HTTPException(400, "Voice registration failed")
772
+
773
+ voice_key = (voice.audio_hash or "").strip()
774
+ if not voice_key:
775
+ raise HTTPException(500, "Voice key unavailable")
776
+
777
+ stream_id = (http_request.query_params.get("stream_id") or "").strip()
778
+ if stream_id:
779
+ with _internal_cancel_lock:
780
+ keys = _internal_stream_voice_keys.setdefault(stream_id, set())
781
+ keys.add(voice_key)
782
+
783
+ return {"status": "registered", "voice_key": voice_key}
784
+
785
+
786
+ @app.post("/internal/chunk/synthesize")
787
+ async def internal_chunk_synthesize(
788
+ request: InternalChunkRequest,
789
+ http_request: Request,
790
+ ):
791
+ """Internal endpoint used by primary/helper parallel routing."""
792
+ if Config.INTERNAL_SHARED_SECRET:
793
+ provided = http_request.headers.get("X-Internal-Secret", "")
794
+ if provided != Config.INTERNAL_SHARED_SECRET:
795
+ raise HTTPException(403, "Forbidden")
796
+
797
+ with _internal_cancel_lock:
798
+ if request.stream_id in _internal_cancelled_streams:
799
+ raise HTTPException(409, "Stream already cancelled")
800
+
801
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
802
+ if not wrapper:
803
+ raise HTTPException(503, "Model not loaded")
804
+
805
+ voice_profile = wrapper.default_voice
806
+ if request.voice_key:
807
+ cached_voice = wrapper._voice_cache.get(request.voice_key)
808
+ if cached_voice is None:
809
+ raise HTTPException(409, "Voice key expired or not found")
810
+ voice_profile = cached_voice
811
+
812
+ loop = asyncio.get_running_loop()
813
+ try:
814
+ audio = await loop.run_in_executor(
815
+ tts_executor,
816
+ wrapper.generate_speech,
817
+ request.text,
818
+ voice_profile,
819
+ request.max_new_tokens,
820
+ request.repetition_penalty,
821
+ )
822
+ except Exception as e:
823
+ logger.error(f"[internal] chunk {request.chunk_index} failed: {e}")
824
+ raise HTTPException(500, "Chunk synthesis failed")
825
+
826
+ fmt = (request.output_format or "mp3").lower()
827
+ if fmt not in {"mp3", "wav", "flac"}:
828
+ fmt = "mp3"
829
+ data, media_type = _encode_audio(audio, fmt=fmt)
830
+ return Response(
831
+ content=data,
832
+ media_type=media_type,
833
+ headers={
834
+ "X-Stream-Id": request.stream_id,
835
+ "X-Chunk-Index": str(request.chunk_index),
836
+ },
837
+ )
838
+
839
+
840
+ @app.post("/internal/chunk/cancel/{stream_id}")
841
+ async def internal_chunk_cancel(stream_id: str, http_request: Request):
842
+ if Config.INTERNAL_SHARED_SECRET:
843
+ provided = http_request.headers.get("X-Internal-Secret", "")
844
+ if provided != Config.INTERNAL_SHARED_SECRET:
845
+ raise HTTPException(403, "Forbidden")
846
+
847
+ with _internal_cancel_lock:
848
+ _internal_cancelled_streams.add(stream_id)
849
+ _internal_stream_voice_keys.pop(stream_id, None)
850
+ return {"status": "cancelled", "stream_id": stream_id}
851
+
852
+
853
+ @app.post("/v1/audio/speech")
854
+ async def openai_compatible_tts(request: TTSJsonRequest):
855
+ """OpenAI-compatible streaming endpoint (JSON body, no file upload).
856
+
857
+ Uses the default voice. For voice cloning, use /tts/stream with FormData.
858
+ """
859
+ wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
860
+ if not wrapper:
861
+ raise HTTPException(503, "Model not loaded")
862
+
863
+ stream_id = uuid.uuid4().hex[:12]
864
+
865
+ return StreamingResponse(
866
+ _pipeline_stream_generator(
867
+ wrapper, request.text, wrapper.default_voice,
868
+ request.max_new_tokens, request.repetition_penalty, stream_id,
869
+ ),
870
+ media_type="audio/mpeg",
871
+ headers={
872
+ "Transfer-Encoding": "chunked",
873
+ "X-Stream-Id": stream_id,
874
+ "Cache-Control": "no-cache",
875
+ },
876
+ )
877
+
878
+
879
+ # ═══════════════════════════════════════════════════════════════════
880
+ # Stop / Cancel Endpoint
881
+ # ═══════════════════════════════════════════════════════════════════
882
+
883
+ @app.post("/tts/stop/{stream_id}")
884
+ async def stop_stream(stream_id: str):
885
+ """Stop an active TTS stream by its ID (from X-Stream-Id header).
886
+
887
+ Cancels the ONNX generation loop mid-token, freeing CPU immediately.
888
+ """
889
+ event = _active_streams.get(stream_id)
890
+ if event:
891
+ event.set()
892
+ logger.info(f"Stream {stream_id} cancelled by client")
893
+ return {"status": "stopped", "stream_id": stream_id}
894
+ return {"status": "not_found", "stream_id": stream_id}
895
+
896
+
897
+ @app.post("/tts/stop")
898
+ async def stop_all_streams():
899
+ """Emergency stop: cancel ALL active TTS streams."""
900
+ count = len(_active_streams)
901
+ for sid, event in list(_active_streams.items()):
902
+ event.set()
903
+ _active_streams.clear()
904
+ logger.info(f"Stopped all streams ({count} active)")
905
+ return {"status": "stopped_all", "count": count}
906
+
907
+
908
+ # ═══════════════════════════════════════════════════════════════════
909
+ # Entrypoint
910
+ # ════════════════════════════════���══════════════════════════════════
911
+
912
+ if __name__ == "__main__":
913
+ import uvicorn
914
+
915
+ uvicorn.run(app, host=Config.HOST, port=Config.PORT)
chatterbox_wrapper.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatterbox Turbo TTS β€” ONNX Inference Wrapper
3
+ ═══════════════════════════════════════════════
4
+ Orchestrates the 4-component ONNX pipeline:
5
+ embed_tokens β†’ speech_encoder β†’ language_model β†’ conditional_decoder
6
+
7
+ Optimised for lowest-latency CPU inference on 2 vCPU:
8
+ β€’ Sequential execution, thread count = physical cores, no spinning
9
+ β€’ Token list pre-allocation (avoids O(nΒ²) np.concatenate in loop)
10
+ β€’ In-memory voice caching (no disk writes for uploads)
11
+ β€’ Robust audio loading: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM
12
+ β€’ Sentence-level streaming for real-time playback
13
+ """
14
+
15
+ # ── Suppress harmless transformers warnings BEFORE import ─────────
16
+ import os
17
+ import warnings
18
+
19
+ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
20
+ warnings.filterwarnings("ignore", message=".*model of type.*chatterbox.*")
21
+
22
+ import hashlib
23
+ import io
24
+ import logging
25
+ import subprocess
26
+ import tempfile
27
+ import time
28
+ from collections import OrderedDict
29
+ from dataclasses import dataclass
30
+ from typing import Callable, Generator, Optional
31
+
32
+ import librosa
33
+ import numpy as np
34
+ import onnxruntime as ort
35
+ import soundfile as soundfile_lib
36
+ from huggingface_hub import hf_hub_download
37
+ from transformers import AutoTokenizer
38
+
39
+ from config import Config
40
+ import text_processor
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+ # ── Supported audio MIME types for voice upload ───────────────────
45
+ _SUPPORTED_AUDIO_EXTENSIONS = {
46
+ ".wav", ".mp3", ".mpeg", ".mpga", ".m4a", ".mp4",
47
+ ".ogg", ".oga", ".opus", ".flac", ".webm", ".aac", ".wma",
48
+ }
49
+
50
+
51
+ # ═══════════════════════════════════════════════════════════════════
52
+ # Data Structures
53
+ # ═══════════════════════════════════════════════════════════════════
54
+
55
+ @dataclass
56
+ class VoiceProfile:
57
+ """Cached speaker embedding extracted from reference audio."""
58
+ cond_emb: np.ndarray
59
+ prompt_token: np.ndarray
60
+ speaker_embeddings: np.ndarray
61
+ speaker_features: np.ndarray
62
+ audio_hash: str = ""
63
+
64
+
65
+ class GenerationCancelled(Exception):
66
+ """Raised when inference is cancelled by the client."""
67
+ pass
68
+
69
+
70
+ # ═══════════════════════════════════════════════════════════════════
71
+ # LRU Voice Cache
72
+ # ═══════════════════════════════════════════════════════════════════
73
+
74
+ class _VoiceCache:
75
+ """LRU cache for VoiceProfile objects with TTL-based expiration.
76
+
77
+ Entries auto-expire after `ttl_seconds` (default: 1 hour).
78
+ Re-uploading the same voice file within the TTL window returns
79
+ the cached profile instantly β€” no re-encoding needed.
80
+ """
81
+
82
+ def __init__(self, maxsize: int, ttl_seconds: int = 3600):
83
+ self._cache: OrderedDict[str, tuple[VoiceProfile, float]] = OrderedDict()
84
+ self._maxsize = maxsize
85
+ self._ttl = ttl_seconds
86
+
87
+ def _evict_expired(self):
88
+ """Remove all entries older than TTL."""
89
+ now = time.time()
90
+ expired = [k for k, (_, ts) in self._cache.items() if now - ts > self._ttl]
91
+ for k in expired:
92
+ del self._cache[k]
93
+ logger.debug(f"Voice cache expired: {k[:8]}…")
94
+
95
+ def get(self, key: str) -> Optional[VoiceProfile]:
96
+ self._evict_expired()
97
+ if key in self._cache:
98
+ profile, ts = self._cache[key]
99
+ remaining = self._ttl - (time.time() - ts)
100
+ self._cache.move_to_end(key)
101
+ logger.info(f"Voice cache HIT: {key[:8]}… (expires in {remaining:.0f}s)")
102
+ return profile
103
+ return None
104
+
105
+ def put(self, key: str, profile: VoiceProfile):
106
+ self._evict_expired()
107
+ if key in self._cache:
108
+ self._cache.move_to_end(key)
109
+ else:
110
+ if len(self._cache) >= self._maxsize:
111
+ evicted_key, _ = self._cache.popitem(last=False)
112
+ logger.debug(f"Voice cache evicted (LRU): {evicted_key[:8]}…")
113
+ self._cache[key] = (profile, time.time())
114
+ logger.info(f"Voice cache STORED: {key[:8]}… (TTL: {self._ttl}s, size: {len(self._cache)})")
115
+
116
+ @property
117
+ def size(self) -> int:
118
+ return len(self._cache)
119
+
120
+
121
+ # ═══════════════════════════════════════════════════════════════════
122
+ # Audio Loading (robust multi-format support)
123
+ # ════════════════���══════════════════════════════════════════════════
124
+
125
+ def _load_audio_bytes(audio_bytes: bytes, sr: int = 24000) -> np.ndarray:
126
+ """Load audio from raw bytes, supporting WAV/MP3/MPEG/M4A/OGG/FLAC/WebM.
127
+
128
+ Strategy: try soundfile (fast, native) β†’ librosa (ffmpeg backend) β†’ ffmpeg CLI.
129
+ """
130
+ buf = io.BytesIO(audio_bytes)
131
+
132
+ # 1) Try soundfile (handles WAV, FLAC, OGG natively β€” fastest)
133
+ try:
134
+ audio, file_sr = soundfile_lib.read(buf)
135
+ if audio.ndim > 1:
136
+ audio = audio.mean(axis=1) # stereo β†’ mono
137
+ if file_sr != sr:
138
+ audio = librosa.resample(audio.astype(np.float32), orig_sr=file_sr, target_sr=sr)
139
+ return audio.astype(np.float32)
140
+ except Exception:
141
+ buf.seek(0)
142
+
143
+ # 2) Try librosa (handles MP3 via audioread + ffmpeg backend)
144
+ try:
145
+ audio, _ = librosa.load(buf, sr=sr, mono=True)
146
+ return audio.astype(np.float32)
147
+ except Exception:
148
+ buf.seek(0)
149
+
150
+ # 3) Fallback: use ffmpeg CLI to convert to WAV in memory
151
+ try:
152
+ proc = subprocess.run(
153
+ ["ffmpeg", "-i", "pipe:0", "-f", "wav", "-ac", "1", "-ar", str(sr), "pipe:1"],
154
+ input=audio_bytes, capture_output=True, timeout=30,
155
+ )
156
+ if proc.returncode == 0 and len(proc.stdout) > 44:
157
+ wav_buf = io.BytesIO(proc.stdout)
158
+ audio, _ = soundfile_lib.read(wav_buf)
159
+ return audio.astype(np.float32)
160
+ except Exception:
161
+ pass
162
+
163
+ raise ValueError(
164
+ "Could not decode audio file. Supported formats: "
165
+ "WAV, MP3, MPEG, M4A, OGG, FLAC, WebM, AAC. "
166
+ "Please upload a valid audio file."
167
+ )
168
+
169
+
170
+ # ═══════════════════════════════════════════════════════════════════
171
+ # Main Wrapper
172
+ # ═══════════════════════════════════════════════════════════════════
173
+
174
+ class ChatterboxWrapper:
175
+
176
+ def __init__(self, download_only: bool = False):
177
+ self.cfg = Config
178
+ os.makedirs(self.cfg.MODELS_DIR, exist_ok=True)
179
+
180
+ logger.info(f"Downloading ONNX models (dtype={self.cfg.MODEL_DTYPE}) …")
181
+ self._model_paths = self._download_models()
182
+
183
+ if download_only:
184
+ return
185
+
186
+ logger.info(
187
+ f"Creating ONNX Runtime sessions "
188
+ f"(intra_op_threads={self.cfg.CPU_THREADS}, workers={self.cfg.MAX_WORKERS}) …"
189
+ )
190
+ opts = self._make_session_options()
191
+ providers = ["CPUExecutionProvider"]
192
+
193
+ self.embed_session = ort.InferenceSession(self._model_paths["embed_tokens"], sess_options=opts, providers=providers)
194
+ self.encoder_session = ort.InferenceSession(self._model_paths["speech_encoder"], sess_options=opts, providers=providers)
195
+ self.lm_session = ort.InferenceSession(self._model_paths["language_model"], sess_options=opts, providers=providers)
196
+ self.decoder_session = ort.InferenceSession(self._model_paths["conditional_decoder"], sess_options=opts, providers=providers)
197
+
198
+ logger.info("Loading tokenizer …")
199
+ self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.MODEL_ID)
200
+
201
+ self._voice_cache = _VoiceCache(
202
+ maxsize=self.cfg.VOICE_CACHE_SIZE,
203
+ ttl_seconds=self.cfg.VOICE_CACHE_TTL_SEC,
204
+ )
205
+
206
+ logger.info("Encoding default reference voice …")
207
+ self.default_voice = self._load_default_voice()
208
+
209
+ logger.info("βœ… ChatterboxWrapper ready")
210
+
211
+ # ─── Model download ──────────────────────────────────────────
212
+
213
+ def _download_models(self) -> dict:
214
+ """Download all 4 ONNX components + weight files from HuggingFace."""
215
+ components = ("conditional_decoder", "speech_encoder", "embed_tokens", "language_model")
216
+ paths = {}
217
+ for name in components:
218
+ paths[name] = self._download_component(name, self.cfg.MODEL_DTYPE)
219
+ return paths
220
+
221
+ def _download_component(self, name: str, dtype: str) -> str:
222
+ if dtype == "fp32":
223
+ filename = f"{name}.onnx"
224
+ elif dtype == "q8":
225
+ filename = f"{name}_quantized.onnx"
226
+ else:
227
+ filename = f"{name}_{dtype}.onnx"
228
+
229
+ graph = hf_hub_download(
230
+ self.cfg.MODEL_ID, subfolder="onnx", filename=filename,
231
+ cache_dir=self.cfg.MODELS_DIR,
232
+ )
233
+ # Download companion weight file
234
+ try:
235
+ hf_hub_download(
236
+ self.cfg.MODEL_ID, subfolder="onnx", filename=f"{filename}_data",
237
+ cache_dir=self.cfg.MODELS_DIR,
238
+ )
239
+ except Exception:
240
+ pass # Some quantized variants embed weights in-graph
241
+ return graph
242
+
243
+ # ─── Session configuration (optimised for 2 vCPU) ─────────────
244
+
245
+ def _make_session_options(self) -> ort.SessionOptions:
246
+ opts = ort.SessionOptions()
247
+ # Sequential execution: no parallel graph scheduling overhead
248
+ opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
249
+ # Match physical cores exactly (2 for HF Space free tier)
250
+ opts.intra_op_num_threads = self.cfg.CPU_THREADS
251
+ opts.inter_op_num_threads = 1
252
+ # Full graph optimisations (constant folding, fusion, etc.)
253
+ opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
254
+ # Disable thread spinning β€” wastes CPU cycles on busy-wait
255
+ opts.add_session_config_entry("session.intra_op.allow_spinning", "0")
256
+ opts.add_session_config_entry("session.inter_op.allow_spinning", "0")
257
+ # Enable memory optimisations
258
+ opts.enable_cpu_mem_arena = True
259
+ opts.enable_mem_pattern = True
260
+ opts.enable_mem_reuse = True
261
+ return opts
262
+
263
+ # ─── Default voice ────────────────────────────────────────────
264
+
265
+ def _load_default_voice(self) -> VoiceProfile:
266
+ path = hf_hub_download(
267
+ self.cfg.DEFAULT_VOICE_REPO,
268
+ filename=self.cfg.DEFAULT_VOICE_FILE,
269
+ cache_dir=self.cfg.MODELS_DIR,
270
+ )
271
+ audio, _ = librosa.load(path, sr=self.cfg.SAMPLE_RATE)
272
+ return self._encode_audio_array(audio, audio_hash="__default__")
273
+
274
+ # ─── Voice encoding ──────────────────────────────────────────
275
+
276
+ def encode_voice_from_bytes(self, audio_bytes: bytes) -> VoiceProfile:
277
+ """Encode reference audio from raw bytes (in-memory, no disk write).
278
+
279
+ Accepts: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM, AAC, WMA, Opus.
280
+ """
281
+ audio_hash = hashlib.md5(audio_bytes).hexdigest()
282
+ cached = self._voice_cache.get(audio_hash)
283
+ if cached is not None:
284
+ logger.info(f"Voice cache hit: {audio_hash[:8]}…")
285
+ return cached
286
+
287
+ # Robust multi-format audio loading
288
+ audio = _load_audio_bytes(audio_bytes, sr=self.cfg.SAMPLE_RATE)
289
+
290
+ # Validate duration
291
+ duration = len(audio) / self.cfg.SAMPLE_RATE
292
+ if duration < self.cfg.MIN_REF_DURATION_SEC:
293
+ raise ValueError(
294
+ f"Reference audio too short ({duration:.1f}s). "
295
+ f"Minimum: {self.cfg.MIN_REF_DURATION_SEC}s"
296
+ )
297
+ if duration > self.cfg.MAX_REF_DURATION_SEC:
298
+ audio = audio[: int(self.cfg.MAX_REF_DURATION_SEC * self.cfg.SAMPLE_RATE)]
299
+
300
+ profile = self._encode_audio_array(audio, audio_hash=audio_hash)
301
+ self._voice_cache.put(audio_hash, profile)
302
+ return profile
303
+
304
+ def _encode_audio_array(self, audio: np.ndarray, audio_hash: str = "") -> VoiceProfile:
305
+ """Run speech_encoder on a float32 mono audio array."""
306
+ audio_input = audio[np.newaxis, :].astype(np.float32)
307
+ cond_emb, prompt_token, speaker_emb, speaker_feat = self.encoder_session.run(
308
+ None, {"audio_values": audio_input}
309
+ )
310
+ return VoiceProfile(
311
+ cond_emb=cond_emb,
312
+ prompt_token=prompt_token,
313
+ speaker_embeddings=speaker_emb,
314
+ speaker_features=speaker_feat,
315
+ audio_hash=audio_hash,
316
+ )
317
+
318
+ # ─── Full generation (non-streaming) ──────────────────────────
319
+
320
+ def generate_speech(
321
+ self,
322
+ text: str,
323
+ voice: Optional[VoiceProfile] = None,
324
+ max_new_tokens: Optional[int] = None,
325
+ repetition_penalty: Optional[float] = None,
326
+ ) -> np.ndarray:
327
+ """Generate complete audio for the given text."""
328
+ voice = voice or self.default_voice
329
+ text = text_processor.sanitize(text.strip()[: self.cfg.MAX_TEXT_LENGTH])
330
+ if not text:
331
+ raise ValueError("Text is empty after sanitization")
332
+
333
+ tokens = self._generate_tokens(
334
+ text, voice,
335
+ max_new_tokens or self.cfg.MAX_NEW_TOKENS,
336
+ repetition_penalty or self.cfg.REPETITION_PENALTY,
337
+ )
338
+ return self._decode_tokens(tokens, voice)
339
+
340
+ # ─── Streaming generation ─────────────────────────────────────
341
+
342
+ def stream_speech(
343
+ self,
344
+ text: str,
345
+ voice: Optional[VoiceProfile] = None,
346
+ max_new_tokens: Optional[int] = None,
347
+ repetition_penalty: Optional[float] = None,
348
+ is_cancelled: Optional[Callable[[], bool]] = None,
349
+ ) -> Generator[np.ndarray, None, None]:
350
+ """Yield audio chunks sentence-by-sentence for real-time streaming.
351
+
352
+ Each sentence is independently processed through the full pipeline
353
+ so the first chunk arrives as fast as possible (low TTFB).
354
+
355
+ Args:
356
+ is_cancelled: Optional callable that returns True to abort generation.
357
+ Checked between chunks and every 25 autoregressive steps.
358
+ """
359
+ voice = voice or self.default_voice
360
+ text = text_processor.sanitize(text.strip()[: self.cfg.MAX_TEXT_LENGTH])
361
+ if not text:
362
+ return
363
+
364
+ sentences = text_processor.split_for_streaming(text)
365
+ _max = max_new_tokens or self.cfg.MAX_NEW_TOKENS
366
+ _rep = repetition_penalty or self.cfg.REPETITION_PENALTY
367
+ _check = is_cancelled or (lambda: False)
368
+
369
+ for i, sentence in enumerate(sentences):
370
+ # Check cancellation between chunks
371
+ if _check():
372
+ logger.info("Generation cancelled by client (between chunks)")
373
+ return
374
+ if not sentence.strip():
375
+ continue
376
+ t0 = time.perf_counter()
377
+ try:
378
+ tokens = self._generate_tokens(sentence, voice, _max, _rep, _check)
379
+ if _check():
380
+ return
381
+ audio = self._decode_tokens(tokens, voice)
382
+ elapsed = time.perf_counter() - t0
383
+ audio_duration = len(audio) / self.cfg.SAMPLE_RATE
384
+ rtf = elapsed / audio_duration if audio_duration > 0 else 0
385
+ logger.info(
386
+ f"Chunk {i + 1}/{len(sentences)}: "
387
+ f"{len(sentence)} chars β†’ {audio_duration:.1f}s audio "
388
+ f"in {elapsed:.2f}s (RTF: {rtf:.2f}x)"
389
+ )
390
+ yield audio
391
+ except GenerationCancelled:
392
+ logger.info(f"Generation cancelled mid-token at chunk {i + 1}")
393
+ return
394
+ except Exception as e:
395
+ logger.error(f"Error on chunk {i + 1}: {e}")
396
+ raise
397
+
398
+ # ─── Autoregressive token generation (OPTIMISED) ──────────────
399
+
400
+ def _generate_tokens(
401
+ self,
402
+ text: str,
403
+ voice: VoiceProfile,
404
+ max_new_tokens: int,
405
+ repetition_penalty: float,
406
+ is_cancelled: Callable[[], bool] = lambda: False,
407
+ ) -> np.ndarray:
408
+ """Run embed β†’ LM autoregressive loop. Returns raw token array.
409
+
410
+ Optimisations:
411
+ β€’ Token list instead of repeated np.concatenate (O(n) β†’ O(1) append)
412
+ β€’ Unique tokens set for inline repetition penalty (avoids exponential penalty bug)
413
+ β€’ Pre-allocated attention mask for zero-copy slicing
414
+ β€’ Correct dimensional routing for step 0 prompt processing
415
+ """
416
+ input_ids = self.tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)
417
+
418
+ # Pre-allocate collections
419
+ token_list: list[int] = [self.cfg.START_SPEECH_TOKEN]
420
+ unique_tokens: set[int] = {self.cfg.START_SPEECH_TOKEN}
421
+ penalty = repetition_penalty
422
+
423
+ past_key_values = None
424
+ attention_mask_full = None
425
+ seq_len = 0
426
+
427
+ for step in range(max_new_tokens):
428
+ if step > 0 and step % 25 == 0 and is_cancelled():
429
+ raise GenerationCancelled()
430
+
431
+ embeds = self.embed_session.run(None, {"input_ids": input_ids})[0]
432
+
433
+ if step == 0:
434
+ # Prepend speaker conditioning
435
+ embeds = np.concatenate((voice.cond_emb, embeds), axis=1)
436
+ batch, seq_len, _ = embeds.shape
437
+
438
+ past_key_values = {
439
+ inp.name: np.zeros(
440
+ [batch, self.cfg.NUM_KV_HEADS, 0, self.cfg.HEAD_DIM],
441
+ dtype=np.float16 if inp.type == "tensor(float16)" else np.float32,
442
+ )
443
+ for inp in self.lm_session.get_inputs()
444
+ if "past_key_values" in inp.name
445
+ }
446
+
447
+ # Pre-allocate full attention mask
448
+ attention_mask_full = np.ones((batch, seq_len + max_new_tokens), dtype=np.int64)
449
+ attention_mask = attention_mask_full[:, :seq_len]
450
+
451
+ # Step 0 requires position_ids matching prompt sequence length
452
+ position_ids = np.arange(seq_len, dtype=np.int64).reshape(batch, -1)
453
+ else:
454
+ # O(1) zero-copy slice for subsequent steps
455
+ attention_mask = attention_mask_full[:, : seq_len + step]
456
+ # Single position ID for the single new token
457
+ position_ids = np.array([[seq_len + step - 1]], dtype=np.int64)
458
+
459
+ # Language model forward pass
460
+ logits, *present_kv = self.lm_session.run(
461
+ None,
462
+ dict(
463
+ inputs_embeds=embeds,
464
+ attention_mask=attention_mask,
465
+ position_ids=position_ids,
466
+ **past_key_values,
467
+ ),
468
+ )
469
+
470
+ # ── Inline repetition penalty + token selection ───────
471
+ last_logits = logits[0, -1, :].copy() # shape: (vocab_size,)
472
+
473
+ # Apply repetition penalty strictly to unique tokens to prevent over-penalization
474
+ for tok_id in unique_tokens:
475
+ if last_logits[tok_id] < 0:
476
+ last_logits[tok_id] *= penalty
477
+ else:
478
+ last_logits[tok_id] /= penalty
479
+
480
+ next_token = int(np.argmax(last_logits))
481
+ token_list.append(next_token)
482
+ unique_tokens.add(next_token)
483
+
484
+ if next_token == self.cfg.STOP_SPEECH_TOKEN:
485
+ break
486
+
487
+ # Update state for next step
488
+ input_ids = np.array([[next_token]], dtype=np.int64)
489
+ for j, key in enumerate(past_key_values):
490
+ past_key_values[key] = present_kv[j]
491
+
492
+ return np.array([token_list], dtype=np.int64)
493
+
494
+ # ─── Token β†’ audio decoding ───────────────────────────────────
495
+
496
+ def _decode_tokens(self, generated: np.ndarray, voice: VoiceProfile) -> np.ndarray:
497
+ """Decode speech tokens to a float32 waveform at 24 kHz."""
498
+ # Strip START token; strip STOP token if present
499
+ tokens = generated[:, 1:]
500
+ if tokens.shape[1] > 0 and tokens[0, -1] == self.cfg.STOP_SPEECH_TOKEN:
501
+ tokens = tokens[:, :-1]
502
+
503
+ if tokens.shape[1] == 0:
504
+ return np.zeros(0, dtype=np.float32)
505
+
506
+ # Prepend prompt token + append silence
507
+ silence = np.full(
508
+ (tokens.shape[0], 3), self.cfg.SILENCE_TOKEN, dtype=np.int64
509
+ )
510
+ full_tokens = np.concatenate(
511
+ [voice.prompt_token, tokens, silence], axis=1
512
+ )
513
+
514
+ wav = self.decoder_session.run(
515
+ None,
516
+ {
517
+ "speech_tokens": full_tokens,
518
+ "speaker_embeddings": voice.speaker_embeddings,
519
+ "speaker_features": voice.speaker_features,
520
+ },
521
+ )[0].squeeze(axis=0)
522
+
523
+ return wav
524
+
525
+ # ─── Warmup ───────────────────────────────────────────────────
526
+
527
+ def warmup(self):
528
+ """Run a short inference to warm up ONNX sessions and JIT paths."""
529
+ try:
530
+ t0 = time.perf_counter()
531
+ _ = self.generate_speech("Hello.", self.default_voice, max_new_tokens=32)
532
+ logger.info(f"Warmup done in {time.perf_counter() - t0:.2f}s")
533
+ except Exception as e:
534
+ logger.warning(f"Warmup failed (non-critical): {e}")
config.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatterbox Turbo TTS β€” Centralized Configuration
3
+ ═══════════════════════════════════════════════════
4
+ Optimised for HF Space free tier (2 vCPU).
5
+ Adjust MODEL_DTYPE to switch quantization (q8/q4/fp16/fp32).
6
+ All settings overridable via environment variables prefixed CB_.
7
+ """
8
+ import os
9
+
10
+ _HERE = os.path.dirname(os.path.abspath(__file__))
11
+
12
+
13
+ def _get_bool(name: str, default: bool) -> bool:
14
+ raw = os.getenv(name)
15
+ if raw is None:
16
+ return default
17
+ return raw.strip().lower() in {"1", "true", "yes", "on"}
18
+
19
+
20
+ class Config:
21
+ # ── Model ────────────────────────────────────────────────────
22
+ MODEL_ID: str = os.getenv("CB_MODEL_ID", "ResembleAI/chatterbox-turbo-ONNX")
23
+
24
+ # fp32 β†’ highest quality, ~1.4 GB, slowest
25
+ # fp16 β†’ good quality, ~0.7 GB
26
+ # q8 β†’ β˜… recommended, ~0.35 GB, best balance
27
+ # q4 β†’ smallest, ~0.17 GB, fastest, slight loss
28
+ # q4f16 β†’ q4 weights + fp16 activations
29
+ MODEL_DTYPE: str = os.getenv("CB_MODEL_DTYPE", "q4")
30
+
31
+ MODELS_DIR: str = os.getenv("CB_MODELS_DIR", os.path.join(_HERE, "models"))
32
+
33
+ # ── ONNX Runtime CPU tuning (optimised for 2 vCPU) ───────────
34
+ #
35
+ # KEY RULE: intra_op threads MUST match physical cores.
36
+ # β†’ 4 threads on 2 cores = oversubscription = SLOWER.
37
+ # β†’ 2 threads on 2 cores = each op uses both cores perfectly.
38
+ #
39
+ # MAX_WORKERS = 1 ensures ONE inference gets both cores.
40
+ # β†’ 2 workers would split 2 cores = both requests slow.
41
+ #
42
+ CPU_THREADS: int = int(os.getenv("CB_CPU_THREADS", "2"))
43
+ MAX_WORKERS: int = int(os.getenv("CB_MAX_WORKERS", "1"))
44
+
45
+ # ── Generation defaults ──────────────────────────────────────
46
+ SAMPLE_RATE: int = 24000
47
+ MAX_NEW_TOKENS: int = int(os.getenv("CB_MAX_NEW_TOKENS", "768"))
48
+ REPETITION_PENALTY: float = float(os.getenv("CB_REPETITION_PENALTY", "1.2"))
49
+ MAX_TEXT_LENGTH: int = int(os.getenv("CB_MAX_TEXT_LENGTH", "50000"))
50
+
51
+ # ── Model constants (official card β€” do not change) ──────────
52
+ START_SPEECH_TOKEN: int = 6561
53
+ STOP_SPEECH_TOKEN: int = 6562
54
+ SILENCE_TOKEN: int = 4299
55
+ NUM_KV_HEADS: int = 16
56
+ HEAD_DIM: int = 64
57
+
58
+ # ── Paralinguistic tags (Turbo native) ───────────────────────
59
+ PARALINGUISTIC_TAGS: tuple = (
60
+ "laugh", "chuckle", "cough", "sigh", "gasp",
61
+ "shush", "groan", "sniff", "clear throat",
62
+ )
63
+
64
+ # ── Voice / reference audio ──────────────────────────────────
65
+ # NOTE: Official ResembleAI/chatterbox-turbo-ONNX has no bundled voice.
66
+ # The default_voice.wav is a plain audio sample from community repo
67
+ # (not a model β€” just a reference WAV, safe to use from any source).
68
+ DEFAULT_VOICE_REPO: str = "onnx-community/chatterbox-ONNX"
69
+ DEFAULT_VOICE_FILE: str = "default_voice.wav"
70
+ MAX_VOICE_UPLOAD_BYTES: int = 10 * 1024 * 1024 # 10 MB
71
+ MIN_REF_DURATION_SEC: float = 1.5
72
+ MAX_REF_DURATION_SEC: float = 30.0
73
+ VOICE_CACHE_SIZE: int = int(os.getenv("CB_VOICE_CACHE_SIZE", "20"))
74
+ VOICE_CACHE_TTL_SEC: int = int(os.getenv("CB_VOICE_CACHE_TTL", "3600")) # 1 hour
75
+
76
+ # ── Streaming ────────────────────────────────────────────────
77
+ # Smaller chunks = faster TTFB (first audio arrives sooner)
78
+ # ~200 chars β‰ˆ 1–2 sentences β‰ˆ fastest first-chunk on 2 vCPU
79
+ MAX_CHUNK_CHARS: int = int(os.getenv("CB_MAX_CHUNK_CHARS", "100"))
80
+ # Additive parallel mode (odd/even split across primary/helper).
81
+ ENABLE_PARALLEL_MODE: bool = _get_bool("CB_ENABLE_PARALLEL_MODE", True)
82
+ HELPER_BASE_URL: str = os.getenv("CB_HELPER_BASE_URL", "https://shadowhunter222-hello2.hf.space").strip()
83
+ HELPER_TIMEOUT_SEC: float = float(os.getenv("CB_HELPER_TIMEOUT_SEC", "45"))
84
+ HELPER_RETRY_ONCE: bool = _get_bool("CB_HELPER_RETRY_ONCE", True)
85
+ # Optional shared secret for internal chunk endpoints.
86
+ INTERNAL_SHARED_SECRET: str = os.getenv("CB_INTERNAL_SHARED_SECRET", "").strip()
87
+
88
+ # ── Server ───────────────────────────────────────────────────
89
+ HOST: str = os.getenv("CB_HOST", "0.0.0.0")
90
+ PORT: int = int(os.getenv("CB_PORT", "7860"))
91
+
92
+ ALLOWED_ORIGINS: list = [
93
+ "https://toolboxesai.com",
94
+ "http://localhost:8788", "http://127.0.0.1:8788",
95
+ "http://localhost:5502", "http://127.0.0.1:5502",
96
+ "http://localhost:5501", "http://127.0.0.1:5501",
97
+ "http://localhost:5500", "http://127.0.0.1:5500",
98
+ "http://localhost:5173", "http://127.0.0.1:5173",
99
+ "http://localhost:7860", "http://127.0.0.1:7860",
100
+ ]
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================================================
2
+ # Chatterbox Turbo TTS - Dependencies (CPU-only)
3
+ # =========================================================
4
+
5
+ # PyTorch CPU (required by transformers tokenizer internals)
6
+ torch --index-url https://download.pytorch.org/whl/cpu
7
+
8
+ # Core API
9
+ fastapi>=0.104.1
10
+ uvicorn[standard]>=0.24.0
11
+ pydantic>=2.5.0
12
+ python-multipart>=0.0.6
13
+
14
+ # ONNX Runtime (CPU inference)
15
+ onnxruntime>=1.17.0
16
+
17
+ # Audio processing
18
+ numpy>=1.24.0
19
+ librosa>=0.10.0
20
+ soundfile>=0.12.0
21
+
22
+ # Tokenizer + model download
23
+ transformers>=4.46.0
24
+ huggingface-hub>=0.19.0
text_processor.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chatterbox Turbo TTS β€” Text Processor
3
+ ═══════════════════════════════════════
4
+ Sanitizes raw input text and splits it into sentence-level chunks
5
+ for streaming TTS. Paralinguistic tags ([laugh], [cough], …) are
6
+ explicitly preserved so the model can render them.
7
+ """
8
+ import re
9
+ from typing import List
10
+
11
+ from config import Config
12
+
13
+ # ═══════════════════════════════════════════════════════════════════
14
+ # Pre-compiled regex patterns (compiled once at import β†’ zero cost)
15
+ # ═══════════════════════════════════════════════════════════════════
16
+
17
+ # β€” Paralinguistic tag protector (matches [laugh], [clear throat], etc.)
18
+ _TAG_NAMES = "|".join(re.escape(t) for t in Config.PARALINGUISTIC_TAGS)
19
+ _RE_PARA_TAG = re.compile(rf"\[(?:{_TAG_NAMES})\]", re.IGNORECASE)
20
+
21
+ # β€” Markdown / structural noise
22
+ _RE_CODE_BLOCK = re.compile(r"```[\s\S]*?```")
23
+ _RE_INLINE_CODE = re.compile(r"`([^`]+)`")
24
+ _RE_IMAGE = re.compile(r"!\[([^\]]*)\]\([^)]+\)")
25
+ _RE_LINK = re.compile(r"\[([^\]]+)\]\([^)]+\)")
26
+ _RE_BOLD_AST = re.compile(r"\*\*(.+?)\*\*")
27
+ _RE_BOLD_UND = re.compile(r"__(.+?)__")
28
+ _RE_STRIKE = re.compile(r"~~(.+?)~~")
29
+ _RE_ITALIC_AST = re.compile(r"\*(.+?)\*")
30
+ _RE_ITALIC_UND = re.compile(r"(?<!\w)_(.+?)_(?!\w)")
31
+ _RE_HEADER = re.compile(r"^#{1,6}\s+", re.MULTILINE)
32
+ _RE_BLOCKQUOTE = re.compile(r"^>+\s?", re.MULTILINE)
33
+ _RE_HR = re.compile(r"^[-*_]{3,}$", re.MULTILINE)
34
+ _RE_BULLET = re.compile(r"^\s*[-*+]\s+", re.MULTILINE)
35
+ _RE_ORDERED = re.compile(r"^\s*\d+\.\s+", re.MULTILINE)
36
+
37
+ # β€” URLs, emojis, HTML entities
38
+ _RE_URL = re.compile(r"https?://\S+")
39
+ _RE_EMOJI = re.compile(
40
+ r"["
41
+ r"\U0001F600-\U0001F64F\U0001F300-\U0001F5FF"
42
+ r"\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF"
43
+ r"\U00002702-\U000027B0\U0001F900-\U0001F9FF"
44
+ r"\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF"
45
+ r"\U00002600-\U000026FF\U0000FE00-\U0000FE0F"
46
+ r"\U0000200D"
47
+ r"]+", re.UNICODE,
48
+ )
49
+ _RE_HTML_ENTITY = re.compile(r"&(?:#x?[\da-fA-F]+|\w+);")
50
+ _HTML_ENTITIES = {
51
+ "&amp;": " and ", "&lt;": " less than ", "&gt;": " greater than ",
52
+ "&nbsp;": " ", "&quot;": '"', "&apos;": "'",
53
+ "&mdash;": ", ", "&ndash;": ", ", "&hellip;": ".",
54
+ }
55
+
56
+ # β€” Punctuation normalization
57
+ _RE_REPEATED_DOT = re.compile(r"\.{2,}")
58
+ _RE_REPEATED_EXCLAM = re.compile(r"!{2,}")
59
+ _RE_REPEATED_QUEST = re.compile(r"\?{2,}")
60
+ _RE_REPEATED_SEMI = re.compile(r";{2,}")
61
+ _RE_REPEATED_COLON = re.compile(r":{2,}")
62
+ _RE_REPEATED_COMMA = re.compile(r",{2,}")
63
+ _RE_REPEATED_DASH = re.compile(r"-{3,}")
64
+
65
+ # β€” Whitespace
66
+ _RE_MULTI_SPACE = re.compile(r"[ \t]+")
67
+ _RE_MULTI_NEWLINE = re.compile(r"\n{3,}")
68
+ _RE_SPACE_BEFORE_PUN = re.compile(r"\s+([.!?,;:])")
69
+
70
+ # β€” Sentence boundary (split point)
71
+ _RE_SENTENCE_SPLIT = re.compile(r"(?<=[.!?;:])\s+")
72
+
73
+ _MIN_MERGE_WORDS = 5
74
+
75
+
76
+ # ═══════════════════════════════════════════════════════════════════
77
+ # Public API
78
+ # ═══════════════════════════════════════════════════════════════════
79
+
80
+ def sanitize(text: str) -> str:
81
+ """Clean raw input for TTS while preserving paralinguistic tags."""
82
+ if not text:
83
+ return text
84
+
85
+ # 1. Protect paralinguistic tags by replacing with placeholders
86
+ tags_found: list[tuple[int, str]] = []
87
+ def _protect_tag(m):
88
+ idx = len(tags_found)
89
+ tags_found.append((idx, m.group(0)))
90
+ return f"Β§TAG{idx}Β§"
91
+ text = _RE_PARA_TAG.sub(_protect_tag, text)
92
+
93
+ # 2. Strip non-speakable structures
94
+ text = _RE_URL.sub("", text)
95
+ text = _RE_CODE_BLOCK.sub("", text)
96
+ text = _RE_IMAGE.sub(lambda m: m.group(1) if m.group(1) else "", text)
97
+ text = _RE_LINK.sub(r"\1", text)
98
+ text = _RE_BOLD_AST.sub(r"\1", text)
99
+ text = _RE_BOLD_UND.sub(r"\1", text)
100
+ text = _RE_STRIKE.sub(r"\1", text)
101
+ text = _RE_ITALIC_AST.sub(r"\1", text)
102
+ text = _RE_ITALIC_UND.sub(r"\1", text)
103
+ text = _RE_INLINE_CODE.sub(r"\1", text)
104
+ text = _RE_HEADER.sub("", text)
105
+ text = _RE_BLOCKQUOTE.sub("", text)
106
+ text = _RE_HR.sub("", text)
107
+ text = _RE_BULLET.sub("", text)
108
+ text = _RE_ORDERED.sub("", text)
109
+
110
+ # 3. Emojis, hashtags
111
+ text = _RE_EMOJI.sub("", text)
112
+ text = re.sub(r"#(\w+)", r"\1", text)
113
+
114
+ # 4. HTML entities
115
+ text = _RE_HTML_ENTITY.sub(lambda m: _HTML_ENTITIES.get(m.group(0), ""), text)
116
+
117
+ # 5. Collapse repeated punctuation
118
+ text = _RE_REPEATED_DOT.sub(".", text)
119
+ text = _RE_REPEATED_EXCLAM.sub("!", text)
120
+ text = _RE_REPEATED_QUEST.sub("?", text)
121
+ text = _RE_REPEATED_SEMI.sub(";", text)
122
+ text = _RE_REPEATED_COLON.sub(":", text)
123
+ text = _RE_REPEATED_COMMA.sub(",", text)
124
+ text = _RE_REPEATED_DASH.sub("β€”", text)
125
+
126
+ # 6. Whitespace
127
+ text = _RE_SPACE_BEFORE_PUN.sub(r"\1", text)
128
+ text = _RE_MULTI_SPACE.sub(" ", text)
129
+ text = _RE_MULTI_NEWLINE.sub("\n\n", text)
130
+ text = text.strip()
131
+
132
+ # 7. Restore paralinguistic tags
133
+ for idx, original in tags_found:
134
+ text = text.replace(f"Β§TAG{idx}Β§", original)
135
+
136
+ return text
137
+
138
+
139
+ def split_for_streaming(text: str, max_chars: int = Config.MAX_CHUNK_CHARS) -> List[str]:
140
+ """Split sanitized text into sentence-level chunks for streaming.
141
+
142
+ Strategy:
143
+ 1. Split on sentence-ending punctuation boundaries
144
+ 2. Enforce max_chars per chunk (split long sentences on commas / spaces)
145
+ 3. Merge short chunks (≀5 words) with the next to avoid tiny segments
146
+ """
147
+ if not text:
148
+ return []
149
+
150
+ # Step 1: sentence split
151
+ raw_chunks = _RE_SENTENCE_SPLIT.split(text)
152
+ raw_chunks = [c.strip() for c in raw_chunks if c.strip()]
153
+
154
+ # Step 2: enforce max length per chunk
155
+ sized: List[str] = []
156
+ for chunk in raw_chunks:
157
+ if len(chunk) <= max_chars:
158
+ sized.append(chunk)
159
+ else:
160
+ sized.extend(_break_long_chunk(chunk, max_chars))
161
+
162
+ # Step 3: merge short chunks
163
+ if len(sized) <= 1:
164
+ return sized
165
+
166
+ merged: List[str] = []
167
+ carry = ""
168
+ for i, chunk in enumerate(sized):
169
+ if carry:
170
+ chunk = carry + " " + chunk
171
+ carry = ""
172
+ if len(chunk.split()) <= _MIN_MERGE_WORDS and i < len(sized) - 1:
173
+ carry = chunk
174
+ else:
175
+ merged.append(chunk)
176
+ if carry:
177
+ if merged:
178
+ merged[-1] += " " + carry
179
+ else:
180
+ merged.append(carry)
181
+
182
+ return merged
183
+
184
+
185
+ # ═══════════════════════════════════════════════════════════════════
186
+ # Internal helpers
187
+ # ═══════════════════════════════════════════════════════════════════
188
+
189
+ def _break_long_chunk(text: str, max_chars: int) -> List[str]:
190
+ """Break a chunk longer than max_chars on commas or word boundaries."""
191
+ parts: List[str] = []
192
+ remaining = text
193
+ while len(remaining) > max_chars:
194
+ # Try comma first
195
+ pos = remaining.rfind(",", 0, max_chars)
196
+ if pos == -1:
197
+ pos = remaining.rfind(" ", 0, max_chars)
198
+ if pos == -1:
199
+ pos = max_chars # hard break
200
+ segment = remaining[:pos].strip()
201
+ if segment:
202
+ parts.append(segment)
203
+ remaining = remaining[pos:].lstrip(", ")
204
+ if remaining.strip():
205
+ parts.append(remaining.strip())
206
+ return parts