Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- Dockerfile +36 -0
- app.py +915 -0
- chatterbox_wrapper.py +534 -0
- config.py +100 -0
- requirements.txt +24 -0
- text_processor.py +206 -0
Dockerfile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 2 |
+
# Chatterbox Turbo TTS β CPU-Optimised Docker Image
|
| 3 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4 |
+
FROM python:3.11-slim
|
| 5 |
+
|
| 6 |
+
# Audio codec libraries for soundfile/librosa
|
| 7 |
+
RUN apt-get update && \
|
| 8 |
+
apt-get install -y --no-install-recommends libsndfile1 ffmpeg && \
|
| 9 |
+
rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
WORKDIR /app
|
| 12 |
+
|
| 13 |
+
# Install PyTorch CPU first (from dedicated index for smaller size)
|
| 14 |
+
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 15 |
+
|
| 16 |
+
# Install remaining dependencies
|
| 17 |
+
COPY requirements.txt .
|
| 18 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 19 |
+
|
| 20 |
+
# Copy application code
|
| 21 |
+
COPY config.py text_processor.py chatterbox_wrapper.py app.py ./
|
| 22 |
+
|
| 23 |
+
# Pre-download ONNX models + tokenizer at build time
|
| 24 |
+
RUN python -c "\
|
| 25 |
+
from chatterbox_wrapper import ChatterboxWrapper; \
|
| 26 |
+
ChatterboxWrapper(download_only=True); \
|
| 27 |
+
print('Models pre-downloaded successfully')"
|
| 28 |
+
|
| 29 |
+
# Prevent thread oversubscription in production
|
| 30 |
+
ENV OMP_NUM_THREADS=1
|
| 31 |
+
ENV MKL_NUM_THREADS=1
|
| 32 |
+
ENV OPENBLAS_NUM_THREADS=1
|
| 33 |
+
|
| 34 |
+
EXPOSE 7860
|
| 35 |
+
|
| 36 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "1"]
|
app.py
ADDED
|
@@ -0,0 +1,915 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chatterbox Turbo TTS -- FastAPI Server
|
| 3 |
+
======================================
|
| 4 |
+
Production-ready API with true real-time MP3 streaming,
|
| 5 |
+
in-memory voice cloning, and fully non-blocking inference.
|
| 6 |
+
|
| 7 |
+
Endpoints:
|
| 8 |
+
GET /health -> health check + optional warmup
|
| 9 |
+
GET /info -> model info, supported tags, parameters
|
| 10 |
+
POST /tts -> full audio response (WAV/MP3/FLAC)
|
| 11 |
+
POST /tts/stream -> chunked MP3 streaming (MediaSource-ready)
|
| 12 |
+
POST /tts/true-stream -> alias for /tts/stream (Kokoro compat)
|
| 13 |
+
POST /tts/stop/{stream_id}-> cancel a specific active stream
|
| 14 |
+
POST /tts/stop -> cancel ALL active streams
|
| 15 |
+
POST /v1/audio/speech -> OpenAI-compatible streaming
|
| 16 |
+
"""
|
| 17 |
+
import asyncio
|
| 18 |
+
import io
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
import queue as stdlib_queue
|
| 22 |
+
import threading
|
| 23 |
+
import time
|
| 24 |
+
import urllib.error
|
| 25 |
+
import urllib.parse
|
| 26 |
+
import urllib.request
|
| 27 |
+
import uuid
|
| 28 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 29 |
+
from typing import Generator, Optional
|
| 30 |
+
|
| 31 |
+
import numpy as np
|
| 32 |
+
import soundfile as sf
|
| 33 |
+
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
| 34 |
+
from fastapi.responses import Response, StreamingResponse
|
| 35 |
+
from contextlib import asynccontextmanager
|
| 36 |
+
|
| 37 |
+
from config import Config
|
| 38 |
+
from chatterbox_wrapper import ChatterboxWrapper, GenerationCancelled, VoiceProfile
|
| 39 |
+
import text_processor
|
| 40 |
+
|
| 41 |
+
# ββ Logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
logging.basicConfig(
|
| 43 |
+
level=logging.INFO,
|
| 44 |
+
format="%(asctime)s β %(levelname)-7s β %(name)s β %(message)s",
|
| 45 |
+
datefmt="%H:%M:%S",
|
| 46 |
+
)
|
| 47 |
+
logger = logging.getLogger(__name__)
|
| 48 |
+
|
| 49 |
+
# ββ Thread pool for CPU-bound inference βββββββββββββββββββββββββββ
|
| 50 |
+
tts_executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ββ Lifespan ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
+
|
| 55 |
+
@asynccontextmanager
|
| 56 |
+
async def lifespan(app: FastAPI):
|
| 57 |
+
try:
|
| 58 |
+
wrapper = ChatterboxWrapper()
|
| 59 |
+
app.state.wrapper = wrapper
|
| 60 |
+
logger.info("β
Model loaded, server ready")
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.error(f"β Model loading failed: {e}")
|
| 63 |
+
raise
|
| 64 |
+
yield
|
| 65 |
+
tts_executor.shutdown(wait=False)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
app = FastAPI(
|
| 69 |
+
title="Chatterbox Turbo TTS API",
|
| 70 |
+
version="1.0.0",
|
| 71 |
+
docs_url="/docs",
|
| 72 |
+
lifespan=lifespan,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ββ CORS Middleware βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
|
| 78 |
+
@app.middleware("http")
|
| 79 |
+
async def cors_middleware(request: Request, call_next):
|
| 80 |
+
origin = request.headers.get("origin")
|
| 81 |
+
|
| 82 |
+
# Preflight
|
| 83 |
+
if request.method == "OPTIONS" and origin in Config.ALLOWED_ORIGINS:
|
| 84 |
+
return Response(
|
| 85 |
+
status_code=200,
|
| 86 |
+
headers={
|
| 87 |
+
"Access-Control-Allow-Origin": origin,
|
| 88 |
+
"Access-Control-Allow-Methods": "*",
|
| 89 |
+
"Access-Control-Allow-Headers": "*",
|
| 90 |
+
"Access-Control-Allow-Credentials": "true",
|
| 91 |
+
},
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if not origin or origin in Config.ALLOWED_ORIGINS:
|
| 95 |
+
response = await call_next(request)
|
| 96 |
+
if origin:
|
| 97 |
+
response.headers["Access-Control-Allow-Origin"] = origin
|
| 98 |
+
response.headers["Access-Control-Allow-Credentials"] = "true"
|
| 99 |
+
response.headers["Access-Control-Allow-Methods"] = "*"
|
| 100 |
+
response.headers["Access-Control-Allow-Headers"] = "*"
|
| 101 |
+
response.headers["Access-Control-Expose-Headers"] = "X-Stream-Id"
|
| 102 |
+
return response
|
| 103 |
+
|
| 104 |
+
logger.warning(f"π« Blocked origin: {origin}")
|
| 105 |
+
return Response(status_code=403, content="Forbidden: Origin not allowed")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
+
# Helper: resolve voice from optional upload
|
| 110 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 111 |
+
|
| 112 |
+
async def _resolve_voice(
|
| 113 |
+
voice_ref: Optional[UploadFile],
|
| 114 |
+
wrapper: ChatterboxWrapper,
|
| 115 |
+
) -> VoiceProfile:
|
| 116 |
+
"""Return a VoiceProfile from uploaded audio or the default voice."""
|
| 117 |
+
if voice_ref is None or voice_ref.filename == "":
|
| 118 |
+
return wrapper.default_voice
|
| 119 |
+
|
| 120 |
+
audio_bytes = await voice_ref.read()
|
| 121 |
+
if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES:
|
| 122 |
+
raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)")
|
| 123 |
+
if len(audio_bytes) == 0:
|
| 124 |
+
raise HTTPException(status_code=400, detail="Empty voice file")
|
| 125 |
+
|
| 126 |
+
loop = asyncio.get_running_loop()
|
| 127 |
+
try:
|
| 128 |
+
return await loop.run_in_executor(
|
| 129 |
+
tts_executor, wrapper.encode_voice_from_bytes, audio_bytes
|
| 130 |
+
)
|
| 131 |
+
except ValueError as e:
|
| 132 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Voice encoding failed: {e}")
|
| 135 |
+
raise HTTPException(
|
| 136 |
+
status_code=400,
|
| 137 |
+
detail=f"Could not process voice file: {str(e)}. "
|
| 138 |
+
f"Supported formats: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM."
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 143 |
+
# Helper: encode numpy audio to bytes in given format
|
| 144 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 145 |
+
|
| 146 |
+
def _encode_audio(audio: np.ndarray, fmt: str = "wav") -> tuple[bytes, str]:
|
| 147 |
+
buf = io.BytesIO()
|
| 148 |
+
fmt_lower = fmt.lower()
|
| 149 |
+
if fmt_lower == "mp3":
|
| 150 |
+
sf.write(buf, audio, Config.SAMPLE_RATE, format="mp3")
|
| 151 |
+
media = "audio/mpeg"
|
| 152 |
+
elif fmt_lower == "flac":
|
| 153 |
+
sf.write(buf, audio, Config.SAMPLE_RATE, format="flac")
|
| 154 |
+
media = "audio/flac"
|
| 155 |
+
else:
|
| 156 |
+
sf.write(buf, audio, Config.SAMPLE_RATE, format="wav")
|
| 157 |
+
media = "audio/wav"
|
| 158 |
+
return buf.getvalue(), media
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _encode_mp3_chunk(audio: np.ndarray) -> bytes:
|
| 162 |
+
"""Encode one numpy chunk to MP3 bytes (same encoder path as current server)."""
|
| 163 |
+
data, _ = _encode_audio(audio, fmt="mp3")
|
| 164 |
+
return data
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _build_helper_endpoint(base_url: str, path: str) -> str:
|
| 168 |
+
return f"{base_url.rstrip('/')}{path}"
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _internal_headers() -> dict[str, str]:
|
| 172 |
+
headers = {"Content-Type": "application/json", "Accept": "audio/mpeg"}
|
| 173 |
+
if Config.INTERNAL_SHARED_SECRET:
|
| 174 |
+
headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET
|
| 175 |
+
return headers
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _helper_request_chunk(
|
| 179 |
+
helper_base_url: str,
|
| 180 |
+
payload: dict,
|
| 181 |
+
timeout_sec: float,
|
| 182 |
+
) -> bytes:
|
| 183 |
+
url = _build_helper_endpoint(helper_base_url, "/internal/chunk/synthesize")
|
| 184 |
+
body = json.dumps(payload).encode("utf-8")
|
| 185 |
+
req = urllib.request.Request(
|
| 186 |
+
url=url,
|
| 187 |
+
data=body,
|
| 188 |
+
headers=_internal_headers(),
|
| 189 |
+
method="POST",
|
| 190 |
+
)
|
| 191 |
+
with urllib.request.urlopen(req, timeout=timeout_sec) as resp:
|
| 192 |
+
return resp.read()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _helper_register_voice(
|
| 196 |
+
helper_base_url: str,
|
| 197 |
+
stream_id: str,
|
| 198 |
+
audio_bytes: bytes,
|
| 199 |
+
timeout_sec: float,
|
| 200 |
+
) -> str:
|
| 201 |
+
"""Register reference voice on helper once, return voice_key for chunk calls."""
|
| 202 |
+
query = urllib.parse.urlencode({"stream_id": stream_id})
|
| 203 |
+
url = _build_helper_endpoint(helper_base_url, f"/internal/voice/register?{query}")
|
| 204 |
+
headers = {"Content-Type": "application/octet-stream", "Accept": "application/json"}
|
| 205 |
+
if Config.INTERNAL_SHARED_SECRET:
|
| 206 |
+
headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET
|
| 207 |
+
|
| 208 |
+
req = urllib.request.Request(
|
| 209 |
+
url=url,
|
| 210 |
+
data=audio_bytes,
|
| 211 |
+
headers=headers,
|
| 212 |
+
method="POST",
|
| 213 |
+
)
|
| 214 |
+
with urllib.request.urlopen(req, timeout=timeout_sec) as resp:
|
| 215 |
+
data = json.loads(resp.read().decode("utf-8"))
|
| 216 |
+
voice_key = (data.get("voice_key") or "").strip()
|
| 217 |
+
if not voice_key:
|
| 218 |
+
raise RuntimeError("helper voice registration returned no voice_key")
|
| 219 |
+
return voice_key
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def _helper_cancel_stream(helper_base_url: str, stream_id: str):
|
| 223 |
+
"""Best-effort cancellation signal to helper."""
|
| 224 |
+
try:
|
| 225 |
+
url = _build_helper_endpoint(helper_base_url, f"/internal/chunk/cancel/{stream_id}")
|
| 226 |
+
req = urllib.request.Request(
|
| 227 |
+
url=url,
|
| 228 |
+
data=b"",
|
| 229 |
+
headers=_internal_headers(),
|
| 230 |
+
method="POST",
|
| 231 |
+
)
|
| 232 |
+
with urllib.request.urlopen(req, timeout=3.0):
|
| 233 |
+
pass
|
| 234 |
+
except Exception:
|
| 235 |
+
pass
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 239 |
+
# Endpoints
|
| 240 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 241 |
+
|
| 242 |
+
@app.get("/health")
|
| 243 |
+
async def health(warm_up: bool = False):
|
| 244 |
+
wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
|
| 245 |
+
status = {
|
| 246 |
+
"status": "healthy" if wrapper else "loading",
|
| 247 |
+
"model_loaded": wrapper is not None,
|
| 248 |
+
"model_dtype": Config.MODEL_DTYPE,
|
| 249 |
+
"streaming_supported": True,
|
| 250 |
+
"voice_cache_entries": wrapper._voice_cache.size if wrapper else 0,
|
| 251 |
+
}
|
| 252 |
+
if warm_up and wrapper:
|
| 253 |
+
try:
|
| 254 |
+
loop = asyncio.get_running_loop()
|
| 255 |
+
await loop.run_in_executor(tts_executor, wrapper.warmup)
|
| 256 |
+
status["warm_up"] = "success"
|
| 257 |
+
except Exception as e:
|
| 258 |
+
status["warm_up"] = f"failed: {e}"
|
| 259 |
+
return status
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@app.get("/info")
|
| 263 |
+
async def info():
|
| 264 |
+
return {
|
| 265 |
+
"model": Config.MODEL_ID,
|
| 266 |
+
"dtype": Config.MODEL_DTYPE,
|
| 267 |
+
"sample_rate": Config.SAMPLE_RATE,
|
| 268 |
+
"paralinguistic_tags": list(Config.PARALINGUISTIC_TAGS),
|
| 269 |
+
"tag_usage": "Insert tags directly in text, e.g. 'That is so funny! [laugh] Anywayβ¦'",
|
| 270 |
+
"parameters": {
|
| 271 |
+
"max_new_tokens": {"default": Config.MAX_NEW_TOKENS, "range": "64β2048"},
|
| 272 |
+
"repetition_penalty": {"default": Config.REPETITION_PENALTY, "range": "1.0β2.0"},
|
| 273 |
+
},
|
| 274 |
+
"voice_cloning": {
|
| 275 |
+
"description": "Upload 3β30s reference WAV/MP3 as 'voice_ref' field",
|
| 276 |
+
"max_upload_mb": Config.MAX_VOICE_UPLOAD_BYTES // (1024 * 1024),
|
| 277 |
+
},
|
| 278 |
+
"parallel_mode": {
|
| 279 |
+
"enabled": Config.ENABLE_PARALLEL_MODE,
|
| 280 |
+
"helper_configured": bool(Config.HELPER_BASE_URL),
|
| 281 |
+
"helper_base_url": Config.HELPER_BASE_URL or None,
|
| 282 |
+
"supports_voice_ref": True,
|
| 283 |
+
},
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# ββ POST /tts βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 288 |
+
|
| 289 |
+
@app.post("/tts", response_class=Response)
|
| 290 |
+
async def text_to_speech(
|
| 291 |
+
text: str = Form(...),
|
| 292 |
+
voice_ref: Optional[UploadFile] = File(None),
|
| 293 |
+
output_format: str = Form("wav"),
|
| 294 |
+
max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
|
| 295 |
+
repetition_penalty: float = Form(Config.REPETITION_PENALTY),
|
| 296 |
+
):
|
| 297 |
+
"""Generate complete audio for the given text."""
|
| 298 |
+
wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
|
| 299 |
+
if not wrapper:
|
| 300 |
+
raise HTTPException(503, "Model not loaded")
|
| 301 |
+
|
| 302 |
+
if not text or not text.strip():
|
| 303 |
+
raise HTTPException(400, "Text is required")
|
| 304 |
+
|
| 305 |
+
voice = await _resolve_voice(voice_ref, wrapper)
|
| 306 |
+
|
| 307 |
+
loop = asyncio.get_running_loop()
|
| 308 |
+
try:
|
| 309 |
+
audio = await loop.run_in_executor(
|
| 310 |
+
tts_executor,
|
| 311 |
+
wrapper.generate_speech,
|
| 312 |
+
text, voice, max_new_tokens, repetition_penalty,
|
| 313 |
+
)
|
| 314 |
+
except ValueError as e:
|
| 315 |
+
raise HTTPException(400, str(e))
|
| 316 |
+
except Exception as e:
|
| 317 |
+
logger.error(f"TTS error: {e}")
|
| 318 |
+
raise HTTPException(500, "Internal server error")
|
| 319 |
+
|
| 320 |
+
data, media_type = _encode_audio(audio, output_format)
|
| 321 |
+
return Response(
|
| 322 |
+
content=data,
|
| 323 |
+
media_type=media_type,
|
| 324 |
+
headers={"Content-Disposition": f"attachment; filename=tts_output.{output_format}"},
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 328 |
+
# Active Stream Registry (for cancellation)
|
| 329 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 330 |
+
|
| 331 |
+
_active_streams: dict[str, threading.Event] = {}
|
| 332 |
+
_internal_cancelled_streams: set[str] = set()
|
| 333 |
+
_internal_cancel_lock = threading.Lock()
|
| 334 |
+
_internal_stream_voice_keys: dict[str, set[str]] = {}
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 338 |
+
# Pipeline Streaming Generator
|
| 339 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 340 |
+
|
| 341 |
+
def _pipeline_stream_generator(
|
| 342 |
+
wrapper: ChatterboxWrapper,
|
| 343 |
+
text: str,
|
| 344 |
+
voice: VoiceProfile,
|
| 345 |
+
max_new_tokens: int,
|
| 346 |
+
repetition_penalty: float,
|
| 347 |
+
stream_id: str,
|
| 348 |
+
) -> Generator[bytes, None, None]:
|
| 349 |
+
"""Two-stage producer-consumer pipeline for minimal inter-chunk gaps.
|
| 350 |
+
|
| 351 |
+
Architecture:
|
| 352 |
+
Producer thread (heavyweight, ~80% CPU):
|
| 353 |
+
ONNX token generation β audio decoding β raw numpy arrays β queue
|
| 354 |
+
|
| 355 |
+
Consumer (this generator, lightweight, ~20% CPU):
|
| 356 |
+
queue β MP3 encode β yield to HTTP response
|
| 357 |
+
|
| 358 |
+
Why this helps:
|
| 359 |
+
- ONNX model runs CONTINUOUSLY without waiting for MP3 encode or HTTP
|
| 360 |
+
- MP3 encoding (libsndfile, C code) releases GIL β true parallelism
|
| 361 |
+
- ONNX inference (C++ code) also releases GIL β both run simultaneously
|
| 362 |
+
- Queue(maxsize=2) lets producer stay 1-2 chunks ahead
|
| 363 |
+
|
| 364 |
+
Cancellation:
|
| 365 |
+
- cancel_event checked between chunks + every 25 autoregressive steps
|
| 366 |
+
- Client disconnect triggers GeneratorExit β finally sets cancel
|
| 367 |
+
- /tts/stop endpoint sets cancel externally
|
| 368 |
+
"""
|
| 369 |
+
cancel_event = threading.Event()
|
| 370 |
+
_active_streams[stream_id] = cancel_event
|
| 371 |
+
|
| 372 |
+
# Raw audio buffer: producer puts numpy arrays, consumer takes them
|
| 373 |
+
audio_buffer: stdlib_queue.Queue = stdlib_queue.Queue(maxsize=2)
|
| 374 |
+
|
| 375 |
+
def _producer():
|
| 376 |
+
"""Heavyweight worker: runs ONNX model continuously."""
|
| 377 |
+
try:
|
| 378 |
+
for audio_chunk in wrapper.stream_speech(
|
| 379 |
+
text, voice,
|
| 380 |
+
max_new_tokens=max_new_tokens,
|
| 381 |
+
repetition_penalty=repetition_penalty,
|
| 382 |
+
is_cancelled=cancel_event.is_set,
|
| 383 |
+
):
|
| 384 |
+
if cancel_event.is_set():
|
| 385 |
+
break
|
| 386 |
+
while not cancel_event.is_set():
|
| 387 |
+
try:
|
| 388 |
+
audio_buffer.put(audio_chunk, timeout=0.1)
|
| 389 |
+
break
|
| 390 |
+
except stdlib_queue.Full:
|
| 391 |
+
continue
|
| 392 |
+
except GenerationCancelled:
|
| 393 |
+
logger.info(f"[{stream_id}] Generation cancelled")
|
| 394 |
+
except Exception as e:
|
| 395 |
+
while not cancel_event.is_set():
|
| 396 |
+
try:
|
| 397 |
+
audio_buffer.put(e, timeout=0.1)
|
| 398 |
+
break
|
| 399 |
+
except stdlib_queue.Full:
|
| 400 |
+
continue
|
| 401 |
+
finally:
|
| 402 |
+
while not cancel_event.is_set():
|
| 403 |
+
try:
|
| 404 |
+
audio_buffer.put(None, timeout=0.1)
|
| 405 |
+
break
|
| 406 |
+
except stdlib_queue.Full:
|
| 407 |
+
continue
|
| 408 |
+
|
| 409 |
+
producer = threading.Thread(target=_producer, daemon=True)
|
| 410 |
+
producer.start()
|
| 411 |
+
|
| 412 |
+
try:
|
| 413 |
+
# Consumer: lightweight MP3 encoding + yield
|
| 414 |
+
while True:
|
| 415 |
+
item = audio_buffer.get()
|
| 416 |
+
if item is None:
|
| 417 |
+
break
|
| 418 |
+
if isinstance(item, Exception):
|
| 419 |
+
logger.error(f"[{stream_id}] Stream error: {item}")
|
| 420 |
+
break
|
| 421 |
+
if cancel_event.is_set():
|
| 422 |
+
break
|
| 423 |
+
|
| 424 |
+
# MP3 encode (C code, releases GIL, runs parallel with next ONNX step)
|
| 425 |
+
buf = io.BytesIO()
|
| 426 |
+
sf.write(buf, item, Config.SAMPLE_RATE, format="mp3")
|
| 427 |
+
yield buf.getvalue()
|
| 428 |
+
finally:
|
| 429 |
+
# Cleanup: signal producer to stop + deregister
|
| 430 |
+
cancel_event.set()
|
| 431 |
+
_active_streams.pop(stream_id, None)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def _parallel_odd_even_stream_generator(
|
| 435 |
+
wrapper: ChatterboxWrapper,
|
| 436 |
+
text: str,
|
| 437 |
+
local_voice: VoiceProfile,
|
| 438 |
+
helper_voice_bytes: Optional[bytes],
|
| 439 |
+
max_new_tokens: int,
|
| 440 |
+
repetition_penalty: float,
|
| 441 |
+
stream_id: str,
|
| 442 |
+
helper_base_url: str,
|
| 443 |
+
) -> Generator[bytes, None, None]:
|
| 444 |
+
"""Additive odd/even split streamer (primary handles odd, helper handles even)."""
|
| 445 |
+
cancel_event = threading.Event()
|
| 446 |
+
_active_streams[stream_id] = cancel_event
|
| 447 |
+
|
| 448 |
+
clean_text = text_processor.sanitize(text.strip()[: Config.MAX_TEXT_LENGTH])
|
| 449 |
+
chunks = text_processor.split_for_streaming(clean_text)
|
| 450 |
+
total_chunks = len(chunks)
|
| 451 |
+
if total_chunks == 0:
|
| 452 |
+
_active_streams.pop(stream_id, None)
|
| 453 |
+
return
|
| 454 |
+
|
| 455 |
+
lock = threading.Lock()
|
| 456 |
+
cond = threading.Condition(lock)
|
| 457 |
+
ready: dict[int, bytes] = {}
|
| 458 |
+
first_error: Optional[Exception] = None
|
| 459 |
+
workers_done = 0
|
| 460 |
+
|
| 461 |
+
def _publish(idx: int, data: bytes):
|
| 462 |
+
with cond:
|
| 463 |
+
ready[idx] = data
|
| 464 |
+
cond.notify_all()
|
| 465 |
+
|
| 466 |
+
def _set_error(err: Exception):
|
| 467 |
+
nonlocal first_error
|
| 468 |
+
with cond:
|
| 469 |
+
if first_error is None:
|
| 470 |
+
first_error = err
|
| 471 |
+
cond.notify_all()
|
| 472 |
+
|
| 473 |
+
def _worker_done():
|
| 474 |
+
nonlocal workers_done
|
| 475 |
+
with cond:
|
| 476 |
+
workers_done += 1
|
| 477 |
+
cond.notify_all()
|
| 478 |
+
|
| 479 |
+
def _synth_local(chunk_text: str) -> bytes:
|
| 480 |
+
audio = wrapper.generate_speech(
|
| 481 |
+
chunk_text,
|
| 482 |
+
local_voice,
|
| 483 |
+
max_new_tokens=max_new_tokens,
|
| 484 |
+
repetition_penalty=repetition_penalty,
|
| 485 |
+
)
|
| 486 |
+
return _encode_mp3_chunk(audio)
|
| 487 |
+
|
| 488 |
+
def _odd_worker():
|
| 489 |
+
try:
|
| 490 |
+
for idx in range(0, total_chunks, 2):
|
| 491 |
+
if cancel_event.is_set():
|
| 492 |
+
break
|
| 493 |
+
data = _synth_local(chunks[idx])
|
| 494 |
+
_publish(idx, data)
|
| 495 |
+
except Exception as e:
|
| 496 |
+
_set_error(e)
|
| 497 |
+
finally:
|
| 498 |
+
_worker_done()
|
| 499 |
+
|
| 500 |
+
def _even_worker():
|
| 501 |
+
helper_available = True
|
| 502 |
+
helper_voice_key: Optional[str] = None
|
| 503 |
+
try:
|
| 504 |
+
if helper_voice_bytes:
|
| 505 |
+
attempts = 2 if Config.HELPER_RETRY_ONCE else 1
|
| 506 |
+
last_err: Optional[Exception] = None
|
| 507 |
+
for _ in range(attempts):
|
| 508 |
+
try:
|
| 509 |
+
helper_voice_key = _helper_register_voice(
|
| 510 |
+
helper_base_url=helper_base_url,
|
| 511 |
+
stream_id=stream_id,
|
| 512 |
+
audio_bytes=helper_voice_bytes,
|
| 513 |
+
timeout_sec=max(1.0, Config.HELPER_TIMEOUT_SEC),
|
| 514 |
+
)
|
| 515 |
+
last_err = None
|
| 516 |
+
break
|
| 517 |
+
except Exception as reg_err:
|
| 518 |
+
last_err = reg_err
|
| 519 |
+
continue
|
| 520 |
+
if last_err is not None:
|
| 521 |
+
helper_available = False
|
| 522 |
+
logger.warning(
|
| 523 |
+
f"[{stream_id}] Helper voice registration failed; "
|
| 524 |
+
"falling back to local synthesis for even chunks"
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
for idx in range(1, total_chunks, 2):
|
| 528 |
+
if cancel_event.is_set():
|
| 529 |
+
break
|
| 530 |
+
|
| 531 |
+
if helper_available:
|
| 532 |
+
payload = {
|
| 533 |
+
"stream_id": stream_id,
|
| 534 |
+
"chunk_index": idx,
|
| 535 |
+
"text": chunks[idx],
|
| 536 |
+
"max_new_tokens": max_new_tokens,
|
| 537 |
+
"repetition_penalty": repetition_penalty,
|
| 538 |
+
"output_format": "mp3",
|
| 539 |
+
}
|
| 540 |
+
if helper_voice_key:
|
| 541 |
+
payload["voice_key"] = helper_voice_key
|
| 542 |
+
|
| 543 |
+
attempts = 2 if Config.HELPER_RETRY_ONCE else 1
|
| 544 |
+
last_err: Optional[Exception] = None
|
| 545 |
+
for _ in range(attempts):
|
| 546 |
+
try:
|
| 547 |
+
helper_data = _helper_request_chunk(
|
| 548 |
+
helper_base_url=helper_base_url,
|
| 549 |
+
payload=payload,
|
| 550 |
+
timeout_sec=max(1.0, Config.HELPER_TIMEOUT_SEC),
|
| 551 |
+
)
|
| 552 |
+
_publish(idx, helper_data)
|
| 553 |
+
last_err = None
|
| 554 |
+
break
|
| 555 |
+
except Exception as helper_err:
|
| 556 |
+
last_err = helper_err
|
| 557 |
+
continue
|
| 558 |
+
|
| 559 |
+
if last_err is None:
|
| 560 |
+
continue
|
| 561 |
+
|
| 562 |
+
helper_available = False
|
| 563 |
+
logger.warning(
|
| 564 |
+
f"[{stream_id}] Helper failed at chunk {idx}; "
|
| 565 |
+
"falling back to local synthesis for remaining even chunks"
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# Local fallback for even chunks
|
| 569 |
+
data = _synth_local(chunks[idx])
|
| 570 |
+
_publish(idx, data)
|
| 571 |
+
except Exception as e:
|
| 572 |
+
_set_error(e)
|
| 573 |
+
finally:
|
| 574 |
+
_worker_done()
|
| 575 |
+
|
| 576 |
+
odd_thread = threading.Thread(target=_odd_worker, daemon=True)
|
| 577 |
+
even_thread = threading.Thread(target=_even_worker, daemon=True)
|
| 578 |
+
odd_thread.start()
|
| 579 |
+
even_thread.start()
|
| 580 |
+
|
| 581 |
+
next_idx = 0
|
| 582 |
+
try:
|
| 583 |
+
while next_idx < total_chunks:
|
| 584 |
+
with cond:
|
| 585 |
+
while (
|
| 586 |
+
next_idx not in ready
|
| 587 |
+
and first_error is None
|
| 588 |
+
and not cancel_event.is_set()
|
| 589 |
+
and workers_done < 2
|
| 590 |
+
):
|
| 591 |
+
cond.wait(timeout=0.1)
|
| 592 |
+
|
| 593 |
+
if cancel_event.is_set():
|
| 594 |
+
break
|
| 595 |
+
|
| 596 |
+
if next_idx in ready:
|
| 597 |
+
data = ready.pop(next_idx)
|
| 598 |
+
elif first_error is not None:
|
| 599 |
+
logger.error(f"[{stream_id}] Parallel stream error: {first_error}")
|
| 600 |
+
break
|
| 601 |
+
elif workers_done >= 2:
|
| 602 |
+
logger.error(
|
| 603 |
+
f"[{stream_id}] Parallel stream ended with missing chunk index {next_idx}"
|
| 604 |
+
)
|
| 605 |
+
break
|
| 606 |
+
else:
|
| 607 |
+
continue
|
| 608 |
+
|
| 609 |
+
yield data
|
| 610 |
+
next_idx += 1
|
| 611 |
+
finally:
|
| 612 |
+
cancel_event.set()
|
| 613 |
+
_helper_cancel_stream(helper_base_url, stream_id)
|
| 614 |
+
odd_thread.join(timeout=1.0)
|
| 615 |
+
even_thread.join(timeout=1.0)
|
| 616 |
+
_active_streams.pop(stream_id, None)
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
# ββ POST /tts/stream & /tts/true-stream ββββββββββββββββββββββββββ
|
| 620 |
+
|
| 621 |
+
@app.post("/tts/stream")
|
| 622 |
+
@app.post("/tts/true-stream")
|
| 623 |
+
async def stream_text_to_speech(
|
| 624 |
+
text: str = Form(...),
|
| 625 |
+
voice_ref: Optional[UploadFile] = File(None),
|
| 626 |
+
max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
|
| 627 |
+
repetition_penalty: float = Form(Config.REPETITION_PENALTY),
|
| 628 |
+
):
|
| 629 |
+
"""True real-time streaming: yields MP3 chunks as each sentence finishes.
|
| 630 |
+
|
| 631 |
+
Response includes X-Stream-Id header for cancellation via /tts/stop.
|
| 632 |
+
Compatible with frontend's MediaSource + ReadableStream pattern.
|
| 633 |
+
"""
|
| 634 |
+
wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
|
| 635 |
+
if not wrapper:
|
| 636 |
+
raise HTTPException(503, "Model not loaded")
|
| 637 |
+
|
| 638 |
+
if not text or not text.strip():
|
| 639 |
+
raise HTTPException(400, "Text is required")
|
| 640 |
+
|
| 641 |
+
voice = await _resolve_voice(voice_ref, wrapper)
|
| 642 |
+
stream_id = uuid.uuid4().hex[:12]
|
| 643 |
+
|
| 644 |
+
return StreamingResponse(
|
| 645 |
+
_pipeline_stream_generator(
|
| 646 |
+
wrapper, text, voice, max_new_tokens, repetition_penalty, stream_id,
|
| 647 |
+
),
|
| 648 |
+
media_type="audio/mpeg",
|
| 649 |
+
headers={
|
| 650 |
+
"Content-Disposition": "attachment; filename=tts_stream.mp3",
|
| 651 |
+
"Transfer-Encoding": "chunked",
|
| 652 |
+
"X-Stream-Id": stream_id,
|
| 653 |
+
"X-Streaming-Type": "true-realtime",
|
| 654 |
+
"Cache-Control": "no-cache",
|
| 655 |
+
},
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
@app.post("/tts/parallel-stream")
|
| 660 |
+
async def parallel_stream_text_to_speech(
|
| 661 |
+
text: str = Form(...),
|
| 662 |
+
voice_ref: Optional[UploadFile] = File(None),
|
| 663 |
+
max_new_tokens: int = Form(Config.MAX_NEW_TOKENS),
|
| 664 |
+
repetition_penalty: float = Form(Config.REPETITION_PENALTY),
|
| 665 |
+
helper_url: Optional[str] = Form(None),
|
| 666 |
+
):
|
| 667 |
+
"""Additive odd/even split stream mode (primary + helper)."""
|
| 668 |
+
wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
|
| 669 |
+
if not wrapper:
|
| 670 |
+
raise HTTPException(503, "Model not loaded")
|
| 671 |
+
if not Config.ENABLE_PARALLEL_MODE:
|
| 672 |
+
raise HTTPException(503, "Parallel mode is disabled")
|
| 673 |
+
if not text or not text.strip():
|
| 674 |
+
raise HTTPException(400, "Text is required")
|
| 675 |
+
|
| 676 |
+
local_voice: VoiceProfile = wrapper.default_voice
|
| 677 |
+
helper_voice_bytes: Optional[bytes] = None
|
| 678 |
+
if voice_ref is not None and voice_ref.filename:
|
| 679 |
+
helper_voice_bytes = await voice_ref.read()
|
| 680 |
+
if len(helper_voice_bytes) > Config.MAX_VOICE_UPLOAD_BYTES:
|
| 681 |
+
raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)")
|
| 682 |
+
if len(helper_voice_bytes) == 0:
|
| 683 |
+
raise HTTPException(status_code=400, detail="Empty voice file")
|
| 684 |
+
loop = asyncio.get_running_loop()
|
| 685 |
+
try:
|
| 686 |
+
local_voice = await loop.run_in_executor(
|
| 687 |
+
tts_executor, wrapper.encode_voice_from_bytes, helper_voice_bytes
|
| 688 |
+
)
|
| 689 |
+
except Exception as e:
|
| 690 |
+
logger.error(f"Parallel voice encoding failed: {e}")
|
| 691 |
+
raise HTTPException(400, "Could not process voice file for parallel mode")
|
| 692 |
+
|
| 693 |
+
resolved_helper = (helper_url or Config.HELPER_BASE_URL).strip()
|
| 694 |
+
if not resolved_helper:
|
| 695 |
+
raise HTTPException(
|
| 696 |
+
400,
|
| 697 |
+
"Helper URL not configured. Set CB_HELPER_BASE_URL or pass helper_url.",
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
stream_id = uuid.uuid4().hex[:12]
|
| 701 |
+
return StreamingResponse(
|
| 702 |
+
_parallel_odd_even_stream_generator(
|
| 703 |
+
wrapper=wrapper,
|
| 704 |
+
text=text,
|
| 705 |
+
local_voice=local_voice,
|
| 706 |
+
helper_voice_bytes=helper_voice_bytes,
|
| 707 |
+
max_new_tokens=max_new_tokens,
|
| 708 |
+
repetition_penalty=repetition_penalty,
|
| 709 |
+
stream_id=stream_id,
|
| 710 |
+
helper_base_url=resolved_helper,
|
| 711 |
+
),
|
| 712 |
+
media_type="audio/mpeg",
|
| 713 |
+
headers={
|
| 714 |
+
"Content-Disposition": "attachment; filename=tts_parallel_stream.mp3",
|
| 715 |
+
"Transfer-Encoding": "chunked",
|
| 716 |
+
"X-Stream-Id": stream_id,
|
| 717 |
+
"X-Streaming-Type": "parallel-odd-even",
|
| 718 |
+
"Cache-Control": "no-cache",
|
| 719 |
+
},
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
# ββ JSON body variant (Kokoro/OpenAI compatibility) βββββββββββββββ
|
| 724 |
+
|
| 725 |
+
from pydantic import BaseModel, Field
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
class InternalChunkRequest(BaseModel):
|
| 729 |
+
stream_id: str = Field(..., min_length=1, max_length=64)
|
| 730 |
+
chunk_index: int = Field(..., ge=0)
|
| 731 |
+
text: str = Field(..., min_length=1, max_length=10000)
|
| 732 |
+
max_new_tokens: int = Field(default=Config.MAX_NEW_TOKENS, ge=64, le=2048)
|
| 733 |
+
repetition_penalty: float = Field(default=Config.REPETITION_PENALTY, ge=1.0, le=2.0)
|
| 734 |
+
output_format: str = Field(default="mp3")
|
| 735 |
+
voice_key: Optional[str] = Field(default=None, min_length=1, max_length=64)
|
| 736 |
+
|
| 737 |
+
|
| 738 |
+
class TTSJsonRequest(BaseModel):
|
| 739 |
+
text: str = Field(..., min_length=1, max_length=50000)
|
| 740 |
+
voice: str = Field(default="default")
|
| 741 |
+
speed: float = Field(default=1.0, ge=0.5, le=2.0) # reserved for future use
|
| 742 |
+
max_new_tokens: int = Field(default=Config.MAX_NEW_TOKENS, ge=64, le=2048)
|
| 743 |
+
repetition_penalty: float = Field(default=Config.REPETITION_PENALTY, ge=1.0, le=2.0)
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
@app.post("/internal/voice/register")
|
| 747 |
+
async def internal_voice_register(http_request: Request):
|
| 748 |
+
"""Register voice once for a stream; returns reusable voice_key."""
|
| 749 |
+
if Config.INTERNAL_SHARED_SECRET:
|
| 750 |
+
provided = http_request.headers.get("X-Internal-Secret", "")
|
| 751 |
+
if provided != Config.INTERNAL_SHARED_SECRET:
|
| 752 |
+
raise HTTPException(403, "Forbidden")
|
| 753 |
+
|
| 754 |
+
wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
|
| 755 |
+
if not wrapper:
|
| 756 |
+
raise HTTPException(503, "Model not loaded")
|
| 757 |
+
|
| 758 |
+
audio_bytes = await http_request.body()
|
| 759 |
+
if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES:
|
| 760 |
+
raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)")
|
| 761 |
+
if len(audio_bytes) == 0:
|
| 762 |
+
raise HTTPException(status_code=400, detail="Empty voice file")
|
| 763 |
+
|
| 764 |
+
loop = asyncio.get_running_loop()
|
| 765 |
+
try:
|
| 766 |
+
voice = await loop.run_in_executor(
|
| 767 |
+
tts_executor, wrapper.encode_voice_from_bytes, audio_bytes
|
| 768 |
+
)
|
| 769 |
+
except Exception as e:
|
| 770 |
+
logger.error(f"[internal] voice register failed: {e}")
|
| 771 |
+
raise HTTPException(400, "Voice registration failed")
|
| 772 |
+
|
| 773 |
+
voice_key = (voice.audio_hash or "").strip()
|
| 774 |
+
if not voice_key:
|
| 775 |
+
raise HTTPException(500, "Voice key unavailable")
|
| 776 |
+
|
| 777 |
+
stream_id = (http_request.query_params.get("stream_id") or "").strip()
|
| 778 |
+
if stream_id:
|
| 779 |
+
with _internal_cancel_lock:
|
| 780 |
+
keys = _internal_stream_voice_keys.setdefault(stream_id, set())
|
| 781 |
+
keys.add(voice_key)
|
| 782 |
+
|
| 783 |
+
return {"status": "registered", "voice_key": voice_key}
|
| 784 |
+
|
| 785 |
+
|
| 786 |
+
@app.post("/internal/chunk/synthesize")
|
| 787 |
+
async def internal_chunk_synthesize(
|
| 788 |
+
request: InternalChunkRequest,
|
| 789 |
+
http_request: Request,
|
| 790 |
+
):
|
| 791 |
+
"""Internal endpoint used by primary/helper parallel routing."""
|
| 792 |
+
if Config.INTERNAL_SHARED_SECRET:
|
| 793 |
+
provided = http_request.headers.get("X-Internal-Secret", "")
|
| 794 |
+
if provided != Config.INTERNAL_SHARED_SECRET:
|
| 795 |
+
raise HTTPException(403, "Forbidden")
|
| 796 |
+
|
| 797 |
+
with _internal_cancel_lock:
|
| 798 |
+
if request.stream_id in _internal_cancelled_streams:
|
| 799 |
+
raise HTTPException(409, "Stream already cancelled")
|
| 800 |
+
|
| 801 |
+
wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
|
| 802 |
+
if not wrapper:
|
| 803 |
+
raise HTTPException(503, "Model not loaded")
|
| 804 |
+
|
| 805 |
+
voice_profile = wrapper.default_voice
|
| 806 |
+
if request.voice_key:
|
| 807 |
+
cached_voice = wrapper._voice_cache.get(request.voice_key)
|
| 808 |
+
if cached_voice is None:
|
| 809 |
+
raise HTTPException(409, "Voice key expired or not found")
|
| 810 |
+
voice_profile = cached_voice
|
| 811 |
+
|
| 812 |
+
loop = asyncio.get_running_loop()
|
| 813 |
+
try:
|
| 814 |
+
audio = await loop.run_in_executor(
|
| 815 |
+
tts_executor,
|
| 816 |
+
wrapper.generate_speech,
|
| 817 |
+
request.text,
|
| 818 |
+
voice_profile,
|
| 819 |
+
request.max_new_tokens,
|
| 820 |
+
request.repetition_penalty,
|
| 821 |
+
)
|
| 822 |
+
except Exception as e:
|
| 823 |
+
logger.error(f"[internal] chunk {request.chunk_index} failed: {e}")
|
| 824 |
+
raise HTTPException(500, "Chunk synthesis failed")
|
| 825 |
+
|
| 826 |
+
fmt = (request.output_format or "mp3").lower()
|
| 827 |
+
if fmt not in {"mp3", "wav", "flac"}:
|
| 828 |
+
fmt = "mp3"
|
| 829 |
+
data, media_type = _encode_audio(audio, fmt=fmt)
|
| 830 |
+
return Response(
|
| 831 |
+
content=data,
|
| 832 |
+
media_type=media_type,
|
| 833 |
+
headers={
|
| 834 |
+
"X-Stream-Id": request.stream_id,
|
| 835 |
+
"X-Chunk-Index": str(request.chunk_index),
|
| 836 |
+
},
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
@app.post("/internal/chunk/cancel/{stream_id}")
|
| 841 |
+
async def internal_chunk_cancel(stream_id: str, http_request: Request):
|
| 842 |
+
if Config.INTERNAL_SHARED_SECRET:
|
| 843 |
+
provided = http_request.headers.get("X-Internal-Secret", "")
|
| 844 |
+
if provided != Config.INTERNAL_SHARED_SECRET:
|
| 845 |
+
raise HTTPException(403, "Forbidden")
|
| 846 |
+
|
| 847 |
+
with _internal_cancel_lock:
|
| 848 |
+
_internal_cancelled_streams.add(stream_id)
|
| 849 |
+
_internal_stream_voice_keys.pop(stream_id, None)
|
| 850 |
+
return {"status": "cancelled", "stream_id": stream_id}
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
@app.post("/v1/audio/speech")
|
| 854 |
+
async def openai_compatible_tts(request: TTSJsonRequest):
|
| 855 |
+
"""OpenAI-compatible streaming endpoint (JSON body, no file upload).
|
| 856 |
+
|
| 857 |
+
Uses the default voice. For voice cloning, use /tts/stream with FormData.
|
| 858 |
+
"""
|
| 859 |
+
wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None)
|
| 860 |
+
if not wrapper:
|
| 861 |
+
raise HTTPException(503, "Model not loaded")
|
| 862 |
+
|
| 863 |
+
stream_id = uuid.uuid4().hex[:12]
|
| 864 |
+
|
| 865 |
+
return StreamingResponse(
|
| 866 |
+
_pipeline_stream_generator(
|
| 867 |
+
wrapper, request.text, wrapper.default_voice,
|
| 868 |
+
request.max_new_tokens, request.repetition_penalty, stream_id,
|
| 869 |
+
),
|
| 870 |
+
media_type="audio/mpeg",
|
| 871 |
+
headers={
|
| 872 |
+
"Transfer-Encoding": "chunked",
|
| 873 |
+
"X-Stream-Id": stream_id,
|
| 874 |
+
"Cache-Control": "no-cache",
|
| 875 |
+
},
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 880 |
+
# Stop / Cancel Endpoint
|
| 881 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 882 |
+
|
| 883 |
+
@app.post("/tts/stop/{stream_id}")
|
| 884 |
+
async def stop_stream(stream_id: str):
|
| 885 |
+
"""Stop an active TTS stream by its ID (from X-Stream-Id header).
|
| 886 |
+
|
| 887 |
+
Cancels the ONNX generation loop mid-token, freeing CPU immediately.
|
| 888 |
+
"""
|
| 889 |
+
event = _active_streams.get(stream_id)
|
| 890 |
+
if event:
|
| 891 |
+
event.set()
|
| 892 |
+
logger.info(f"Stream {stream_id} cancelled by client")
|
| 893 |
+
return {"status": "stopped", "stream_id": stream_id}
|
| 894 |
+
return {"status": "not_found", "stream_id": stream_id}
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
@app.post("/tts/stop")
|
| 898 |
+
async def stop_all_streams():
|
| 899 |
+
"""Emergency stop: cancel ALL active TTS streams."""
|
| 900 |
+
count = len(_active_streams)
|
| 901 |
+
for sid, event in list(_active_streams.items()):
|
| 902 |
+
event.set()
|
| 903 |
+
_active_streams.clear()
|
| 904 |
+
logger.info(f"Stopped all streams ({count} active)")
|
| 905 |
+
return {"status": "stopped_all", "count": count}
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 909 |
+
# Entrypoint
|
| 910 |
+
# ββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββββββββ
|
| 911 |
+
|
| 912 |
+
if __name__ == "__main__":
|
| 913 |
+
import uvicorn
|
| 914 |
+
|
| 915 |
+
uvicorn.run(app, host=Config.HOST, port=Config.PORT)
|
chatterbox_wrapper.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chatterbox Turbo TTS β ONNX Inference Wrapper
|
| 3 |
+
βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4 |
+
Orchestrates the 4-component ONNX pipeline:
|
| 5 |
+
embed_tokens β speech_encoder β language_model β conditional_decoder
|
| 6 |
+
|
| 7 |
+
Optimised for lowest-latency CPU inference on 2 vCPU:
|
| 8 |
+
β’ Sequential execution, thread count = physical cores, no spinning
|
| 9 |
+
β’ Token list pre-allocation (avoids O(nΒ²) np.concatenate in loop)
|
| 10 |
+
β’ In-memory voice caching (no disk writes for uploads)
|
| 11 |
+
β’ Robust audio loading: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM
|
| 12 |
+
β’ Sentence-level streaming for real-time playback
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
# ββ Suppress harmless transformers warnings BEFORE import βββββββββ
|
| 16 |
+
import os
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
| 20 |
+
warnings.filterwarnings("ignore", message=".*model of type.*chatterbox.*")
|
| 21 |
+
|
| 22 |
+
import hashlib
|
| 23 |
+
import io
|
| 24 |
+
import logging
|
| 25 |
+
import subprocess
|
| 26 |
+
import tempfile
|
| 27 |
+
import time
|
| 28 |
+
from collections import OrderedDict
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
from typing import Callable, Generator, Optional
|
| 31 |
+
|
| 32 |
+
import librosa
|
| 33 |
+
import numpy as np
|
| 34 |
+
import onnxruntime as ort
|
| 35 |
+
import soundfile as soundfile_lib
|
| 36 |
+
from huggingface_hub import hf_hub_download
|
| 37 |
+
from transformers import AutoTokenizer
|
| 38 |
+
|
| 39 |
+
from config import Config
|
| 40 |
+
import text_processor
|
| 41 |
+
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
# ββ Supported audio MIME types for voice upload βββββββββββββββββββ
|
| 45 |
+
_SUPPORTED_AUDIO_EXTENSIONS = {
|
| 46 |
+
".wav", ".mp3", ".mpeg", ".mpga", ".m4a", ".mp4",
|
| 47 |
+
".ogg", ".oga", ".opus", ".flac", ".webm", ".aac", ".wma",
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
# Data Structures
|
| 53 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class VoiceProfile:
|
| 57 |
+
"""Cached speaker embedding extracted from reference audio."""
|
| 58 |
+
cond_emb: np.ndarray
|
| 59 |
+
prompt_token: np.ndarray
|
| 60 |
+
speaker_embeddings: np.ndarray
|
| 61 |
+
speaker_features: np.ndarray
|
| 62 |
+
audio_hash: str = ""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class GenerationCancelled(Exception):
|
| 66 |
+
"""Raised when inference is cancelled by the client."""
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 71 |
+
# LRU Voice Cache
|
| 72 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
+
|
| 74 |
+
class _VoiceCache:
|
| 75 |
+
"""LRU cache for VoiceProfile objects with TTL-based expiration.
|
| 76 |
+
|
| 77 |
+
Entries auto-expire after `ttl_seconds` (default: 1 hour).
|
| 78 |
+
Re-uploading the same voice file within the TTL window returns
|
| 79 |
+
the cached profile instantly β no re-encoding needed.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(self, maxsize: int, ttl_seconds: int = 3600):
|
| 83 |
+
self._cache: OrderedDict[str, tuple[VoiceProfile, float]] = OrderedDict()
|
| 84 |
+
self._maxsize = maxsize
|
| 85 |
+
self._ttl = ttl_seconds
|
| 86 |
+
|
| 87 |
+
def _evict_expired(self):
|
| 88 |
+
"""Remove all entries older than TTL."""
|
| 89 |
+
now = time.time()
|
| 90 |
+
expired = [k for k, (_, ts) in self._cache.items() if now - ts > self._ttl]
|
| 91 |
+
for k in expired:
|
| 92 |
+
del self._cache[k]
|
| 93 |
+
logger.debug(f"Voice cache expired: {k[:8]}β¦")
|
| 94 |
+
|
| 95 |
+
def get(self, key: str) -> Optional[VoiceProfile]:
|
| 96 |
+
self._evict_expired()
|
| 97 |
+
if key in self._cache:
|
| 98 |
+
profile, ts = self._cache[key]
|
| 99 |
+
remaining = self._ttl - (time.time() - ts)
|
| 100 |
+
self._cache.move_to_end(key)
|
| 101 |
+
logger.info(f"Voice cache HIT: {key[:8]}β¦ (expires in {remaining:.0f}s)")
|
| 102 |
+
return profile
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
def put(self, key: str, profile: VoiceProfile):
|
| 106 |
+
self._evict_expired()
|
| 107 |
+
if key in self._cache:
|
| 108 |
+
self._cache.move_to_end(key)
|
| 109 |
+
else:
|
| 110 |
+
if len(self._cache) >= self._maxsize:
|
| 111 |
+
evicted_key, _ = self._cache.popitem(last=False)
|
| 112 |
+
logger.debug(f"Voice cache evicted (LRU): {evicted_key[:8]}β¦")
|
| 113 |
+
self._cache[key] = (profile, time.time())
|
| 114 |
+
logger.info(f"Voice cache STORED: {key[:8]}β¦ (TTL: {self._ttl}s, size: {len(self._cache)})")
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def size(self) -> int:
|
| 118 |
+
return len(self._cache)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 122 |
+
# Audio Loading (robust multi-format support)
|
| 123 |
+
# ββββββββββββββββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 124 |
+
|
| 125 |
+
def _load_audio_bytes(audio_bytes: bytes, sr: int = 24000) -> np.ndarray:
|
| 126 |
+
"""Load audio from raw bytes, supporting WAV/MP3/MPEG/M4A/OGG/FLAC/WebM.
|
| 127 |
+
|
| 128 |
+
Strategy: try soundfile (fast, native) β librosa (ffmpeg backend) β ffmpeg CLI.
|
| 129 |
+
"""
|
| 130 |
+
buf = io.BytesIO(audio_bytes)
|
| 131 |
+
|
| 132 |
+
# 1) Try soundfile (handles WAV, FLAC, OGG natively β fastest)
|
| 133 |
+
try:
|
| 134 |
+
audio, file_sr = soundfile_lib.read(buf)
|
| 135 |
+
if audio.ndim > 1:
|
| 136 |
+
audio = audio.mean(axis=1) # stereo β mono
|
| 137 |
+
if file_sr != sr:
|
| 138 |
+
audio = librosa.resample(audio.astype(np.float32), orig_sr=file_sr, target_sr=sr)
|
| 139 |
+
return audio.astype(np.float32)
|
| 140 |
+
except Exception:
|
| 141 |
+
buf.seek(0)
|
| 142 |
+
|
| 143 |
+
# 2) Try librosa (handles MP3 via audioread + ffmpeg backend)
|
| 144 |
+
try:
|
| 145 |
+
audio, _ = librosa.load(buf, sr=sr, mono=True)
|
| 146 |
+
return audio.astype(np.float32)
|
| 147 |
+
except Exception:
|
| 148 |
+
buf.seek(0)
|
| 149 |
+
|
| 150 |
+
# 3) Fallback: use ffmpeg CLI to convert to WAV in memory
|
| 151 |
+
try:
|
| 152 |
+
proc = subprocess.run(
|
| 153 |
+
["ffmpeg", "-i", "pipe:0", "-f", "wav", "-ac", "1", "-ar", str(sr), "pipe:1"],
|
| 154 |
+
input=audio_bytes, capture_output=True, timeout=30,
|
| 155 |
+
)
|
| 156 |
+
if proc.returncode == 0 and len(proc.stdout) > 44:
|
| 157 |
+
wav_buf = io.BytesIO(proc.stdout)
|
| 158 |
+
audio, _ = soundfile_lib.read(wav_buf)
|
| 159 |
+
return audio.astype(np.float32)
|
| 160 |
+
except Exception:
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
raise ValueError(
|
| 164 |
+
"Could not decode audio file. Supported formats: "
|
| 165 |
+
"WAV, MP3, MPEG, M4A, OGG, FLAC, WebM, AAC. "
|
| 166 |
+
"Please upload a valid audio file."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 171 |
+
# Main Wrapper
|
| 172 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 173 |
+
|
| 174 |
+
class ChatterboxWrapper:
|
| 175 |
+
|
| 176 |
+
def __init__(self, download_only: bool = False):
|
| 177 |
+
self.cfg = Config
|
| 178 |
+
os.makedirs(self.cfg.MODELS_DIR, exist_ok=True)
|
| 179 |
+
|
| 180 |
+
logger.info(f"Downloading ONNX models (dtype={self.cfg.MODEL_DTYPE}) β¦")
|
| 181 |
+
self._model_paths = self._download_models()
|
| 182 |
+
|
| 183 |
+
if download_only:
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
logger.info(
|
| 187 |
+
f"Creating ONNX Runtime sessions "
|
| 188 |
+
f"(intra_op_threads={self.cfg.CPU_THREADS}, workers={self.cfg.MAX_WORKERS}) β¦"
|
| 189 |
+
)
|
| 190 |
+
opts = self._make_session_options()
|
| 191 |
+
providers = ["CPUExecutionProvider"]
|
| 192 |
+
|
| 193 |
+
self.embed_session = ort.InferenceSession(self._model_paths["embed_tokens"], sess_options=opts, providers=providers)
|
| 194 |
+
self.encoder_session = ort.InferenceSession(self._model_paths["speech_encoder"], sess_options=opts, providers=providers)
|
| 195 |
+
self.lm_session = ort.InferenceSession(self._model_paths["language_model"], sess_options=opts, providers=providers)
|
| 196 |
+
self.decoder_session = ort.InferenceSession(self._model_paths["conditional_decoder"], sess_options=opts, providers=providers)
|
| 197 |
+
|
| 198 |
+
logger.info("Loading tokenizer β¦")
|
| 199 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.MODEL_ID)
|
| 200 |
+
|
| 201 |
+
self._voice_cache = _VoiceCache(
|
| 202 |
+
maxsize=self.cfg.VOICE_CACHE_SIZE,
|
| 203 |
+
ttl_seconds=self.cfg.VOICE_CACHE_TTL_SEC,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
logger.info("Encoding default reference voice β¦")
|
| 207 |
+
self.default_voice = self._load_default_voice()
|
| 208 |
+
|
| 209 |
+
logger.info("β
ChatterboxWrapper ready")
|
| 210 |
+
|
| 211 |
+
# βββ Model download ββββββββββββββββββββββββββββββββββββββββββ
|
| 212 |
+
|
| 213 |
+
def _download_models(self) -> dict:
|
| 214 |
+
"""Download all 4 ONNX components + weight files from HuggingFace."""
|
| 215 |
+
components = ("conditional_decoder", "speech_encoder", "embed_tokens", "language_model")
|
| 216 |
+
paths = {}
|
| 217 |
+
for name in components:
|
| 218 |
+
paths[name] = self._download_component(name, self.cfg.MODEL_DTYPE)
|
| 219 |
+
return paths
|
| 220 |
+
|
| 221 |
+
def _download_component(self, name: str, dtype: str) -> str:
|
| 222 |
+
if dtype == "fp32":
|
| 223 |
+
filename = f"{name}.onnx"
|
| 224 |
+
elif dtype == "q8":
|
| 225 |
+
filename = f"{name}_quantized.onnx"
|
| 226 |
+
else:
|
| 227 |
+
filename = f"{name}_{dtype}.onnx"
|
| 228 |
+
|
| 229 |
+
graph = hf_hub_download(
|
| 230 |
+
self.cfg.MODEL_ID, subfolder="onnx", filename=filename,
|
| 231 |
+
cache_dir=self.cfg.MODELS_DIR,
|
| 232 |
+
)
|
| 233 |
+
# Download companion weight file
|
| 234 |
+
try:
|
| 235 |
+
hf_hub_download(
|
| 236 |
+
self.cfg.MODEL_ID, subfolder="onnx", filename=f"{filename}_data",
|
| 237 |
+
cache_dir=self.cfg.MODELS_DIR,
|
| 238 |
+
)
|
| 239 |
+
except Exception:
|
| 240 |
+
pass # Some quantized variants embed weights in-graph
|
| 241 |
+
return graph
|
| 242 |
+
|
| 243 |
+
# βββ Session configuration (optimised for 2 vCPU) βββββββββββββ
|
| 244 |
+
|
| 245 |
+
def _make_session_options(self) -> ort.SessionOptions:
|
| 246 |
+
opts = ort.SessionOptions()
|
| 247 |
+
# Sequential execution: no parallel graph scheduling overhead
|
| 248 |
+
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
| 249 |
+
# Match physical cores exactly (2 for HF Space free tier)
|
| 250 |
+
opts.intra_op_num_threads = self.cfg.CPU_THREADS
|
| 251 |
+
opts.inter_op_num_threads = 1
|
| 252 |
+
# Full graph optimisations (constant folding, fusion, etc.)
|
| 253 |
+
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 254 |
+
# Disable thread spinning β wastes CPU cycles on busy-wait
|
| 255 |
+
opts.add_session_config_entry("session.intra_op.allow_spinning", "0")
|
| 256 |
+
opts.add_session_config_entry("session.inter_op.allow_spinning", "0")
|
| 257 |
+
# Enable memory optimisations
|
| 258 |
+
opts.enable_cpu_mem_arena = True
|
| 259 |
+
opts.enable_mem_pattern = True
|
| 260 |
+
opts.enable_mem_reuse = True
|
| 261 |
+
return opts
|
| 262 |
+
|
| 263 |
+
# βββ Default voice ββββββββββββββββββββββββββββββββββββββββββββ
|
| 264 |
+
|
| 265 |
+
def _load_default_voice(self) -> VoiceProfile:
|
| 266 |
+
path = hf_hub_download(
|
| 267 |
+
self.cfg.DEFAULT_VOICE_REPO,
|
| 268 |
+
filename=self.cfg.DEFAULT_VOICE_FILE,
|
| 269 |
+
cache_dir=self.cfg.MODELS_DIR,
|
| 270 |
+
)
|
| 271 |
+
audio, _ = librosa.load(path, sr=self.cfg.SAMPLE_RATE)
|
| 272 |
+
return self._encode_audio_array(audio, audio_hash="__default__")
|
| 273 |
+
|
| 274 |
+
# βββ Voice encoding ββββββββββββββββββββββββββββββββββββββββββ
|
| 275 |
+
|
| 276 |
+
def encode_voice_from_bytes(self, audio_bytes: bytes) -> VoiceProfile:
|
| 277 |
+
"""Encode reference audio from raw bytes (in-memory, no disk write).
|
| 278 |
+
|
| 279 |
+
Accepts: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM, AAC, WMA, Opus.
|
| 280 |
+
"""
|
| 281 |
+
audio_hash = hashlib.md5(audio_bytes).hexdigest()
|
| 282 |
+
cached = self._voice_cache.get(audio_hash)
|
| 283 |
+
if cached is not None:
|
| 284 |
+
logger.info(f"Voice cache hit: {audio_hash[:8]}β¦")
|
| 285 |
+
return cached
|
| 286 |
+
|
| 287 |
+
# Robust multi-format audio loading
|
| 288 |
+
audio = _load_audio_bytes(audio_bytes, sr=self.cfg.SAMPLE_RATE)
|
| 289 |
+
|
| 290 |
+
# Validate duration
|
| 291 |
+
duration = len(audio) / self.cfg.SAMPLE_RATE
|
| 292 |
+
if duration < self.cfg.MIN_REF_DURATION_SEC:
|
| 293 |
+
raise ValueError(
|
| 294 |
+
f"Reference audio too short ({duration:.1f}s). "
|
| 295 |
+
f"Minimum: {self.cfg.MIN_REF_DURATION_SEC}s"
|
| 296 |
+
)
|
| 297 |
+
if duration > self.cfg.MAX_REF_DURATION_SEC:
|
| 298 |
+
audio = audio[: int(self.cfg.MAX_REF_DURATION_SEC * self.cfg.SAMPLE_RATE)]
|
| 299 |
+
|
| 300 |
+
profile = self._encode_audio_array(audio, audio_hash=audio_hash)
|
| 301 |
+
self._voice_cache.put(audio_hash, profile)
|
| 302 |
+
return profile
|
| 303 |
+
|
| 304 |
+
def _encode_audio_array(self, audio: np.ndarray, audio_hash: str = "") -> VoiceProfile:
|
| 305 |
+
"""Run speech_encoder on a float32 mono audio array."""
|
| 306 |
+
audio_input = audio[np.newaxis, :].astype(np.float32)
|
| 307 |
+
cond_emb, prompt_token, speaker_emb, speaker_feat = self.encoder_session.run(
|
| 308 |
+
None, {"audio_values": audio_input}
|
| 309 |
+
)
|
| 310 |
+
return VoiceProfile(
|
| 311 |
+
cond_emb=cond_emb,
|
| 312 |
+
prompt_token=prompt_token,
|
| 313 |
+
speaker_embeddings=speaker_emb,
|
| 314 |
+
speaker_features=speaker_feat,
|
| 315 |
+
audio_hash=audio_hash,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# βββ Full generation (non-streaming) ββββββββββββββββββββββββββ
|
| 319 |
+
|
| 320 |
+
def generate_speech(
|
| 321 |
+
self,
|
| 322 |
+
text: str,
|
| 323 |
+
voice: Optional[VoiceProfile] = None,
|
| 324 |
+
max_new_tokens: Optional[int] = None,
|
| 325 |
+
repetition_penalty: Optional[float] = None,
|
| 326 |
+
) -> np.ndarray:
|
| 327 |
+
"""Generate complete audio for the given text."""
|
| 328 |
+
voice = voice or self.default_voice
|
| 329 |
+
text = text_processor.sanitize(text.strip()[: self.cfg.MAX_TEXT_LENGTH])
|
| 330 |
+
if not text:
|
| 331 |
+
raise ValueError("Text is empty after sanitization")
|
| 332 |
+
|
| 333 |
+
tokens = self._generate_tokens(
|
| 334 |
+
text, voice,
|
| 335 |
+
max_new_tokens or self.cfg.MAX_NEW_TOKENS,
|
| 336 |
+
repetition_penalty or self.cfg.REPETITION_PENALTY,
|
| 337 |
+
)
|
| 338 |
+
return self._decode_tokens(tokens, voice)
|
| 339 |
+
|
| 340 |
+
# βββ Streaming generation βββββββββββββββββββββββββββββββββββββ
|
| 341 |
+
|
| 342 |
+
def stream_speech(
|
| 343 |
+
self,
|
| 344 |
+
text: str,
|
| 345 |
+
voice: Optional[VoiceProfile] = None,
|
| 346 |
+
max_new_tokens: Optional[int] = None,
|
| 347 |
+
repetition_penalty: Optional[float] = None,
|
| 348 |
+
is_cancelled: Optional[Callable[[], bool]] = None,
|
| 349 |
+
) -> Generator[np.ndarray, None, None]:
|
| 350 |
+
"""Yield audio chunks sentence-by-sentence for real-time streaming.
|
| 351 |
+
|
| 352 |
+
Each sentence is independently processed through the full pipeline
|
| 353 |
+
so the first chunk arrives as fast as possible (low TTFB).
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
is_cancelled: Optional callable that returns True to abort generation.
|
| 357 |
+
Checked between chunks and every 25 autoregressive steps.
|
| 358 |
+
"""
|
| 359 |
+
voice = voice or self.default_voice
|
| 360 |
+
text = text_processor.sanitize(text.strip()[: self.cfg.MAX_TEXT_LENGTH])
|
| 361 |
+
if not text:
|
| 362 |
+
return
|
| 363 |
+
|
| 364 |
+
sentences = text_processor.split_for_streaming(text)
|
| 365 |
+
_max = max_new_tokens or self.cfg.MAX_NEW_TOKENS
|
| 366 |
+
_rep = repetition_penalty or self.cfg.REPETITION_PENALTY
|
| 367 |
+
_check = is_cancelled or (lambda: False)
|
| 368 |
+
|
| 369 |
+
for i, sentence in enumerate(sentences):
|
| 370 |
+
# Check cancellation between chunks
|
| 371 |
+
if _check():
|
| 372 |
+
logger.info("Generation cancelled by client (between chunks)")
|
| 373 |
+
return
|
| 374 |
+
if not sentence.strip():
|
| 375 |
+
continue
|
| 376 |
+
t0 = time.perf_counter()
|
| 377 |
+
try:
|
| 378 |
+
tokens = self._generate_tokens(sentence, voice, _max, _rep, _check)
|
| 379 |
+
if _check():
|
| 380 |
+
return
|
| 381 |
+
audio = self._decode_tokens(tokens, voice)
|
| 382 |
+
elapsed = time.perf_counter() - t0
|
| 383 |
+
audio_duration = len(audio) / self.cfg.SAMPLE_RATE
|
| 384 |
+
rtf = elapsed / audio_duration if audio_duration > 0 else 0
|
| 385 |
+
logger.info(
|
| 386 |
+
f"Chunk {i + 1}/{len(sentences)}: "
|
| 387 |
+
f"{len(sentence)} chars β {audio_duration:.1f}s audio "
|
| 388 |
+
f"in {elapsed:.2f}s (RTF: {rtf:.2f}x)"
|
| 389 |
+
)
|
| 390 |
+
yield audio
|
| 391 |
+
except GenerationCancelled:
|
| 392 |
+
logger.info(f"Generation cancelled mid-token at chunk {i + 1}")
|
| 393 |
+
return
|
| 394 |
+
except Exception as e:
|
| 395 |
+
logger.error(f"Error on chunk {i + 1}: {e}")
|
| 396 |
+
raise
|
| 397 |
+
|
| 398 |
+
# βββ Autoregressive token generation (OPTIMISED) ββββββββββββββ
|
| 399 |
+
|
| 400 |
+
def _generate_tokens(
|
| 401 |
+
self,
|
| 402 |
+
text: str,
|
| 403 |
+
voice: VoiceProfile,
|
| 404 |
+
max_new_tokens: int,
|
| 405 |
+
repetition_penalty: float,
|
| 406 |
+
is_cancelled: Callable[[], bool] = lambda: False,
|
| 407 |
+
) -> np.ndarray:
|
| 408 |
+
"""Run embed β LM autoregressive loop. Returns raw token array.
|
| 409 |
+
|
| 410 |
+
Optimisations:
|
| 411 |
+
β’ Token list instead of repeated np.concatenate (O(n) β O(1) append)
|
| 412 |
+
β’ Unique tokens set for inline repetition penalty (avoids exponential penalty bug)
|
| 413 |
+
β’ Pre-allocated attention mask for zero-copy slicing
|
| 414 |
+
β’ Correct dimensional routing for step 0 prompt processing
|
| 415 |
+
"""
|
| 416 |
+
input_ids = self.tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64)
|
| 417 |
+
|
| 418 |
+
# Pre-allocate collections
|
| 419 |
+
token_list: list[int] = [self.cfg.START_SPEECH_TOKEN]
|
| 420 |
+
unique_tokens: set[int] = {self.cfg.START_SPEECH_TOKEN}
|
| 421 |
+
penalty = repetition_penalty
|
| 422 |
+
|
| 423 |
+
past_key_values = None
|
| 424 |
+
attention_mask_full = None
|
| 425 |
+
seq_len = 0
|
| 426 |
+
|
| 427 |
+
for step in range(max_new_tokens):
|
| 428 |
+
if step > 0 and step % 25 == 0 and is_cancelled():
|
| 429 |
+
raise GenerationCancelled()
|
| 430 |
+
|
| 431 |
+
embeds = self.embed_session.run(None, {"input_ids": input_ids})[0]
|
| 432 |
+
|
| 433 |
+
if step == 0:
|
| 434 |
+
# Prepend speaker conditioning
|
| 435 |
+
embeds = np.concatenate((voice.cond_emb, embeds), axis=1)
|
| 436 |
+
batch, seq_len, _ = embeds.shape
|
| 437 |
+
|
| 438 |
+
past_key_values = {
|
| 439 |
+
inp.name: np.zeros(
|
| 440 |
+
[batch, self.cfg.NUM_KV_HEADS, 0, self.cfg.HEAD_DIM],
|
| 441 |
+
dtype=np.float16 if inp.type == "tensor(float16)" else np.float32,
|
| 442 |
+
)
|
| 443 |
+
for inp in self.lm_session.get_inputs()
|
| 444 |
+
if "past_key_values" in inp.name
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
# Pre-allocate full attention mask
|
| 448 |
+
attention_mask_full = np.ones((batch, seq_len + max_new_tokens), dtype=np.int64)
|
| 449 |
+
attention_mask = attention_mask_full[:, :seq_len]
|
| 450 |
+
|
| 451 |
+
# Step 0 requires position_ids matching prompt sequence length
|
| 452 |
+
position_ids = np.arange(seq_len, dtype=np.int64).reshape(batch, -1)
|
| 453 |
+
else:
|
| 454 |
+
# O(1) zero-copy slice for subsequent steps
|
| 455 |
+
attention_mask = attention_mask_full[:, : seq_len + step]
|
| 456 |
+
# Single position ID for the single new token
|
| 457 |
+
position_ids = np.array([[seq_len + step - 1]], dtype=np.int64)
|
| 458 |
+
|
| 459 |
+
# Language model forward pass
|
| 460 |
+
logits, *present_kv = self.lm_session.run(
|
| 461 |
+
None,
|
| 462 |
+
dict(
|
| 463 |
+
inputs_embeds=embeds,
|
| 464 |
+
attention_mask=attention_mask,
|
| 465 |
+
position_ids=position_ids,
|
| 466 |
+
**past_key_values,
|
| 467 |
+
),
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# ββ Inline repetition penalty + token selection βββββββ
|
| 471 |
+
last_logits = logits[0, -1, :].copy() # shape: (vocab_size,)
|
| 472 |
+
|
| 473 |
+
# Apply repetition penalty strictly to unique tokens to prevent over-penalization
|
| 474 |
+
for tok_id in unique_tokens:
|
| 475 |
+
if last_logits[tok_id] < 0:
|
| 476 |
+
last_logits[tok_id] *= penalty
|
| 477 |
+
else:
|
| 478 |
+
last_logits[tok_id] /= penalty
|
| 479 |
+
|
| 480 |
+
next_token = int(np.argmax(last_logits))
|
| 481 |
+
token_list.append(next_token)
|
| 482 |
+
unique_tokens.add(next_token)
|
| 483 |
+
|
| 484 |
+
if next_token == self.cfg.STOP_SPEECH_TOKEN:
|
| 485 |
+
break
|
| 486 |
+
|
| 487 |
+
# Update state for next step
|
| 488 |
+
input_ids = np.array([[next_token]], dtype=np.int64)
|
| 489 |
+
for j, key in enumerate(past_key_values):
|
| 490 |
+
past_key_values[key] = present_kv[j]
|
| 491 |
+
|
| 492 |
+
return np.array([token_list], dtype=np.int64)
|
| 493 |
+
|
| 494 |
+
# βββ Token β audio decoding βββββββββββββββββββββββββββββββββββ
|
| 495 |
+
|
| 496 |
+
def _decode_tokens(self, generated: np.ndarray, voice: VoiceProfile) -> np.ndarray:
|
| 497 |
+
"""Decode speech tokens to a float32 waveform at 24 kHz."""
|
| 498 |
+
# Strip START token; strip STOP token if present
|
| 499 |
+
tokens = generated[:, 1:]
|
| 500 |
+
if tokens.shape[1] > 0 and tokens[0, -1] == self.cfg.STOP_SPEECH_TOKEN:
|
| 501 |
+
tokens = tokens[:, :-1]
|
| 502 |
+
|
| 503 |
+
if tokens.shape[1] == 0:
|
| 504 |
+
return np.zeros(0, dtype=np.float32)
|
| 505 |
+
|
| 506 |
+
# Prepend prompt token + append silence
|
| 507 |
+
silence = np.full(
|
| 508 |
+
(tokens.shape[0], 3), self.cfg.SILENCE_TOKEN, dtype=np.int64
|
| 509 |
+
)
|
| 510 |
+
full_tokens = np.concatenate(
|
| 511 |
+
[voice.prompt_token, tokens, silence], axis=1
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
wav = self.decoder_session.run(
|
| 515 |
+
None,
|
| 516 |
+
{
|
| 517 |
+
"speech_tokens": full_tokens,
|
| 518 |
+
"speaker_embeddings": voice.speaker_embeddings,
|
| 519 |
+
"speaker_features": voice.speaker_features,
|
| 520 |
+
},
|
| 521 |
+
)[0].squeeze(axis=0)
|
| 522 |
+
|
| 523 |
+
return wav
|
| 524 |
+
|
| 525 |
+
# βββ Warmup βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 526 |
+
|
| 527 |
+
def warmup(self):
|
| 528 |
+
"""Run a short inference to warm up ONNX sessions and JIT paths."""
|
| 529 |
+
try:
|
| 530 |
+
t0 = time.perf_counter()
|
| 531 |
+
_ = self.generate_speech("Hello.", self.default_voice, max_new_tokens=32)
|
| 532 |
+
logger.info(f"Warmup done in {time.perf_counter() - t0:.2f}s")
|
| 533 |
+
except Exception as e:
|
| 534 |
+
logger.warning(f"Warmup failed (non-critical): {e}")
|
config.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chatterbox Turbo TTS β Centralized Configuration
|
| 3 |
+
βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 4 |
+
Optimised for HF Space free tier (2 vCPU).
|
| 5 |
+
Adjust MODEL_DTYPE to switch quantization (q8/q4/fp16/fp32).
|
| 6 |
+
All settings overridable via environment variables prefixed CB_.
|
| 7 |
+
"""
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
_HERE = os.path.dirname(os.path.abspath(__file__))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _get_bool(name: str, default: bool) -> bool:
|
| 14 |
+
raw = os.getenv(name)
|
| 15 |
+
if raw is None:
|
| 16 |
+
return default
|
| 17 |
+
return raw.strip().lower() in {"1", "true", "yes", "on"}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Config:
|
| 21 |
+
# ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 22 |
+
MODEL_ID: str = os.getenv("CB_MODEL_ID", "ResembleAI/chatterbox-turbo-ONNX")
|
| 23 |
+
|
| 24 |
+
# fp32 β highest quality, ~1.4 GB, slowest
|
| 25 |
+
# fp16 β good quality, ~0.7 GB
|
| 26 |
+
# q8 β β
recommended, ~0.35 GB, best balance
|
| 27 |
+
# q4 β smallest, ~0.17 GB, fastest, slight loss
|
| 28 |
+
# q4f16 β q4 weights + fp16 activations
|
| 29 |
+
MODEL_DTYPE: str = os.getenv("CB_MODEL_DTYPE", "q4")
|
| 30 |
+
|
| 31 |
+
MODELS_DIR: str = os.getenv("CB_MODELS_DIR", os.path.join(_HERE, "models"))
|
| 32 |
+
|
| 33 |
+
# ββ ONNX Runtime CPU tuning (optimised for 2 vCPU) βββββββββββ
|
| 34 |
+
#
|
| 35 |
+
# KEY RULE: intra_op threads MUST match physical cores.
|
| 36 |
+
# β 4 threads on 2 cores = oversubscription = SLOWER.
|
| 37 |
+
# β 2 threads on 2 cores = each op uses both cores perfectly.
|
| 38 |
+
#
|
| 39 |
+
# MAX_WORKERS = 1 ensures ONE inference gets both cores.
|
| 40 |
+
# β 2 workers would split 2 cores = both requests slow.
|
| 41 |
+
#
|
| 42 |
+
CPU_THREADS: int = int(os.getenv("CB_CPU_THREADS", "2"))
|
| 43 |
+
MAX_WORKERS: int = int(os.getenv("CB_MAX_WORKERS", "1"))
|
| 44 |
+
|
| 45 |
+
# ββ Generation defaults ββββββββββββββββββββββββββββββββββββββ
|
| 46 |
+
SAMPLE_RATE: int = 24000
|
| 47 |
+
MAX_NEW_TOKENS: int = int(os.getenv("CB_MAX_NEW_TOKENS", "768"))
|
| 48 |
+
REPETITION_PENALTY: float = float(os.getenv("CB_REPETITION_PENALTY", "1.2"))
|
| 49 |
+
MAX_TEXT_LENGTH: int = int(os.getenv("CB_MAX_TEXT_LENGTH", "50000"))
|
| 50 |
+
|
| 51 |
+
# ββ Model constants (official card β do not change) ββββββββββ
|
| 52 |
+
START_SPEECH_TOKEN: int = 6561
|
| 53 |
+
STOP_SPEECH_TOKEN: int = 6562
|
| 54 |
+
SILENCE_TOKEN: int = 4299
|
| 55 |
+
NUM_KV_HEADS: int = 16
|
| 56 |
+
HEAD_DIM: int = 64
|
| 57 |
+
|
| 58 |
+
# ββ Paralinguistic tags (Turbo native) βββββββββββββββββββββββ
|
| 59 |
+
PARALINGUISTIC_TAGS: tuple = (
|
| 60 |
+
"laugh", "chuckle", "cough", "sigh", "gasp",
|
| 61 |
+
"shush", "groan", "sniff", "clear throat",
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# ββ Voice / reference audio ββββββββββββββββββββββββββββββββββ
|
| 65 |
+
# NOTE: Official ResembleAI/chatterbox-turbo-ONNX has no bundled voice.
|
| 66 |
+
# The default_voice.wav is a plain audio sample from community repo
|
| 67 |
+
# (not a model β just a reference WAV, safe to use from any source).
|
| 68 |
+
DEFAULT_VOICE_REPO: str = "onnx-community/chatterbox-ONNX"
|
| 69 |
+
DEFAULT_VOICE_FILE: str = "default_voice.wav"
|
| 70 |
+
MAX_VOICE_UPLOAD_BYTES: int = 10 * 1024 * 1024 # 10 MB
|
| 71 |
+
MIN_REF_DURATION_SEC: float = 1.5
|
| 72 |
+
MAX_REF_DURATION_SEC: float = 30.0
|
| 73 |
+
VOICE_CACHE_SIZE: int = int(os.getenv("CB_VOICE_CACHE_SIZE", "20"))
|
| 74 |
+
VOICE_CACHE_TTL_SEC: int = int(os.getenv("CB_VOICE_CACHE_TTL", "3600")) # 1 hour
|
| 75 |
+
|
| 76 |
+
# ββ Streaming ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
# Smaller chunks = faster TTFB (first audio arrives sooner)
|
| 78 |
+
# ~200 chars β 1β2 sentences β fastest first-chunk on 2 vCPU
|
| 79 |
+
MAX_CHUNK_CHARS: int = int(os.getenv("CB_MAX_CHUNK_CHARS", "100"))
|
| 80 |
+
# Additive parallel mode (odd/even split across primary/helper).
|
| 81 |
+
ENABLE_PARALLEL_MODE: bool = _get_bool("CB_ENABLE_PARALLEL_MODE", True)
|
| 82 |
+
HELPER_BASE_URL: str = os.getenv("CB_HELPER_BASE_URL", "https://shadowhunter222-hello2.hf.space").strip()
|
| 83 |
+
HELPER_TIMEOUT_SEC: float = float(os.getenv("CB_HELPER_TIMEOUT_SEC", "45"))
|
| 84 |
+
HELPER_RETRY_ONCE: bool = _get_bool("CB_HELPER_RETRY_ONCE", True)
|
| 85 |
+
# Optional shared secret for internal chunk endpoints.
|
| 86 |
+
INTERNAL_SHARED_SECRET: str = os.getenv("CB_INTERNAL_SHARED_SECRET", "").strip()
|
| 87 |
+
|
| 88 |
+
# ββ Server βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 89 |
+
HOST: str = os.getenv("CB_HOST", "0.0.0.0")
|
| 90 |
+
PORT: int = int(os.getenv("CB_PORT", "7860"))
|
| 91 |
+
|
| 92 |
+
ALLOWED_ORIGINS: list = [
|
| 93 |
+
"https://toolboxesai.com",
|
| 94 |
+
"http://localhost:8788", "http://127.0.0.1:8788",
|
| 95 |
+
"http://localhost:5502", "http://127.0.0.1:5502",
|
| 96 |
+
"http://localhost:5501", "http://127.0.0.1:5501",
|
| 97 |
+
"http://localhost:5500", "http://127.0.0.1:5500",
|
| 98 |
+
"http://localhost:5173", "http://127.0.0.1:5173",
|
| 99 |
+
"http://localhost:7860", "http://127.0.0.1:7860",
|
| 100 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =========================================================
|
| 2 |
+
# Chatterbox Turbo TTS - Dependencies (CPU-only)
|
| 3 |
+
# =========================================================
|
| 4 |
+
|
| 5 |
+
# PyTorch CPU (required by transformers tokenizer internals)
|
| 6 |
+
torch --index-url https://download.pytorch.org/whl/cpu
|
| 7 |
+
|
| 8 |
+
# Core API
|
| 9 |
+
fastapi>=0.104.1
|
| 10 |
+
uvicorn[standard]>=0.24.0
|
| 11 |
+
pydantic>=2.5.0
|
| 12 |
+
python-multipart>=0.0.6
|
| 13 |
+
|
| 14 |
+
# ONNX Runtime (CPU inference)
|
| 15 |
+
onnxruntime>=1.17.0
|
| 16 |
+
|
| 17 |
+
# Audio processing
|
| 18 |
+
numpy>=1.24.0
|
| 19 |
+
librosa>=0.10.0
|
| 20 |
+
soundfile>=0.12.0
|
| 21 |
+
|
| 22 |
+
# Tokenizer + model download
|
| 23 |
+
transformers>=4.46.0
|
| 24 |
+
huggingface-hub>=0.19.0
|
text_processor.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chatterbox Turbo TTS β Text Processor
|
| 3 |
+
βββββββββββββββββββββββββββββββββββββββ
|
| 4 |
+
Sanitizes raw input text and splits it into sentence-level chunks
|
| 5 |
+
for streaming TTS. Paralinguistic tags ([laugh], [cough], β¦) are
|
| 6 |
+
explicitly preserved so the model can render them.
|
| 7 |
+
"""
|
| 8 |
+
import re
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
from config import Config
|
| 12 |
+
|
| 13 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
+
# Pre-compiled regex patterns (compiled once at import β zero cost)
|
| 15 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
+
|
| 17 |
+
# β Paralinguistic tag protector (matches [laugh], [clear throat], etc.)
|
| 18 |
+
_TAG_NAMES = "|".join(re.escape(t) for t in Config.PARALINGUISTIC_TAGS)
|
| 19 |
+
_RE_PARA_TAG = re.compile(rf"\[(?:{_TAG_NAMES})\]", re.IGNORECASE)
|
| 20 |
+
|
| 21 |
+
# β Markdown / structural noise
|
| 22 |
+
_RE_CODE_BLOCK = re.compile(r"```[\s\S]*?```")
|
| 23 |
+
_RE_INLINE_CODE = re.compile(r"`([^`]+)`")
|
| 24 |
+
_RE_IMAGE = re.compile(r"!\[([^\]]*)\]\([^)]+\)")
|
| 25 |
+
_RE_LINK = re.compile(r"\[([^\]]+)\]\([^)]+\)")
|
| 26 |
+
_RE_BOLD_AST = re.compile(r"\*\*(.+?)\*\*")
|
| 27 |
+
_RE_BOLD_UND = re.compile(r"__(.+?)__")
|
| 28 |
+
_RE_STRIKE = re.compile(r"~~(.+?)~~")
|
| 29 |
+
_RE_ITALIC_AST = re.compile(r"\*(.+?)\*")
|
| 30 |
+
_RE_ITALIC_UND = re.compile(r"(?<!\w)_(.+?)_(?!\w)")
|
| 31 |
+
_RE_HEADER = re.compile(r"^#{1,6}\s+", re.MULTILINE)
|
| 32 |
+
_RE_BLOCKQUOTE = re.compile(r"^>+\s?", re.MULTILINE)
|
| 33 |
+
_RE_HR = re.compile(r"^[-*_]{3,}$", re.MULTILINE)
|
| 34 |
+
_RE_BULLET = re.compile(r"^\s*[-*+]\s+", re.MULTILINE)
|
| 35 |
+
_RE_ORDERED = re.compile(r"^\s*\d+\.\s+", re.MULTILINE)
|
| 36 |
+
|
| 37 |
+
# β URLs, emojis, HTML entities
|
| 38 |
+
_RE_URL = re.compile(r"https?://\S+")
|
| 39 |
+
_RE_EMOJI = re.compile(
|
| 40 |
+
r"["
|
| 41 |
+
r"\U0001F600-\U0001F64F\U0001F300-\U0001F5FF"
|
| 42 |
+
r"\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF"
|
| 43 |
+
r"\U00002702-\U000027B0\U0001F900-\U0001F9FF"
|
| 44 |
+
r"\U0001FA00-\U0001FA6F\U0001FA70-\U0001FAFF"
|
| 45 |
+
r"\U00002600-\U000026FF\U0000FE00-\U0000FE0F"
|
| 46 |
+
r"\U0000200D"
|
| 47 |
+
r"]+", re.UNICODE,
|
| 48 |
+
)
|
| 49 |
+
_RE_HTML_ENTITY = re.compile(r"&(?:#x?[\da-fA-F]+|\w+);")
|
| 50 |
+
_HTML_ENTITIES = {
|
| 51 |
+
"&": " and ", "<": " less than ", ">": " greater than ",
|
| 52 |
+
" ": " ", """: '"', "'": "'",
|
| 53 |
+
"—": ", ", "–": ", ", "…": ".",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
# β Punctuation normalization
|
| 57 |
+
_RE_REPEATED_DOT = re.compile(r"\.{2,}")
|
| 58 |
+
_RE_REPEATED_EXCLAM = re.compile(r"!{2,}")
|
| 59 |
+
_RE_REPEATED_QUEST = re.compile(r"\?{2,}")
|
| 60 |
+
_RE_REPEATED_SEMI = re.compile(r";{2,}")
|
| 61 |
+
_RE_REPEATED_COLON = re.compile(r":{2,}")
|
| 62 |
+
_RE_REPEATED_COMMA = re.compile(r",{2,}")
|
| 63 |
+
_RE_REPEATED_DASH = re.compile(r"-{3,}")
|
| 64 |
+
|
| 65 |
+
# β Whitespace
|
| 66 |
+
_RE_MULTI_SPACE = re.compile(r"[ \t]+")
|
| 67 |
+
_RE_MULTI_NEWLINE = re.compile(r"\n{3,}")
|
| 68 |
+
_RE_SPACE_BEFORE_PUN = re.compile(r"\s+([.!?,;:])")
|
| 69 |
+
|
| 70 |
+
# β Sentence boundary (split point)
|
| 71 |
+
_RE_SENTENCE_SPLIT = re.compile(r"(?<=[.!?;:])\s+")
|
| 72 |
+
|
| 73 |
+
_MIN_MERGE_WORDS = 5
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
+
# Public API
|
| 78 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 79 |
+
|
| 80 |
+
def sanitize(text: str) -> str:
|
| 81 |
+
"""Clean raw input for TTS while preserving paralinguistic tags."""
|
| 82 |
+
if not text:
|
| 83 |
+
return text
|
| 84 |
+
|
| 85 |
+
# 1. Protect paralinguistic tags by replacing with placeholders
|
| 86 |
+
tags_found: list[tuple[int, str]] = []
|
| 87 |
+
def _protect_tag(m):
|
| 88 |
+
idx = len(tags_found)
|
| 89 |
+
tags_found.append((idx, m.group(0)))
|
| 90 |
+
return f"Β§TAG{idx}Β§"
|
| 91 |
+
text = _RE_PARA_TAG.sub(_protect_tag, text)
|
| 92 |
+
|
| 93 |
+
# 2. Strip non-speakable structures
|
| 94 |
+
text = _RE_URL.sub("", text)
|
| 95 |
+
text = _RE_CODE_BLOCK.sub("", text)
|
| 96 |
+
text = _RE_IMAGE.sub(lambda m: m.group(1) if m.group(1) else "", text)
|
| 97 |
+
text = _RE_LINK.sub(r"\1", text)
|
| 98 |
+
text = _RE_BOLD_AST.sub(r"\1", text)
|
| 99 |
+
text = _RE_BOLD_UND.sub(r"\1", text)
|
| 100 |
+
text = _RE_STRIKE.sub(r"\1", text)
|
| 101 |
+
text = _RE_ITALIC_AST.sub(r"\1", text)
|
| 102 |
+
text = _RE_ITALIC_UND.sub(r"\1", text)
|
| 103 |
+
text = _RE_INLINE_CODE.sub(r"\1", text)
|
| 104 |
+
text = _RE_HEADER.sub("", text)
|
| 105 |
+
text = _RE_BLOCKQUOTE.sub("", text)
|
| 106 |
+
text = _RE_HR.sub("", text)
|
| 107 |
+
text = _RE_BULLET.sub("", text)
|
| 108 |
+
text = _RE_ORDERED.sub("", text)
|
| 109 |
+
|
| 110 |
+
# 3. Emojis, hashtags
|
| 111 |
+
text = _RE_EMOJI.sub("", text)
|
| 112 |
+
text = re.sub(r"#(\w+)", r"\1", text)
|
| 113 |
+
|
| 114 |
+
# 4. HTML entities
|
| 115 |
+
text = _RE_HTML_ENTITY.sub(lambda m: _HTML_ENTITIES.get(m.group(0), ""), text)
|
| 116 |
+
|
| 117 |
+
# 5. Collapse repeated punctuation
|
| 118 |
+
text = _RE_REPEATED_DOT.sub(".", text)
|
| 119 |
+
text = _RE_REPEATED_EXCLAM.sub("!", text)
|
| 120 |
+
text = _RE_REPEATED_QUEST.sub("?", text)
|
| 121 |
+
text = _RE_REPEATED_SEMI.sub(";", text)
|
| 122 |
+
text = _RE_REPEATED_COLON.sub(":", text)
|
| 123 |
+
text = _RE_REPEATED_COMMA.sub(",", text)
|
| 124 |
+
text = _RE_REPEATED_DASH.sub("β", text)
|
| 125 |
+
|
| 126 |
+
# 6. Whitespace
|
| 127 |
+
text = _RE_SPACE_BEFORE_PUN.sub(r"\1", text)
|
| 128 |
+
text = _RE_MULTI_SPACE.sub(" ", text)
|
| 129 |
+
text = _RE_MULTI_NEWLINE.sub("\n\n", text)
|
| 130 |
+
text = text.strip()
|
| 131 |
+
|
| 132 |
+
# 7. Restore paralinguistic tags
|
| 133 |
+
for idx, original in tags_found:
|
| 134 |
+
text = text.replace(f"Β§TAG{idx}Β§", original)
|
| 135 |
+
|
| 136 |
+
return text
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def split_for_streaming(text: str, max_chars: int = Config.MAX_CHUNK_CHARS) -> List[str]:
|
| 140 |
+
"""Split sanitized text into sentence-level chunks for streaming.
|
| 141 |
+
|
| 142 |
+
Strategy:
|
| 143 |
+
1. Split on sentence-ending punctuation boundaries
|
| 144 |
+
2. Enforce max_chars per chunk (split long sentences on commas / spaces)
|
| 145 |
+
3. Merge short chunks (β€5 words) with the next to avoid tiny segments
|
| 146 |
+
"""
|
| 147 |
+
if not text:
|
| 148 |
+
return []
|
| 149 |
+
|
| 150 |
+
# Step 1: sentence split
|
| 151 |
+
raw_chunks = _RE_SENTENCE_SPLIT.split(text)
|
| 152 |
+
raw_chunks = [c.strip() for c in raw_chunks if c.strip()]
|
| 153 |
+
|
| 154 |
+
# Step 2: enforce max length per chunk
|
| 155 |
+
sized: List[str] = []
|
| 156 |
+
for chunk in raw_chunks:
|
| 157 |
+
if len(chunk) <= max_chars:
|
| 158 |
+
sized.append(chunk)
|
| 159 |
+
else:
|
| 160 |
+
sized.extend(_break_long_chunk(chunk, max_chars))
|
| 161 |
+
|
| 162 |
+
# Step 3: merge short chunks
|
| 163 |
+
if len(sized) <= 1:
|
| 164 |
+
return sized
|
| 165 |
+
|
| 166 |
+
merged: List[str] = []
|
| 167 |
+
carry = ""
|
| 168 |
+
for i, chunk in enumerate(sized):
|
| 169 |
+
if carry:
|
| 170 |
+
chunk = carry + " " + chunk
|
| 171 |
+
carry = ""
|
| 172 |
+
if len(chunk.split()) <= _MIN_MERGE_WORDS and i < len(sized) - 1:
|
| 173 |
+
carry = chunk
|
| 174 |
+
else:
|
| 175 |
+
merged.append(chunk)
|
| 176 |
+
if carry:
|
| 177 |
+
if merged:
|
| 178 |
+
merged[-1] += " " + carry
|
| 179 |
+
else:
|
| 180 |
+
merged.append(carry)
|
| 181 |
+
|
| 182 |
+
return merged
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 186 |
+
# Internal helpers
|
| 187 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 188 |
+
|
| 189 |
+
def _break_long_chunk(text: str, max_chars: int) -> List[str]:
|
| 190 |
+
"""Break a chunk longer than max_chars on commas or word boundaries."""
|
| 191 |
+
parts: List[str] = []
|
| 192 |
+
remaining = text
|
| 193 |
+
while len(remaining) > max_chars:
|
| 194 |
+
# Try comma first
|
| 195 |
+
pos = remaining.rfind(",", 0, max_chars)
|
| 196 |
+
if pos == -1:
|
| 197 |
+
pos = remaining.rfind(" ", 0, max_chars)
|
| 198 |
+
if pos == -1:
|
| 199 |
+
pos = max_chars # hard break
|
| 200 |
+
segment = remaining[:pos].strip()
|
| 201 |
+
if segment:
|
| 202 |
+
parts.append(segment)
|
| 203 |
+
remaining = remaining[pos:].lstrip(", ")
|
| 204 |
+
if remaining.strip():
|
| 205 |
+
parts.append(remaining.strip())
|
| 206 |
+
return parts
|