Spaces:
Sleeping
Sleeping
| """ | |
| Chatterbox Turbo TTS β ONNX Inference Wrapper | |
| βββββββββββββββββββββββββββββββββββββββββββββββ | |
| Orchestrates the 4-component ONNX pipeline: | |
| embed_tokens β speech_encoder β language_model β conditional_decoder | |
| Optimised for lowest-latency CPU inference on 2 vCPU: | |
| β’ Sequential execution, thread count = physical cores, no spinning | |
| β’ Token list pre-allocation (avoids O(nΒ²) np.concatenate in loop) | |
| β’ In-memory voice caching (no disk writes for uploads) | |
| β’ Robust audio loading: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM | |
| β’ Sentence-level streaming for real-time playback | |
| """ | |
| # ββ Suppress harmless transformers warnings BEFORE import βββββββββ | |
| import os | |
| import warnings | |
| os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" | |
| warnings.filterwarnings("ignore", message=".*model of type.*chatterbox.*") | |
| import hashlib | |
| import io | |
| import logging | |
| import subprocess | |
| import tempfile | |
| import time | |
| from collections import OrderedDict | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Callable, Generator, Optional | |
| import librosa | |
| import numpy as np | |
| import onnxruntime as ort | |
| import soundfile as soundfile_lib | |
| from huggingface_hub import hf_hub_download | |
| from transformers import AutoTokenizer | |
| from config import Config | |
| import text_processor | |
| logger = logging.getLogger(__name__) | |
| # ββ Supported audio MIME types for voice upload βββββββββββββββββββ | |
| _SUPPORTED_AUDIO_EXTENSIONS = { | |
| ".wav", ".mp3", ".mpeg", ".mpga", ".m4a", ".mp4", | |
| ".ogg", ".oga", ".opus", ".flac", ".webm", ".aac", ".wma", | |
| } | |
| def _slugify(text: str) -> str: | |
| buf = [] | |
| prev_underscore = False | |
| for ch in text.strip().lower(): | |
| if ch.isalnum(): | |
| buf.append(ch) | |
| prev_underscore = False | |
| else: | |
| if not prev_underscore: | |
| buf.append("_") | |
| prev_underscore = True | |
| slug = "".join(buf).strip("_") | |
| return slug or "voice" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Data Structures | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class VoiceProfile: | |
| """Cached speaker embedding extracted from reference audio.""" | |
| cond_emb: np.ndarray | |
| prompt_token: np.ndarray | |
| speaker_embeddings: np.ndarray | |
| speaker_features: np.ndarray | |
| audio_hash: str = "" | |
| class GenerationCancelled(Exception): | |
| """Raised when inference is cancelled by the client.""" | |
| pass | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LRU Voice Cache | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class _VoiceCache: | |
| """LRU cache for VoiceProfile objects with TTL-based expiration. | |
| Entries auto-expire after `ttl_seconds` (default: 1 hour). | |
| Re-uploading the same voice file within the TTL window returns | |
| the cached profile instantly β no re-encoding needed. | |
| """ | |
| def __init__(self, maxsize: int, ttl_seconds: int = 3600): | |
| self._cache: OrderedDict[str, tuple[VoiceProfile, float]] = OrderedDict() | |
| self._maxsize = maxsize | |
| self._ttl = ttl_seconds | |
| def _evict_expired(self): | |
| """Remove all entries older than TTL.""" | |
| now = time.time() | |
| expired = [k for k, (_, ts) in self._cache.items() if now - ts > self._ttl] | |
| for k in expired: | |
| del self._cache[k] | |
| logger.debug(f"Voice cache expired: {k[:8]}β¦") | |
| def get(self, key: str) -> Optional[VoiceProfile]: | |
| self._evict_expired() | |
| if key in self._cache: | |
| profile, ts = self._cache[key] | |
| remaining = self._ttl - (time.time() - ts) | |
| self._cache.move_to_end(key) | |
| logger.info(f"Voice cache HIT: {key[:8]}β¦ (expires in {remaining:.0f}s)") | |
| return profile | |
| return None | |
| def put(self, key: str, profile: VoiceProfile): | |
| self._evict_expired() | |
| if key in self._cache: | |
| self._cache.move_to_end(key) | |
| else: | |
| if len(self._cache) >= self._maxsize: | |
| evicted_key, _ = self._cache.popitem(last=False) | |
| logger.debug(f"Voice cache evicted (LRU): {evicted_key[:8]}β¦") | |
| self._cache[key] = (profile, time.time()) | |
| logger.info(f"Voice cache STORED: {key[:8]}β¦ (TTL: {self._ttl}s, size: {len(self._cache)})") | |
| def size(self) -> int: | |
| return len(self._cache) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Audio Loading (robust multi-format support) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_audio_bytes(audio_bytes: bytes, sr: int = 24000) -> np.ndarray: | |
| """Load audio from raw bytes, supporting WAV/MP3/MPEG/M4A/OGG/FLAC/WebM. | |
| Strategy: try soundfile (fast, native) β librosa (ffmpeg backend) β ffmpeg CLI. | |
| """ | |
| buf = io.BytesIO(audio_bytes) | |
| # 1) Try soundfile (handles WAV, FLAC, OGG natively β fastest) | |
| try: | |
| audio, file_sr = soundfile_lib.read(buf) | |
| if audio.ndim > 1: | |
| audio = audio.mean(axis=1) # stereo β mono | |
| if file_sr != sr: | |
| audio = librosa.resample(audio.astype(np.float32), orig_sr=file_sr, target_sr=sr) | |
| return audio.astype(np.float32) | |
| except Exception: | |
| buf.seek(0) | |
| # 2) Try librosa (handles MP3 via audioread + ffmpeg backend) | |
| try: | |
| audio, _ = librosa.load(buf, sr=sr, mono=True) | |
| return audio.astype(np.float32) | |
| except Exception: | |
| buf.seek(0) | |
| # 3) Fallback: use ffmpeg CLI to convert to WAV in memory | |
| try: | |
| proc = subprocess.run( | |
| ["ffmpeg", "-i", "pipe:0", "-f", "wav", "-ac", "1", "-ar", str(sr), "pipe:1"], | |
| input=audio_bytes, capture_output=True, timeout=30, | |
| ) | |
| if proc.returncode == 0 and len(proc.stdout) > 44: | |
| wav_buf = io.BytesIO(proc.stdout) | |
| audio, _ = soundfile_lib.read(wav_buf) | |
| return audio.astype(np.float32) | |
| except Exception: | |
| pass | |
| raise ValueError( | |
| "Could not decode audio file. Supported formats: " | |
| "WAV, MP3, MPEG, M4A, OGG, FLAC, WebM, AAC. " | |
| "Please upload a valid audio file." | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Main Wrapper | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ChatterboxWrapper: | |
| def __init__(self, download_only: bool = False): | |
| self.cfg = Config | |
| os.makedirs(self.cfg.MODELS_DIR, exist_ok=True) | |
| logger.info(f"Downloading ONNX models (dtype={self.cfg.MODEL_DTYPE}) β¦") | |
| self._model_paths = self._download_models() | |
| if download_only: | |
| return | |
| logger.info( | |
| f"Creating ONNX Runtime sessions " | |
| f"(intra_op_threads={self.cfg.CPU_THREADS}, workers={self.cfg.MAX_WORKERS}) β¦" | |
| ) | |
| opts = self._make_session_options() | |
| providers = ["CPUExecutionProvider"] | |
| self.embed_session = ort.InferenceSession(self._model_paths["embed_tokens"], sess_options=opts, providers=providers) | |
| self.encoder_session = ort.InferenceSession(self._model_paths["speech_encoder"], sess_options=opts, providers=providers) | |
| self.lm_session = ort.InferenceSession(self._model_paths["language_model"], sess_options=opts, providers=providers) | |
| self.decoder_session = ort.InferenceSession(self._model_paths["conditional_decoder"], sess_options=opts, providers=providers) | |
| logger.info("Loading tokenizer β¦") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.MODEL_ID) | |
| self._voice_cache = _VoiceCache( | |
| maxsize=self.cfg.VOICE_CACHE_SIZE, | |
| ttl_seconds=self.cfg.VOICE_CACHE_TTL_SEC, | |
| ) | |
| self._builtin_voice_profiles: dict[str, VoiceProfile] = {} | |
| self._builtin_voice_bytes: dict[str, bytes] = {} | |
| self._builtin_voice_by_hash: dict[str, VoiceProfile] = {} | |
| self._voice_alias_to_id: dict[str, str] = {} | |
| self._builtin_voice_catalog: list[dict] = [] | |
| self._default_voice_id: str = "default" | |
| logger.info("Loading built-in voices (HF default + local samples) β¦") | |
| self.default_voice = self._load_builtin_voices() | |
| logger.info("β ChatterboxWrapper ready") | |
| # βββ Model download ββββββββββββββββββββββββββββββββββββββββββ | |
| def _download_models(self) -> dict: | |
| """Download all 4 ONNX components + weight files from HuggingFace.""" | |
| components = ("conditional_decoder", "speech_encoder", "embed_tokens", "language_model") | |
| paths = {} | |
| for name in components: | |
| paths[name] = self._download_component(name, self.cfg.MODEL_DTYPE) | |
| return paths | |
| def _download_component(self, name: str, dtype: str) -> str: | |
| if dtype == "fp32": | |
| filename = f"{name}.onnx" | |
| elif dtype == "q8": | |
| filename = f"{name}_quantized.onnx" | |
| else: | |
| filename = f"{name}_{dtype}.onnx" | |
| graph = hf_hub_download( | |
| self.cfg.MODEL_ID, subfolder="onnx", filename=filename, | |
| cache_dir=self.cfg.MODELS_DIR, | |
| ) | |
| # Download companion weight file | |
| try: | |
| hf_hub_download( | |
| self.cfg.MODEL_ID, subfolder="onnx", filename=f"{filename}_data", | |
| cache_dir=self.cfg.MODELS_DIR, | |
| ) | |
| except Exception: | |
| pass # Some quantized variants embed weights in-graph | |
| return graph | |
| # βββ Session configuration (optimised for 2 vCPU) βββββββββββββ | |
| def _make_session_options(self) -> ort.SessionOptions: | |
| opts = ort.SessionOptions() | |
| # Sequential execution: no parallel graph scheduling overhead | |
| opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL | |
| # Match physical cores exactly (2 for HF Space free tier) | |
| opts.intra_op_num_threads = self.cfg.CPU_THREADS | |
| opts.inter_op_num_threads = 1 | |
| # Full graph optimisations (constant folding, fusion, etc.) | |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| # Disable thread spinning β wastes CPU cycles on busy-wait | |
| opts.add_session_config_entry("session.intra_op.allow_spinning", "0") | |
| opts.add_session_config_entry("session.inter_op.allow_spinning", "0") | |
| # Enable memory optimisations | |
| opts.enable_cpu_mem_arena = True | |
| opts.enable_mem_pattern = True | |
| opts.enable_mem_reuse = True | |
| return opts | |
| # βββ Built-in voices (HF default + local samples) ββββββββββββ | |
| def _download_hf_default_voice_bytes(self) -> bytes: | |
| path = hf_hub_download( | |
| self.cfg.DEFAULT_VOICE_REPO, | |
| filename=self.cfg.DEFAULT_VOICE_FILE, | |
| cache_dir=self.cfg.MODELS_DIR, | |
| ) | |
| return Path(path).read_bytes() | |
| def _list_local_voice_paths(self) -> list[Path]: | |
| wrapper_dir = Path(__file__).resolve().parent | |
| # Support both module-level and repo-root deployment layouts. | |
| candidates = [] | |
| for d in (wrapper_dir, Path.cwd().resolve(), wrapper_dir.parent): | |
| try: | |
| resolved = d.resolve() | |
| except Exception: | |
| continue | |
| if resolved.is_dir() and resolved not in candidates: | |
| candidates.append(resolved) | |
| voices: list[Path] = [] | |
| seen_real_paths: set[str] = set() | |
| for root in candidates: | |
| try: | |
| entries = sorted(root.iterdir(), key=lambda x: x.name.lower()) | |
| except Exception: | |
| continue | |
| for p in entries: | |
| if not p.is_file(): | |
| continue | |
| if p.suffix.lower() not in _SUPPORTED_AUDIO_EXTENSIONS: | |
| continue | |
| real_path = str(p.resolve()) | |
| if real_path in seen_real_paths: | |
| continue | |
| seen_real_paths.add(real_path) | |
| voices.append(p) | |
| logger.info( | |
| "Local voice scan complete: %s files across %s", | |
| len(voices), | |
| [str(x) for x in candidates], | |
| ) | |
| return voices | |
| def _make_unique_voice_id(self, preferred: str) -> str: | |
| base = _slugify(preferred) | |
| candidate = base | |
| idx = 2 | |
| while candidate in self._builtin_voice_profiles: | |
| candidate = f"{base}_{idx}" | |
| idx += 1 | |
| return candidate | |
| def _register_builtin_voice( | |
| self, | |
| *, | |
| preferred_id: str, | |
| display_name: str, | |
| source: str, | |
| source_ref: str, | |
| audio_bytes: bytes, | |
| is_default: bool = False, | |
| ) -> str: | |
| if not audio_bytes: | |
| raise ValueError("Voice file is empty") | |
| voice_id = self._make_unique_voice_id(preferred_id) | |
| audio_hash = hashlib.md5(audio_bytes).hexdigest() | |
| profile = self._voice_cache.get(audio_hash) | |
| if profile is None: | |
| audio = _load_audio_bytes(audio_bytes, sr=self.cfg.SAMPLE_RATE) | |
| profile = self._encode_audio_array(audio, audio_hash=audio_hash) | |
| self._voice_cache.put(audio_hash, profile) | |
| else: | |
| # Keep hash attached to cached profile for metadata/voice-key usage. | |
| profile.audio_hash = audio_hash | |
| self._builtin_voice_profiles[voice_id] = profile | |
| self._builtin_voice_bytes[voice_id] = audio_bytes | |
| if audio_hash: | |
| self._builtin_voice_by_hash[audio_hash] = profile | |
| aliases: list[str] = [] | |
| for alias in (voice_id, _slugify(Path(display_name).stem)): | |
| if alias not in self._voice_alias_to_id: | |
| self._voice_alias_to_id[alias] = voice_id | |
| aliases.append(alias) | |
| if is_default: | |
| self._default_voice_id = voice_id | |
| self._voice_alias_to_id["default"] = voice_id | |
| if "default" not in aliases: | |
| aliases.append("default") | |
| self._builtin_voice_catalog.append( | |
| { | |
| "id": voice_id, | |
| "display_name": display_name, | |
| "source": source, | |
| "source_ref": source_ref, | |
| "aliases": aliases, | |
| "voice_key": audio_hash, | |
| } | |
| ) | |
| return voice_id | |
| def _load_builtin_voices(self) -> VoiceProfile: | |
| # 1) HF default voice (kept as true default fallback) | |
| hf_bytes = self._download_hf_default_voice_bytes() | |
| self._register_builtin_voice( | |
| preferred_id="default_hf_voice", | |
| display_name=self.cfg.DEFAULT_VOICE_FILE, | |
| source="huggingface", | |
| source_ref=f"{self.cfg.DEFAULT_VOICE_REPO}:{self.cfg.DEFAULT_VOICE_FILE}", | |
| audio_bytes=hf_bytes, | |
| is_default=True, | |
| ) | |
| # 2) Local voice samples placed next to app files | |
| for path in self._list_local_voice_paths(): | |
| # Avoid duplicate entry if someone also copied default_voice.wav locally. | |
| if path.name == self.cfg.DEFAULT_VOICE_FILE: | |
| continue | |
| try: | |
| self._register_builtin_voice( | |
| preferred_id=path.stem, | |
| display_name=path.name, | |
| source="local", | |
| source_ref=str(path.name), | |
| audio_bytes=path.read_bytes(), | |
| is_default=False, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Skipping local voice {path.name}: {e}") | |
| default_profile = self._builtin_voice_profiles.get(self._default_voice_id) | |
| if default_profile is None: | |
| raise RuntimeError("Default built-in voice could not be initialized") | |
| logger.info( | |
| f"Built-in voices loaded: {len(self._builtin_voice_catalog)} " | |
| f"(default={self._default_voice_id})" | |
| ) | |
| return default_profile | |
| def list_builtin_voices(self) -> list[dict]: | |
| """Return metadata for startup-preloaded voices.""" | |
| return [dict(v) for v in self._builtin_voice_catalog] | |
| def default_voice_name(self) -> str: | |
| return self._default_voice_id | |
| def resolve_voice_id(self, voice_name: Optional[str]) -> str: | |
| if voice_name is None: | |
| return self._default_voice_id | |
| key = _slugify(str(voice_name)) | |
| if not key: | |
| return self._default_voice_id | |
| voice_id = self._voice_alias_to_id.get(key) | |
| if voice_id is None: | |
| available = ", ".join(sorted(self._voice_alias_to_id.keys())) | |
| raise ValueError(f"Unknown voice '{voice_name}'. Available: {available}") | |
| return voice_id | |
| def get_builtin_voice(self, voice_name: Optional[str]) -> VoiceProfile: | |
| voice_id = self.resolve_voice_id(voice_name) | |
| profile = self._builtin_voice_profiles[voice_id] | |
| if profile.audio_hash: | |
| self._voice_cache.put(profile.audio_hash, profile) | |
| return profile | |
| def get_builtin_voice_bytes(self, voice_name: Optional[str]) -> Optional[bytes]: | |
| voice_id = self.resolve_voice_id(voice_name) | |
| return self._builtin_voice_bytes.get(voice_id) | |
| def get_builtin_voice_by_hash(self, audio_hash: str) -> Optional[VoiceProfile]: | |
| return self._builtin_voice_by_hash.get((audio_hash or "").strip()) | |
| # βββ Voice encoding ββββββββββββββββββββββββββββββββββββββββββ | |
| def encode_voice_from_bytes(self, audio_bytes: bytes) -> VoiceProfile: | |
| """Encode reference audio from raw bytes (in-memory, no disk write). | |
| Accepts: WAV, MP3, MPEG, M4A, OGG, FLAC, WebM, AAC, WMA, Opus. | |
| """ | |
| audio_hash = hashlib.md5(audio_bytes).hexdigest() | |
| cached = self._voice_cache.get(audio_hash) | |
| if cached is not None: | |
| logger.info(f"Voice cache hit: {audio_hash[:8]}β¦") | |
| return cached | |
| # Robust multi-format audio loading | |
| audio = _load_audio_bytes(audio_bytes, sr=self.cfg.SAMPLE_RATE) | |
| # Validate duration | |
| duration = len(audio) / self.cfg.SAMPLE_RATE | |
| if duration < self.cfg.MIN_REF_DURATION_SEC: | |
| raise ValueError( | |
| f"Reference audio too short ({duration:.1f}s). " | |
| f"Minimum: {self.cfg.MIN_REF_DURATION_SEC}s" | |
| ) | |
| if duration > self.cfg.MAX_REF_DURATION_SEC: | |
| audio = audio[: int(self.cfg.MAX_REF_DURATION_SEC * self.cfg.SAMPLE_RATE)] | |
| profile = self._encode_audio_array(audio, audio_hash=audio_hash) | |
| self._voice_cache.put(audio_hash, profile) | |
| return profile | |
| def _encode_audio_array(self, audio: np.ndarray, audio_hash: str = "") -> VoiceProfile: | |
| """Run speech_encoder on a float32 mono audio array.""" | |
| audio_input = audio[np.newaxis, :].astype(np.float32) | |
| cond_emb, prompt_token, speaker_emb, speaker_feat = self.encoder_session.run( | |
| None, {"audio_values": audio_input} | |
| ) | |
| return VoiceProfile( | |
| cond_emb=cond_emb, | |
| prompt_token=prompt_token, | |
| speaker_embeddings=speaker_emb, | |
| speaker_features=speaker_feat, | |
| audio_hash=audio_hash, | |
| ) | |
| # βββ Full generation (non-streaming) ββββββββββββββββββββββββββ | |
| def generate_speech( | |
| self, | |
| text: str, | |
| voice: Optional[VoiceProfile] = None, | |
| max_new_tokens: Optional[int] = None, | |
| repetition_penalty: Optional[float] = None, | |
| ) -> np.ndarray: | |
| """Generate complete audio for the given text.""" | |
| voice = voice or self.default_voice | |
| text = text_processor.sanitize(text.strip()[: self.cfg.MAX_TEXT_LENGTH]) | |
| if not text: | |
| raise ValueError("Text is empty after sanitization") | |
| tokens = self._generate_tokens( | |
| text, voice, | |
| max_new_tokens or self.cfg.MAX_NEW_TOKENS, | |
| repetition_penalty or self.cfg.REPETITION_PENALTY, | |
| ) | |
| return self._decode_tokens(tokens, voice) | |
| # βββ Streaming generation βββββββββββββββββββββββββββββββββββββ | |
| def stream_speech( | |
| self, | |
| text: str, | |
| voice: Optional[VoiceProfile] = None, | |
| max_new_tokens: Optional[int] = None, | |
| repetition_penalty: Optional[float] = None, | |
| is_cancelled: Optional[Callable[[], bool]] = None, | |
| ) -> Generator[np.ndarray, None, None]: | |
| """Yield audio chunks sentence-by-sentence for real-time streaming. | |
| Each sentence is independently processed through the full pipeline | |
| so the first chunk arrives as fast as possible (low TTFB). | |
| Args: | |
| is_cancelled: Optional callable that returns True to abort generation. | |
| Checked between chunks and every 25 autoregressive steps. | |
| """ | |
| voice = voice or self.default_voice | |
| text = text_processor.sanitize(text.strip()[: self.cfg.MAX_TEXT_LENGTH]) | |
| if not text: | |
| return | |
| sentences = text_processor.split_for_streaming(text) | |
| _max = max_new_tokens or self.cfg.MAX_NEW_TOKENS | |
| _rep = repetition_penalty or self.cfg.REPETITION_PENALTY | |
| _check = is_cancelled or (lambda: False) | |
| for i, sentence in enumerate(sentences): | |
| # Check cancellation between chunks | |
| if _check(): | |
| logger.info("Generation cancelled by client (between chunks)") | |
| return | |
| if not sentence.strip(): | |
| continue | |
| t0 = time.perf_counter() | |
| try: | |
| tokens = self._generate_tokens(sentence, voice, _max, _rep, _check) | |
| if _check(): | |
| return | |
| audio = self._decode_tokens(tokens, voice) | |
| elapsed = time.perf_counter() - t0 | |
| audio_duration = len(audio) / self.cfg.SAMPLE_RATE | |
| rtf = elapsed / audio_duration if audio_duration > 0 else 0 | |
| logger.info( | |
| f"Chunk {i + 1}/{len(sentences)}: " | |
| f"{len(sentence)} chars β {audio_duration:.1f}s audio " | |
| f"in {elapsed:.2f}s (RTF: {rtf:.2f}x)" | |
| ) | |
| yield audio | |
| except GenerationCancelled: | |
| logger.info(f"Generation cancelled mid-token at chunk {i + 1}") | |
| return | |
| except Exception as e: | |
| logger.error(f"Error on chunk {i + 1}: {e}") | |
| raise | |
| # βββ Autoregressive token generation (OPTIMISED) ββββββββββββββ | |
| def _generate_tokens( | |
| self, | |
| text: str, | |
| voice: VoiceProfile, | |
| max_new_tokens: int, | |
| repetition_penalty: float, | |
| is_cancelled: Callable[[], bool] = lambda: False, | |
| ) -> np.ndarray: | |
| """Run embed β LM autoregressive loop. Returns raw token array. | |
| Optimisations: | |
| β’ Token list instead of repeated np.concatenate (O(n) β O(1) append) | |
| β’ Unique tokens set for inline repetition penalty (avoids exponential penalty bug) | |
| β’ Pre-allocated attention mask for zero-copy slicing | |
| β’ Correct dimensional routing for step 0 prompt processing | |
| """ | |
| input_ids = self.tokenizer(text, return_tensors="np")["input_ids"].astype(np.int64) | |
| # Pre-allocate collections | |
| token_list: list[int] = [self.cfg.START_SPEECH_TOKEN] | |
| unique_tokens: set[int] = {self.cfg.START_SPEECH_TOKEN} | |
| penalty = repetition_penalty | |
| past_key_values = None | |
| attention_mask_full = None | |
| seq_len = 0 | |
| for step in range(max_new_tokens): | |
| if step > 0 and step % 25 == 0 and is_cancelled(): | |
| raise GenerationCancelled() | |
| embeds = self.embed_session.run(None, {"input_ids": input_ids})[0] | |
| if step == 0: | |
| # Prepend speaker conditioning | |
| embeds = np.concatenate((voice.cond_emb, embeds), axis=1) | |
| batch, seq_len, _ = embeds.shape | |
| past_key_values = { | |
| inp.name: np.zeros( | |
| [batch, self.cfg.NUM_KV_HEADS, 0, self.cfg.HEAD_DIM], | |
| dtype=np.float16 if inp.type == "tensor(float16)" else np.float32, | |
| ) | |
| for inp in self.lm_session.get_inputs() | |
| if "past_key_values" in inp.name | |
| } | |
| # Pre-allocate full attention mask | |
| attention_mask_full = np.ones((batch, seq_len + max_new_tokens), dtype=np.int64) | |
| attention_mask = attention_mask_full[:, :seq_len] | |
| # Step 0 requires position_ids matching prompt sequence length | |
| position_ids = np.arange(seq_len, dtype=np.int64).reshape(batch, -1) | |
| else: | |
| # O(1) zero-copy slice for subsequent steps | |
| attention_mask = attention_mask_full[:, : seq_len + step] | |
| # Single position ID for the single new token | |
| position_ids = np.array([[seq_len + step - 1]], dtype=np.int64) | |
| # Language model forward pass | |
| logits, *present_kv = self.lm_session.run( | |
| None, | |
| dict( | |
| inputs_embeds=embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| **past_key_values, | |
| ), | |
| ) | |
| # ββ Inline repetition penalty + token selection βββββββ | |
| last_logits = logits[0, -1, :].copy() # shape: (vocab_size,) | |
| # Apply repetition penalty strictly to unique tokens to prevent over-penalization | |
| for tok_id in unique_tokens: | |
| if last_logits[tok_id] < 0: | |
| last_logits[tok_id] *= penalty | |
| else: | |
| last_logits[tok_id] /= penalty | |
| next_token = int(np.argmax(last_logits)) | |
| token_list.append(next_token) | |
| unique_tokens.add(next_token) | |
| if next_token == self.cfg.STOP_SPEECH_TOKEN: | |
| break | |
| # Update state for next step | |
| input_ids = np.array([[next_token]], dtype=np.int64) | |
| for j, key in enumerate(past_key_values): | |
| past_key_values[key] = present_kv[j] | |
| return np.array([token_list], dtype=np.int64) | |
| # βββ Token β audio decoding βββββββββββββββββββββββββββββββββββ | |
| def _decode_tokens(self, generated: np.ndarray, voice: VoiceProfile) -> np.ndarray: | |
| """Decode speech tokens to a float32 waveform at 24 kHz.""" | |
| # Strip START token; strip STOP token if present | |
| tokens = generated[:, 1:] | |
| if tokens.shape[1] > 0 and tokens[0, -1] == self.cfg.STOP_SPEECH_TOKEN: | |
| tokens = tokens[:, :-1] | |
| if tokens.shape[1] == 0: | |
| return np.zeros(0, dtype=np.float32) | |
| # Prepend prompt token + append silence | |
| silence = np.full( | |
| (tokens.shape[0], 3), self.cfg.SILENCE_TOKEN, dtype=np.int64 | |
| ) | |
| full_tokens = np.concatenate( | |
| [voice.prompt_token, tokens, silence], axis=1 | |
| ) | |
| wav = self.decoder_session.run( | |
| None, | |
| { | |
| "speech_tokens": full_tokens, | |
| "speaker_embeddings": voice.speaker_embeddings, | |
| "speaker_features": voice.speaker_features, | |
| }, | |
| )[0].squeeze(axis=0) | |
| return wav | |
| # βββ Warmup βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def warmup(self): | |
| """Run a short inference to warm up ONNX sessions and JIT paths.""" | |
| try: | |
| t0 = time.perf_counter() | |
| _ = self.generate_speech("Hello.", self.default_voice, max_new_tokens=32) | |
| logger.info(f"Warmup done in {time.perf_counter() - t0:.2f}s") | |
| except Exception as e: | |
| logger.warning(f"Warmup failed (non-critical): {e}") | |