Spaces:
Sleeping
Sleeping
| import asyncio | |
| import http.client | |
| import io | |
| import json | |
| import logging | |
| import queue as stdlib_queue | |
| import threading | |
| import time | |
| import urllib.parse | |
| import uuid | |
| from concurrent.futures import ThreadPoolExecutor | |
| from dataclasses import dataclass | |
| from typing import Any, Generator, Optional | |
| import numpy as np | |
| import soundfile as sf | |
| from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile | |
| from fastapi.responses import Response, StreamingResponse | |
| from contextlib import asynccontextmanager | |
| from config import Config | |
| from chatterbox_wrapper import ChatterboxWrapper, GenerationCancelled, VoiceProfile | |
| import text_processor | |
| # ββ Logging βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s β %(levelname)-7s β %(name)s β %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # ββ Thread pool for CPU-bound inference βββββββββββββββββββββββββββ | |
| tts_executor = ThreadPoolExecutor(max_workers=Config.MAX_WORKERS) | |
| # ββ Lifespan ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def lifespan(app: FastAPI): | |
| try: | |
| wrapper = ChatterboxWrapper() | |
| app.state.wrapper = wrapper | |
| logger.info("β Model loaded, server ready") | |
| except Exception as e: | |
| logger.error(f"β Model loading failed: {e}") | |
| raise | |
| yield | |
| tts_executor.shutdown(wait=False) | |
| app = FastAPI( | |
| title="Chatterbox Turbo TTS API", | |
| version="1.0.0", | |
| docs_url="/docs", | |
| lifespan=lifespan, | |
| ) | |
| # ββ CORS Middleware βββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def cors_middleware(request: Request, call_next): | |
| origin = request.headers.get("origin") | |
| # Preflight | |
| if request.method == "OPTIONS" and origin in Config.ALLOWED_ORIGINS: | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| "Access-Control-Allow-Origin": origin, | |
| "Access-Control-Allow-Methods": "*", | |
| "Access-Control-Allow-Headers": "*", | |
| "Access-Control-Allow-Credentials": "true", | |
| }, | |
| ) | |
| if not origin or origin in Config.ALLOWED_ORIGINS: | |
| response = await call_next(request) | |
| if origin: | |
| response.headers["Access-Control-Allow-Origin"] = origin | |
| response.headers["Access-Control-Allow-Credentials"] = "true" | |
| response.headers["Access-Control-Allow-Methods"] = "*" | |
| response.headers["Access-Control-Allow-Headers"] = "*" | |
| response.headers["Access-Control-Expose-Headers"] = "X-Stream-Id" | |
| return response | |
| logger.warning(f"π« Blocked origin: {origin}") | |
| return Response(status_code=403, content="Forbidden: Origin not allowed") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helper: resolve voice from optional upload | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def _resolve_voice( | |
| voice_ref: Optional[UploadFile], | |
| voice_name: Optional[str], | |
| wrapper: ChatterboxWrapper, | |
| ) -> VoiceProfile: | |
| """Return a VoiceProfile from uploaded audio or built-in voice selection.""" | |
| if voice_ref is None or voice_ref.filename == "": | |
| try: | |
| return wrapper.get_builtin_voice(voice_name) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| audio_bytes = await voice_ref.read() | |
| if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES: | |
| raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)") | |
| if len(audio_bytes) == 0: | |
| raise HTTPException(status_code=400, detail="Empty voice file") | |
| loop = asyncio.get_running_loop() | |
| try: | |
| return await loop.run_in_executor( | |
| tts_executor, wrapper.encode_voice_from_bytes, audio_bytes | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| logger.error(f"Voice encoding failed: {e}") | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Could not process voice file: {str(e)}. " | |
| f"Supported formats: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM." | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helper: encode numpy audio to bytes in given format | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _encode_audio(audio: np.ndarray, fmt: str = "wav") -> tuple[bytes, str]: | |
| buf = io.BytesIO() | |
| fmt_lower = fmt.lower() | |
| if fmt_lower == "mp3": | |
| sf.write(buf, audio, Config.SAMPLE_RATE, format="mp3") | |
| media = "audio/mpeg" | |
| elif fmt_lower == "flac": | |
| sf.write(buf, audio, Config.SAMPLE_RATE, format="flac") | |
| media = "audio/flac" | |
| else: | |
| sf.write(buf, audio, Config.SAMPLE_RATE, format="wav") | |
| media = "audio/wav" | |
| return buf.getvalue(), media | |
| def _encode_mp3_chunk(audio: np.ndarray) -> bytes: | |
| """Encode one numpy chunk to MP3 bytes (same encoder path as current server).""" | |
| data, _ = _encode_audio(audio, fmt="mp3") | |
| return data | |
| class _ChunkPacket: | |
| index: int | |
| data: bytes | |
| lane: str | |
| produced_at: float | |
| def _internal_headers( | |
| *, | |
| content_type: Optional[str] = "application/json", | |
| accept: str = "audio/mpeg", | |
| ) -> dict[str, str]: | |
| headers: dict[str, str] = {"Accept": accept, "Connection": "keep-alive"} | |
| if content_type: | |
| headers["Content-Type"] = content_type | |
| if Config.INTERNAL_SHARED_SECRET: | |
| headers["X-Internal-Secret"] = Config.INTERNAL_SHARED_SECRET | |
| return headers | |
| class _HelperHttpClient: | |
| """Small persistent HTTP client for helper server keep-alive calls.""" | |
| def __init__(self, base_url: str, default_timeout: float): | |
| parsed = urllib.parse.urlparse((base_url or "").strip()) | |
| if parsed.scheme not in {"http", "https"} or not parsed.hostname: | |
| raise ValueError(f"Invalid helper URL: {base_url!r}") | |
| self._scheme = parsed.scheme | |
| self._host = parsed.hostname | |
| self._port = parsed.port | |
| self._base_path = (parsed.path or "").rstrip("/") | |
| self._default_timeout = max(1.0, float(default_timeout)) | |
| self._conn: Optional[http.client.HTTPConnection] = None | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc, tb): | |
| self.close() | |
| def close(self): | |
| if self._conn is not None: | |
| try: | |
| self._conn.close() | |
| except Exception: | |
| pass | |
| self._conn = None | |
| def _target(self, path: str, query: Optional[str] = None) -> str: | |
| normalized = path if path.startswith("/") else f"/{path}" | |
| target = f"{self._base_path}{normalized}" | |
| if query: | |
| target = f"{target}?{query}" | |
| return target | |
| def _make_connection(self, timeout_sec: float) -> http.client.HTTPConnection: | |
| if self._scheme == "https": | |
| return http.client.HTTPSConnection(self._host, self._port, timeout=timeout_sec) | |
| return http.client.HTTPConnection(self._host, self._port, timeout=timeout_sec) | |
| def _ensure_connection(self, timeout_sec: float) -> http.client.HTTPConnection: | |
| if self._conn is None: | |
| self._conn = self._make_connection(timeout_sec) | |
| else: | |
| self._conn.timeout = timeout_sec | |
| return self._conn | |
| def _request( | |
| self, | |
| method: str, | |
| path: str, | |
| *, | |
| body: Optional[bytes] = None, | |
| headers: Optional[dict[str, str]] = None, | |
| timeout_sec: Optional[float] = None, | |
| query: Optional[str] = None, | |
| ) -> tuple[int, bytes, dict[str, str]]: | |
| timeout = max(1.0, float(timeout_sec or self._default_timeout)) | |
| target = self._target(path, query=query) | |
| req_headers = headers or {} | |
| conn = self._ensure_connection(timeout) | |
| try: | |
| conn.request(method=method, url=target, body=body, headers=req_headers) | |
| resp = conn.getresponse() | |
| payload = resp.read() | |
| resp_headers = {k.lower(): v for k, v in resp.getheaders()} | |
| except Exception: | |
| # Force reconnect on next attempt if socket is stale/reset. | |
| self.close() | |
| raise | |
| if resp.status >= 400: | |
| snippet = payload[:256].decode("utf-8", errors="replace") | |
| raise RuntimeError( | |
| f"helper {method} {target} returned {resp.status}: {snippet}" | |
| ) | |
| return resp.status, payload, resp_headers | |
| def request_chunk(self, payload: dict[str, Any], timeout_sec: float) -> bytes: | |
| _, data, _ = self._request( | |
| "POST", | |
| "/internal/chunk/synthesize", | |
| body=json.dumps(payload).encode("utf-8"), | |
| headers=_internal_headers(content_type="application/json", accept="audio/mpeg"), | |
| timeout_sec=timeout_sec, | |
| ) | |
| return data | |
| def register_voice(self, stream_id: str, audio_bytes: bytes, timeout_sec: float) -> str: | |
| query = urllib.parse.urlencode({"stream_id": stream_id}) | |
| _, data, _ = self._request( | |
| "POST", | |
| "/internal/voice/register", | |
| query=query, | |
| body=audio_bytes, | |
| headers=_internal_headers( | |
| content_type="application/octet-stream", | |
| accept="application/json", | |
| ), | |
| timeout_sec=timeout_sec, | |
| ) | |
| payload = json.loads(data.decode("utf-8")) | |
| voice_key = (payload.get("voice_key") or "").strip() | |
| if not voice_key: | |
| raise RuntimeError("helper voice registration returned no voice_key") | |
| return voice_key | |
| def cancel_stream(self, stream_id: str, timeout_sec: float = 3.0): | |
| self._request( | |
| "POST", | |
| f"/internal/chunk/cancel/{stream_id}", | |
| body=b"", | |
| headers=_internal_headers(), | |
| timeout_sec=timeout_sec, | |
| ) | |
| def complete_stream(self, stream_id: str, timeout_sec: float = 3.0): | |
| self._request( | |
| "POST", | |
| f"/internal/chunk/complete/{stream_id}", | |
| body=b"", | |
| headers=_internal_headers(), | |
| timeout_sec=timeout_sec, | |
| ) | |
| def _helper_request_chunk( | |
| helper_base_url: str, | |
| payload: dict, | |
| timeout_sec: float, | |
| helper_client: Optional[_HelperHttpClient] = None, | |
| ) -> bytes: | |
| if helper_client is not None: | |
| return helper_client.request_chunk(payload, timeout_sec=timeout_sec) | |
| with _HelperHttpClient(helper_base_url, default_timeout=timeout_sec) as helper_client_single: | |
| return helper_client_single.request_chunk(payload, timeout_sec=timeout_sec) | |
| def _helper_register_voice( | |
| helper_base_url: str, | |
| stream_id: str, | |
| audio_bytes: bytes, | |
| timeout_sec: float, | |
| helper_client: Optional[_HelperHttpClient] = None, | |
| ) -> str: | |
| """Register reference voice on helper once, return voice_key for chunk calls.""" | |
| if helper_client is not None: | |
| return helper_client.register_voice( | |
| stream_id=stream_id, | |
| audio_bytes=audio_bytes, | |
| timeout_sec=timeout_sec, | |
| ) | |
| with _HelperHttpClient(helper_base_url, default_timeout=timeout_sec) as helper_client_single: | |
| return helper_client_single.register_voice( | |
| stream_id=stream_id, | |
| audio_bytes=audio_bytes, | |
| timeout_sec=timeout_sec, | |
| ) | |
| def _helper_cancel_stream(helper_base_url: str, stream_id: str): | |
| """Best-effort cancellation signal to helper.""" | |
| try: | |
| with _HelperHttpClient(helper_base_url, default_timeout=3.0) as helper_client: | |
| helper_client.cancel_stream(stream_id=stream_id, timeout_sec=3.0) | |
| except Exception: | |
| pass | |
| def _helper_complete_stream(helper_base_url: str, stream_id: str): | |
| """Best-effort stream completion cleanup on helper. | |
| Falls back to cancel for backwards compatibility if helper does not expose | |
| the completion endpoint yet. | |
| """ | |
| try: | |
| with _HelperHttpClient(helper_base_url, default_timeout=3.0) as helper_client: | |
| helper_client.complete_stream(stream_id=stream_id, timeout_sec=3.0) | |
| except Exception: | |
| _helper_cancel_stream(helper_base_url, stream_id) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Endpoints | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def health(warm_up: bool = False): | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| with _internal_cancel_lock: | |
| _purge_internal_stream_state_locked() | |
| cancelled_count = len(_internal_cancelled_streams) | |
| voice_state_count = len(_internal_stream_voice_keys) | |
| status = { | |
| "status": "healthy" if wrapper else "loading", | |
| "model_loaded": wrapper is not None, | |
| "model_dtype": Config.MODEL_DTYPE, | |
| "streaming_supported": True, | |
| "voice_cache_entries": wrapper._voice_cache.size if wrapper else 0, | |
| "internal_cancelled_streams": cancelled_count, | |
| "internal_stream_voice_states": voice_state_count, | |
| } | |
| if warm_up and wrapper: | |
| try: | |
| loop = asyncio.get_running_loop() | |
| await loop.run_in_executor(tts_executor, wrapper.warmup) | |
| status["warm_up"] = "success" | |
| except Exception as e: | |
| status["warm_up"] = f"failed: {e}" | |
| return status | |
| async def list_voices(): | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| voices = wrapper.list_builtin_voices() | |
| return { | |
| "count": len(voices), | |
| "default_voice": wrapper.default_voice_name, | |
| "voices": voices, | |
| "usage": { | |
| "form_field": "voice_name", | |
| "json_field": "voice", | |
| "note": "If voice_ref is uploaded, it overrides voice_name.", | |
| }, | |
| } | |
| # ββ POST /tts βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def text_to_speech( | |
| text: str = Form(...), | |
| voice_ref: Optional[UploadFile] = File(None), | |
| voice_name: str = Form("default"), | |
| output_format: str = Form("wav"), | |
| max_new_tokens: int = Form(Config.MAX_NEW_TOKENS), | |
| repetition_penalty: float = Form(Config.REPETITION_PENALTY), | |
| ): | |
| """Generate complete audio for the given text.""" | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| if not text or not text.strip(): | |
| raise HTTPException(400, "Text is required") | |
| voice = await _resolve_voice(voice_ref, voice_name, wrapper) | |
| loop = asyncio.get_running_loop() | |
| try: | |
| audio = await loop.run_in_executor( | |
| tts_executor, | |
| wrapper.generate_speech, | |
| text, voice, max_new_tokens, repetition_penalty, | |
| ) | |
| except ValueError as e: | |
| raise HTTPException(400, str(e)) | |
| except Exception as e: | |
| logger.error(f"TTS error: {e}") | |
| raise HTTPException(500, "Internal server error") | |
| data, media_type = _encode_audio(audio, output_format) | |
| return Response( | |
| content=data, | |
| media_type=media_type, | |
| headers={"Content-Disposition": f"attachment; filename=tts_output.{output_format}"}, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Active Stream Registry (for cancellation) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _active_streams: dict[str, threading.Event] = {} | |
| # stream_id -> expires_at epoch seconds | |
| _internal_cancelled_streams: dict[str, float] = {} | |
| _internal_cancel_lock = threading.Lock() | |
| # stream_id -> (voice_keys, expires_at) | |
| _internal_stream_voice_keys: dict[str, tuple[set[str], float]] = {} | |
| # stream_id -> helper base URLs (used to cancel helpers quickly on /tts/stop) | |
| _stream_helper_routes: dict[str, set[str]] = {} | |
| _stream_routes_lock = threading.Lock() | |
| def _purge_internal_stream_state_locked(now: Optional[float] = None): | |
| now_ts = now if now is not None else time.time() | |
| expired_cancel_ids = [ | |
| sid for sid, expires_at in _internal_cancelled_streams.items() | |
| if expires_at <= now_ts | |
| ] | |
| for sid in expired_cancel_ids: | |
| _internal_cancelled_streams.pop(sid, None) | |
| expired_voice_state_ids = [ | |
| sid for sid, (_, expires_at) in _internal_stream_voice_keys.items() | |
| if expires_at <= now_ts | |
| ] | |
| for sid in expired_voice_state_ids: | |
| _internal_stream_voice_keys.pop(sid, None) | |
| def _touch_internal_stream_voice_keys_locked(stream_id: str): | |
| if not stream_id: | |
| return | |
| entry = _internal_stream_voice_keys.get(stream_id) | |
| if entry is None: | |
| return | |
| keys, _ = entry | |
| _internal_stream_voice_keys[stream_id] = ( | |
| keys, | |
| time.time() + max(1, Config.INTERNAL_STREAM_STATE_TTL_SEC), | |
| ) | |
| def _clear_internal_stream_state_locked(stream_id: str): | |
| _internal_cancelled_streams.pop(stream_id, None) | |
| _internal_stream_voice_keys.pop(stream_id, None) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Pipeline Streaming Generator | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _pipeline_stream_generator( | |
| wrapper: ChatterboxWrapper, | |
| text: str, | |
| voice: VoiceProfile, | |
| max_new_tokens: int, | |
| repetition_penalty: float, | |
| stream_id: str, | |
| ) -> Generator[bytes, None, None]: | |
| """Two-stage producer-consumer pipeline for minimal inter-chunk gaps. | |
| Architecture: | |
| Producer thread (heavyweight, ~80% CPU): | |
| ONNX token generation β audio decoding β raw numpy arrays β queue | |
| Consumer (this generator, lightweight, ~20% CPU): | |
| queue β MP3 encode β yield to HTTP response | |
| Why this helps: | |
| - ONNX model runs CONTINUOUSLY without waiting for MP3 encode or HTTP | |
| - MP3 encoding (libsndfile, C code) releases GIL β true parallelism | |
| - ONNX inference (C++ code) also releases GIL β both run simultaneously | |
| - Queue(maxsize=2) lets producer stay 1-2 chunks ahead | |
| Cancellation: | |
| - cancel_event checked between chunks + every 25 autoregressive steps | |
| - Client disconnect triggers GeneratorExit β finally sets cancel | |
| - /tts/stop endpoint sets cancel externally | |
| """ | |
| cancel_event = threading.Event() | |
| _active_streams[stream_id] = cancel_event | |
| # Raw audio buffer: producer puts numpy arrays, consumer takes them | |
| audio_buffer: stdlib_queue.Queue = stdlib_queue.Queue(maxsize=2) | |
| def _producer(): | |
| """Heavyweight worker: runs ONNX model continuously.""" | |
| try: | |
| for audio_chunk in wrapper.stream_speech( | |
| text, voice, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| is_cancelled=cancel_event.is_set, | |
| ): | |
| if cancel_event.is_set(): | |
| break | |
| while not cancel_event.is_set(): | |
| try: | |
| audio_buffer.put(audio_chunk, timeout=0.1) | |
| break | |
| except stdlib_queue.Full: | |
| continue | |
| except GenerationCancelled: | |
| logger.info(f"[{stream_id}] Generation cancelled") | |
| except Exception as e: | |
| while not cancel_event.is_set(): | |
| try: | |
| audio_buffer.put(e, timeout=0.1) | |
| break | |
| except stdlib_queue.Full: | |
| continue | |
| finally: | |
| while not cancel_event.is_set(): | |
| try: | |
| audio_buffer.put(None, timeout=0.1) | |
| break | |
| except stdlib_queue.Full: | |
| continue | |
| producer = threading.Thread(target=_producer, daemon=True) | |
| producer.start() | |
| try: | |
| # Consumer: lightweight MP3 encoding + yield | |
| while True: | |
| item = audio_buffer.get() | |
| if item is None: | |
| break | |
| if isinstance(item, Exception): | |
| logger.error(f"[{stream_id}] Stream error: {item}") | |
| break | |
| if cancel_event.is_set(): | |
| break | |
| # MP3 encode (C code, releases GIL, runs parallel with next ONNX step) | |
| buf = io.BytesIO() | |
| sf.write(buf, item, Config.SAMPLE_RATE, format="mp3") | |
| yield buf.getvalue() | |
| finally: | |
| # Cleanup: signal producer to stop + deregister | |
| cancel_event.set() | |
| _active_streams.pop(stream_id, None) | |
| def _parallel_two_way_stream_generator( | |
| wrapper: ChatterboxWrapper, | |
| text: str, | |
| local_voice: VoiceProfile, | |
| helper_voice_bytes: Optional[bytes], | |
| max_new_tokens: int, | |
| repetition_penalty: float, | |
| stream_id: str, | |
| helper_base_url: str, | |
| ) -> Generator[bytes, None, None]: | |
| """Additive 2-way split streamer (primary + helper). | |
| Routing pattern: | |
| - chunk 0,2,4... -> primary (local) | |
| - chunk 1,3,5... -> helper | |
| """ | |
| cancel_event = threading.Event() | |
| _active_streams[stream_id] = cancel_event | |
| helper_base_url = (helper_base_url or "").strip() | |
| helper_route_set = {helper_base_url} if helper_base_url else set() | |
| if helper_route_set: | |
| with _stream_routes_lock: | |
| _stream_helper_routes[stream_id] = set(helper_route_set) | |
| clean_text = text_processor.sanitize(text.strip()[: Config.MAX_TEXT_LENGTH]) | |
| chunks = text_processor.split_for_streaming(clean_text) | |
| total_chunks = len(chunks) | |
| if total_chunks == 0: | |
| with _stream_routes_lock: | |
| _stream_helper_routes.pop(stream_id, None) | |
| _active_streams.pop(stream_id, None) | |
| return | |
| lock = threading.Lock() | |
| cond = threading.Condition(lock) | |
| ready: dict[int, _ChunkPacket] = {} | |
| first_error: Optional[Exception] = None | |
| workers_done = 0 | |
| expected_workers = 2 | |
| stream_completed = False | |
| def _publish(packet: _ChunkPacket): | |
| with cond: | |
| # First write wins for an index to avoid duplicate retry races. | |
| if packet.index not in ready: | |
| ready[packet.index] = packet | |
| cond.notify_all() | |
| def _set_error(err: Exception): | |
| nonlocal first_error | |
| with cond: | |
| if first_error is None: | |
| first_error = err | |
| cond.notify_all() | |
| def _worker_done(): | |
| nonlocal workers_done | |
| with cond: | |
| workers_done += 1 | |
| cond.notify_all() | |
| def _synth_local(chunk_text: str) -> bytes: | |
| audio = wrapper.generate_speech( | |
| chunk_text, | |
| local_voice, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| ) | |
| return _encode_mp3_chunk(audio) | |
| def _local_worker(): | |
| try: | |
| for idx in range(0, total_chunks, 2): | |
| if cancel_event.is_set(): | |
| break | |
| data = _synth_local(chunks[idx]) | |
| _publish( | |
| _ChunkPacket( | |
| index=idx, | |
| data=data, | |
| lane="primary", | |
| produced_at=time.perf_counter(), | |
| ) | |
| ) | |
| except Exception as e: | |
| _set_error(e) | |
| finally: | |
| _worker_done() | |
| def _helper_worker(): | |
| helper_available = bool(helper_base_url) | |
| helper_voice_key: Optional[str] = None | |
| helper_timeout = max(1.0, Config.HELPER_TIMEOUT_SEC) | |
| helper_client: Optional[_HelperHttpClient] = None | |
| try: | |
| if helper_available: | |
| try: | |
| helper_client = _HelperHttpClient( | |
| helper_base_url, | |
| default_timeout=helper_timeout, | |
| ) | |
| except Exception as conn_err: | |
| helper_available = False | |
| logger.warning( | |
| f"[{stream_id}] helper keep-alive init failed ({conn_err}); " | |
| "using local fallback for helper lane" | |
| ) | |
| if helper_available and helper_voice_bytes: | |
| attempts = 2 if Config.HELPER_RETRY_ONCE else 1 | |
| last_err: Optional[Exception] = None | |
| for _ in range(attempts): | |
| try: | |
| helper_voice_key = _helper_register_voice( | |
| helper_base_url=helper_base_url, | |
| stream_id=stream_id, | |
| audio_bytes=helper_voice_bytes, | |
| timeout_sec=helper_timeout, | |
| helper_client=helper_client, | |
| ) | |
| last_err = None | |
| break | |
| except Exception as reg_err: | |
| last_err = reg_err | |
| continue | |
| if last_err is not None: | |
| helper_available = False | |
| logger.warning( | |
| f"[{stream_id}] helper voice registration failed; " | |
| "falling back to local synthesis for helper lane" | |
| ) | |
| elif not helper_available: | |
| logger.info( | |
| f"[{stream_id}] helper URL not configured; using local fallback" | |
| ) | |
| for idx in range(1, total_chunks, 2): | |
| if cancel_event.is_set(): | |
| break | |
| if helper_available: | |
| payload = { | |
| "stream_id": stream_id, | |
| "chunk_index": idx, | |
| "text": chunks[idx], | |
| "max_new_tokens": max_new_tokens, | |
| "repetition_penalty": repetition_penalty, | |
| "output_format": "mp3", | |
| } | |
| if helper_voice_key: | |
| payload["voice_key"] = helper_voice_key | |
| attempts = 2 if Config.HELPER_RETRY_ONCE else 1 | |
| last_err: Optional[Exception] = None | |
| for _ in range(attempts): | |
| try: | |
| helper_data = _helper_request_chunk( | |
| helper_base_url=helper_base_url, | |
| payload=payload, | |
| timeout_sec=helper_timeout, | |
| helper_client=helper_client, | |
| ) | |
| _publish( | |
| _ChunkPacket( | |
| index=idx, | |
| data=helper_data, | |
| lane="helper", | |
| produced_at=time.perf_counter(), | |
| ) | |
| ) | |
| last_err = None | |
| break | |
| except Exception as helper_err: | |
| last_err = helper_err | |
| continue | |
| if last_err is None: | |
| continue | |
| helper_available = False | |
| logger.warning( | |
| f"[{stream_id}] helper failed at chunk {idx}; " | |
| "falling back to local synthesis for remaining helper chunks" | |
| ) | |
| # Local fallback for helper lane | |
| data = _synth_local(chunks[idx]) | |
| _publish( | |
| _ChunkPacket( | |
| index=idx, | |
| data=data, | |
| lane="helper-local-fallback", | |
| produced_at=time.perf_counter(), | |
| ) | |
| ) | |
| except Exception as e: | |
| _set_error(e) | |
| finally: | |
| if helper_client is not None: | |
| helper_client.close() | |
| _worker_done() | |
| local_thread = threading.Thread(target=_local_worker, daemon=True) | |
| helper_thread = threading.Thread(target=_helper_worker, daemon=True) | |
| local_thread.start() | |
| helper_thread.start() | |
| next_idx = 0 | |
| try: | |
| while next_idx < total_chunks: | |
| with cond: | |
| while ( | |
| next_idx not in ready | |
| and first_error is None | |
| and not cancel_event.is_set() | |
| and workers_done < expected_workers | |
| ): | |
| cond.wait(timeout=0.1) | |
| if cancel_event.is_set(): | |
| break | |
| if next_idx in ready: | |
| packet = ready.pop(next_idx) | |
| buffered_chunks = len(ready) | |
| elif first_error is not None: | |
| logger.error(f"[{stream_id}] Parallel stream error: {first_error}") | |
| break | |
| elif workers_done >= expected_workers: | |
| logger.error( | |
| f"[{stream_id}] Parallel stream ended with missing chunk index {next_idx}" | |
| ) | |
| break | |
| else: | |
| continue | |
| logger.debug( | |
| "[%s] stitch emit chunk %s/%s from %s (buffered=%s)", | |
| stream_id, | |
| next_idx + 1, | |
| total_chunks, | |
| packet.lane, | |
| buffered_chunks, | |
| ) | |
| yield packet.data | |
| next_idx += 1 | |
| stream_completed = ( | |
| next_idx >= total_chunks | |
| and first_error is None | |
| and not cancel_event.is_set() | |
| ) | |
| finally: | |
| cancel_event.set() | |
| # For fast stop/cancel, signal helpers first; for normal completion, wait for | |
| # workers to flush and then ask helpers to clear stream state. | |
| if not stream_completed: | |
| for base_url in helper_route_set: | |
| _helper_cancel_stream(base_url, stream_id) | |
| local_thread.join(timeout=1.0) | |
| helper_thread.join(timeout=1.0) | |
| if stream_completed: | |
| for base_url in helper_route_set: | |
| _helper_complete_stream(base_url, stream_id) | |
| with _stream_routes_lock: | |
| _stream_helper_routes.pop(stream_id, None) | |
| _active_streams.pop(stream_id, None) | |
| # ββ POST /tts/stream & /tts/true-stream ββββββββββββββββββββββββββ | |
| async def stream_text_to_speech( | |
| text: str = Form(...), | |
| voice_ref: Optional[UploadFile] = File(None), | |
| voice_name: str = Form("default"), | |
| max_new_tokens: int = Form(Config.MAX_NEW_TOKENS), | |
| repetition_penalty: float = Form(Config.REPETITION_PENALTY), | |
| ): | |
| """True real-time streaming: yields MP3 chunks as each sentence finishes. | |
| Response includes X-Stream-Id header for cancellation via /tts/stop. | |
| Compatible with frontend's MediaSource + ReadableStream pattern. | |
| """ | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| if not text or not text.strip(): | |
| raise HTTPException(400, "Text is required") | |
| voice = await _resolve_voice(voice_ref, voice_name, wrapper) | |
| stream_id = uuid.uuid4().hex[:12] | |
| return StreamingResponse( | |
| _pipeline_stream_generator( | |
| wrapper, text, voice, max_new_tokens, repetition_penalty, stream_id, | |
| ), | |
| media_type="audio/mpeg", | |
| headers={ | |
| "Content-Disposition": "attachment; filename=tts_stream.mp3", | |
| "Transfer-Encoding": "chunked", | |
| "X-Stream-Id": stream_id, | |
| "X-Streaming-Type": "true-realtime", | |
| "Cache-Control": "no-cache", | |
| }, | |
| ) | |
| async def parallel_stream_text_to_speech( | |
| text: str = Form(...), | |
| voice_ref: Optional[UploadFile] = File(None), | |
| voice_name: str = Form("default"), | |
| max_new_tokens: int = Form(Config.MAX_NEW_TOKENS), | |
| repetition_penalty: float = Form(Config.REPETITION_PENALTY), | |
| helper_url: Optional[str] = Form(None), | |
| ): | |
| """Additive 2-way split stream mode (primary + helper).""" | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| if not Config.ENABLE_PARALLEL_MODE: | |
| raise HTTPException(503, "Parallel mode is disabled") | |
| if not text or not text.strip(): | |
| raise HTTPException(400, "Text is required") | |
| local_voice: VoiceProfile = wrapper.default_voice | |
| helper_voice_bytes: Optional[bytes] = None | |
| if voice_ref is not None and voice_ref.filename: | |
| helper_voice_bytes = await voice_ref.read() | |
| if len(helper_voice_bytes) > Config.MAX_VOICE_UPLOAD_BYTES: | |
| raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)") | |
| if len(helper_voice_bytes) == 0: | |
| raise HTTPException(status_code=400, detail="Empty voice file") | |
| loop = asyncio.get_running_loop() | |
| try: | |
| local_voice = await loop.run_in_executor( | |
| tts_executor, wrapper.encode_voice_from_bytes, helper_voice_bytes | |
| ) | |
| except Exception as e: | |
| logger.error(f"Parallel voice encoding failed: {e}") | |
| raise HTTPException(400, "Could not process voice file for parallel mode") | |
| else: | |
| try: | |
| selected_voice_id = wrapper.resolve_voice_id(voice_name) | |
| local_voice = wrapper.get_builtin_voice(selected_voice_id) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| # Ensure helper uses the same selected built-in voice. | |
| if selected_voice_id != wrapper.default_voice_name: | |
| helper_voice_bytes = wrapper.get_builtin_voice_bytes(selected_voice_id) | |
| if not helper_voice_bytes: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Selected voice '{voice_name}' is unavailable for helper registration", | |
| ) | |
| resolved_helper = (helper_url or Config.HELPER_BASE_URL).strip() | |
| if not resolved_helper: | |
| raise HTTPException( | |
| 400, | |
| "No helper configured. Set CB_HELPER_BASE_URL or pass helper_url.", | |
| ) | |
| stream_id = uuid.uuid4().hex[:12] | |
| return StreamingResponse( | |
| _parallel_two_way_stream_generator( | |
| wrapper=wrapper, | |
| text=text, | |
| local_voice=local_voice, | |
| helper_voice_bytes=helper_voice_bytes, | |
| max_new_tokens=max_new_tokens, | |
| repetition_penalty=repetition_penalty, | |
| stream_id=stream_id, | |
| helper_base_url=resolved_helper, | |
| ), | |
| media_type="audio/mpeg", | |
| headers={ | |
| "Content-Disposition": "attachment; filename=tts_parallel_stream.mp3", | |
| "Transfer-Encoding": "chunked", | |
| "X-Stream-Id": stream_id, | |
| "X-Streaming-Type": "parallel-2way", | |
| "Cache-Control": "no-cache", | |
| }, | |
| ) | |
| # ββ JSON body variant (Kokoro/OpenAI compatibility) βββββββββββββββ | |
| from pydantic import BaseModel, Field | |
| class InternalChunkRequest(BaseModel): | |
| stream_id: str = Field(..., min_length=1, max_length=64) | |
| chunk_index: int = Field(..., ge=0) | |
| text: str = Field(..., min_length=1, max_length=10000) | |
| max_new_tokens: int = Field(default=Config.MAX_NEW_TOKENS, ge=64, le=2048) | |
| repetition_penalty: float = Field(default=Config.REPETITION_PENALTY, ge=1.0, le=2.0) | |
| output_format: str = Field(default="mp3") | |
| voice_key: Optional[str] = Field(default=None, min_length=1, max_length=64) | |
| class TTSJsonRequest(BaseModel): | |
| text: str = Field(..., min_length=1, max_length=50000) | |
| voice: str = Field(default="default") | |
| speed: float = Field(default=1.0, ge=0.5, le=2.0) # reserved for future use | |
| max_new_tokens: int = Field(default=Config.MAX_NEW_TOKENS, ge=64, le=2048) | |
| repetition_penalty: float = Field(default=Config.REPETITION_PENALTY, ge=1.0, le=2.0) | |
| async def internal_voice_register(http_request: Request): | |
| """Register voice once for a stream; returns reusable voice_key.""" | |
| if Config.INTERNAL_SHARED_SECRET: | |
| provided = http_request.headers.get("X-Internal-Secret", "") | |
| if provided != Config.INTERNAL_SHARED_SECRET: | |
| raise HTTPException(403, "Forbidden") | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| audio_bytes = await http_request.body() | |
| if len(audio_bytes) > Config.MAX_VOICE_UPLOAD_BYTES: | |
| raise HTTPException(status_code=413, detail="Voice file too large (max 10 MB)") | |
| if len(audio_bytes) == 0: | |
| raise HTTPException(status_code=400, detail="Empty voice file") | |
| loop = asyncio.get_running_loop() | |
| try: | |
| voice = await loop.run_in_executor( | |
| tts_executor, wrapper.encode_voice_from_bytes, audio_bytes | |
| ) | |
| except Exception as e: | |
| logger.error(f"[internal] voice register failed: {e}") | |
| raise HTTPException(400, "Voice registration failed") | |
| voice_key = (voice.audio_hash or "").strip() | |
| if not voice_key: | |
| raise HTTPException(500, "Voice key unavailable") | |
| stream_id = (http_request.query_params.get("stream_id") or "").strip() | |
| if stream_id: | |
| with _internal_cancel_lock: | |
| _purge_internal_stream_state_locked() | |
| keys, _ = _internal_stream_voice_keys.get(stream_id, (set(), 0.0)) | |
| keys.add(voice_key) | |
| _internal_stream_voice_keys[stream_id] = ( | |
| keys, | |
| time.time() + max(1, Config.INTERNAL_STREAM_STATE_TTL_SEC), | |
| ) | |
| return {"status": "registered", "voice_key": voice_key} | |
| async def internal_chunk_synthesize( | |
| request: InternalChunkRequest, | |
| http_request: Request, | |
| ): | |
| """Internal endpoint used by primary/helper parallel routing.""" | |
| if Config.INTERNAL_SHARED_SECRET: | |
| provided = http_request.headers.get("X-Internal-Secret", "") | |
| if provided != Config.INTERNAL_SHARED_SECRET: | |
| raise HTTPException(403, "Forbidden") | |
| with _internal_cancel_lock: | |
| _purge_internal_stream_state_locked() | |
| if request.stream_id in _internal_cancelled_streams: | |
| raise HTTPException(409, "Stream already cancelled") | |
| _touch_internal_stream_voice_keys_locked(request.stream_id) | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| voice_profile = wrapper.default_voice | |
| if request.voice_key: | |
| cached_voice = wrapper._voice_cache.get(request.voice_key) | |
| if cached_voice is None: | |
| # Built-in voices are permanent in wrapper registry even if TTL cache entry expired. | |
| cached_voice = wrapper.get_builtin_voice_by_hash(request.voice_key) | |
| if cached_voice is None: | |
| raise HTTPException(409, "Voice key expired or not found") | |
| voice_profile = cached_voice | |
| loop = asyncio.get_running_loop() | |
| try: | |
| audio = await loop.run_in_executor( | |
| tts_executor, | |
| wrapper.generate_speech, | |
| request.text, | |
| voice_profile, | |
| request.max_new_tokens, | |
| request.repetition_penalty, | |
| ) | |
| except Exception as e: | |
| logger.error(f"[internal] chunk {request.chunk_index} failed: {e}") | |
| raise HTTPException(500, "Chunk synthesis failed") | |
| fmt = (request.output_format or "mp3").lower() | |
| if fmt not in {"mp3", "wav", "flac"}: | |
| fmt = "mp3" | |
| data, media_type = _encode_audio(audio, fmt=fmt) | |
| return Response( | |
| content=data, | |
| media_type=media_type, | |
| headers={ | |
| "X-Stream-Id": request.stream_id, | |
| "X-Chunk-Index": str(request.chunk_index), | |
| }, | |
| ) | |
| async def internal_chunk_cancel(stream_id: str, http_request: Request): | |
| if Config.INTERNAL_SHARED_SECRET: | |
| provided = http_request.headers.get("X-Internal-Secret", "") | |
| if provided != Config.INTERNAL_SHARED_SECRET: | |
| raise HTTPException(403, "Forbidden") | |
| with _internal_cancel_lock: | |
| _purge_internal_stream_state_locked() | |
| _internal_cancelled_streams[stream_id] = ( | |
| time.time() + max(1, Config.INTERNAL_CANCEL_TTL_SEC) | |
| ) | |
| _internal_stream_voice_keys.pop(stream_id, None) | |
| return {"status": "cancelled", "stream_id": stream_id} | |
| async def internal_chunk_complete(stream_id: str, http_request: Request): | |
| """Best-effort immediate cleanup after stream completes normally.""" | |
| if Config.INTERNAL_SHARED_SECRET: | |
| provided = http_request.headers.get("X-Internal-Secret", "") | |
| if provided != Config.INTERNAL_SHARED_SECRET: | |
| raise HTTPException(403, "Forbidden") | |
| with _internal_cancel_lock: | |
| _purge_internal_stream_state_locked() | |
| _clear_internal_stream_state_locked(stream_id) | |
| return {"status": "completed", "stream_id": stream_id} | |
| async def openai_compatible_tts(request: TTSJsonRequest): | |
| """OpenAI-compatible streaming endpoint (JSON body, no file upload). | |
| Uses built-in voice selection via `voice`. For voice cloning, use /tts/stream with FormData. | |
| """ | |
| wrapper: ChatterboxWrapper = getattr(app.state, "wrapper", None) | |
| if not wrapper: | |
| raise HTTPException(503, "Model not loaded") | |
| try: | |
| selected_voice = wrapper.get_builtin_voice(request.voice) | |
| except ValueError as e: | |
| raise HTTPException(400, str(e)) | |
| stream_id = uuid.uuid4().hex[:12] | |
| return StreamingResponse( | |
| _pipeline_stream_generator( | |
| wrapper, request.text, selected_voice, | |
| request.max_new_tokens, request.repetition_penalty, stream_id, | |
| ), | |
| media_type="audio/mpeg", | |
| headers={ | |
| "Transfer-Encoding": "chunked", | |
| "X-Stream-Id": stream_id, | |
| "Cache-Control": "no-cache", | |
| }, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Stop / Cancel Endpoint | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def stop_stream(stream_id: str): | |
| """Stop an active TTS stream by its ID (from X-Stream-Id header). | |
| Cancels the ONNX generation loop mid-token, freeing CPU immediately. | |
| """ | |
| event = _active_streams.get(stream_id) | |
| if event: | |
| event.set() | |
| with _stream_routes_lock: | |
| helper_routes = set(_stream_helper_routes.pop(stream_id, set())) | |
| for helper_url in helper_routes: | |
| _helper_cancel_stream(helper_url, stream_id) | |
| logger.info(f"Stream {stream_id} cancelled by client") | |
| return {"status": "stopped", "stream_id": stream_id} | |
| return {"status": "not_found", "stream_id": stream_id} | |
| async def stop_all_streams(): | |
| """Emergency stop: cancel ALL active TTS streams.""" | |
| active_items = list(_active_streams.items()) | |
| count = len(active_items) | |
| with _stream_routes_lock: | |
| stream_routes = {sid: set(urls) for sid, urls in _stream_helper_routes.items()} | |
| _stream_helper_routes.clear() | |
| for sid, event in active_items: | |
| event.set() | |
| for helper_url in stream_routes.get(sid, set()): | |
| _helper_cancel_stream(helper_url, sid) | |
| _active_streams.clear() | |
| logger.info(f"Stopped all streams ({count} active)") | |
| return {"status": "stopped_all", "count": count} | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Entrypoint | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host=Config.HOST, port=Config.PORT) | |