feat(web): info icon on every param with hover/click tooltip; add help text to all params
200e3fe unverified | """Chatterbox Multilingual adapter (23 languages).""" | |
| from __future__ import annotations | |
| import io | |
| from typing import Any, ClassVar | |
| import soundfile as sf | |
| from server.schemas import Lang, ParamSpec | |
| from server.seed import apply_seed | |
| _MTL_LANGS: list[Lang] = [ | |
| Lang(code="ar", label="Arabic"), | |
| Lang(code="da", label="Danish"), | |
| Lang(code="de", label="German"), | |
| Lang(code="el", label="Greek"), | |
| Lang(code="en", label="English"), | |
| Lang(code="es", label="Spanish"), | |
| Lang(code="fi", label="Finnish"), | |
| Lang(code="fr", label="French"), | |
| Lang(code="he", label="Hebrew"), | |
| Lang(code="hi", label="Hindi"), | |
| Lang(code="it", label="Italian"), | |
| Lang(code="ja", label="Japanese"), | |
| Lang(code="ko", label="Korean"), | |
| Lang(code="ms", label="Malay"), | |
| Lang(code="nl", label="Dutch"), | |
| Lang(code="no", label="Norwegian"), | |
| Lang(code="pl", label="Polish"), | |
| Lang(code="pt", label="Portuguese"), | |
| Lang(code="ru", label="Russian"), | |
| Lang(code="sv", label="Swedish"), | |
| Lang(code="sw", label="Swahili"), | |
| Lang(code="tr", label="Turkish"), | |
| Lang(code="zh", label="Chinese"), | |
| ] | |
| class Adapter: | |
| id: ClassVar[str] = "chatterbox-mtl" | |
| label: ClassVar[str] = "Chatterbox Multilingual" | |
| description: ClassVar[str] = ( | |
| "23-language voice cloning. Pick a language at generate time." | |
| ) | |
| languages: ClassVar[list[Lang]] = _MTL_LANGS | |
| paralinguistic_tags: ClassVar[list[str]] = [] # TBD on first manual run | |
| supports_voice_clone: ClassVar[bool] = True | |
| params: ClassVar[list[ParamSpec]] = [ | |
| ParamSpec( | |
| name="exaggeration", label="Exaggeration", type="float", | |
| default=0.5, min=0.0, max=2.0, step=0.05, | |
| help="How emotive the speech is. Higher pushes prosody and emphasis; lower stays flat and neutral.", | |
| group="basic", | |
| ), | |
| ParamSpec( | |
| name="cfg_weight", label="CFG weight", type="float", | |
| default=0.5, min=0.0, max=1.0, step=0.05, | |
| help="Classifier-free guidance. Higher sticks closer to the reference voice; lower allows more variation but may drift in identity.", | |
| group="basic", | |
| ), | |
| ParamSpec( | |
| name="temperature", label="Temperature", type="float", | |
| default=0.8, min=0.1, max=1.5, step=0.05, | |
| help="Sampling randomness. Lower = deterministic and safer; higher = more creative but riskier and prone to artifacts.", | |
| group="basic", | |
| ), | |
| ParamSpec( | |
| name="repetition_penalty", label="Repetition penalty", type="float", | |
| default=2.0, min=1.0, max=3.0, step=0.05, | |
| help="Discourages repeating the same tokens. Higher than for English because non-Latin scripts loop more easily.", | |
| group="basic", | |
| ), | |
| ParamSpec( | |
| name="seed", label="Seed", type="int", | |
| default=-1, min=-1, step=1, | |
| help="Reproducibility. -1 draws a fresh random seed every run; any non-negative value pins the result so you can reproduce it.", | |
| group="advanced", | |
| ), | |
| ParamSpec( | |
| name="min_p", label="Min p", type="float", | |
| default=0.05, min=0.0, max=1.0, step=0.01, | |
| help="Cuts off tokens whose probability is below this fraction of the top token's. Higher trims more aggressively.", | |
| group="advanced", | |
| ), | |
| ParamSpec( | |
| name="top_p", label="Top p", type="float", | |
| default=1.0, min=0.0, max=1.0, step=0.01, | |
| help="Nucleus sampling. Keep tokens until cumulative probability reaches this. Lower = safer/conservative.", | |
| group="advanced", | |
| ), | |
| ] | |
| def __init__(self, device: str) -> None: | |
| self.device = device | |
| self._model = None | |
| def load(self) -> None: | |
| from chatterbox.mtl_tts import ChatterboxMultilingualTTS | |
| self._model = ChatterboxMultilingualTTS.from_pretrained(device=self.device) | |
| def unload(self) -> None: | |
| self._model = None | |
| def generate( | |
| self, | |
| text: str, | |
| reference_wav_path: str | None, | |
| language: str | None, | |
| params: dict[str, Any], | |
| ) -> tuple[bytes, int, int]: | |
| if self._model is None: | |
| raise RuntimeError("model not loaded") | |
| if not language: | |
| raise ValueError("language is required for chatterbox-mtl") | |
| seed_used = apply_seed(params.get("seed")) | |
| wav = self._model.generate( | |
| text, | |
| language_id=language, | |
| audio_prompt_path=reference_wav_path, | |
| exaggeration=float(params.get("exaggeration", 0.5)), | |
| cfg_weight=float(params.get("cfg_weight", 0.5)), | |
| temperature=float(params.get("temperature", 0.8)), | |
| repetition_penalty=float(params.get("repetition_penalty", 2.0)), | |
| min_p=float(params.get("min_p", 0.05)), | |
| top_p=float(params.get("top_p", 1.0)), | |
| ) | |
| import numpy as np | |
| import torch | |
| if hasattr(wav, "detach"): | |
| wav = wav.detach().cpu().numpy() | |
| if isinstance(wav, torch.Tensor): # pragma: no cover | |
| wav = wav.numpy() | |
| arr = np.asarray(wav).squeeze() | |
| sr = getattr(self._model, "sr", 24000) | |
| buf = io.BytesIO() | |
| sf.write(buf, arr, sr, format="WAV", subtype="PCM_16") | |
| return buf.getvalue(), sr, seed_used | |