chatterbox-voice-studio / server /models /chatterbox_mtl.py
techfreakworm's picture
feat(web): info icon on every param with hover/click tooltip; add help text to all params
200e3fe unverified
"""Chatterbox Multilingual adapter (23 languages)."""
from __future__ import annotations
import io
from typing import Any, ClassVar
import soundfile as sf
from server.schemas import Lang, ParamSpec
from server.seed import apply_seed
_MTL_LANGS: list[Lang] = [
Lang(code="ar", label="Arabic"),
Lang(code="da", label="Danish"),
Lang(code="de", label="German"),
Lang(code="el", label="Greek"),
Lang(code="en", label="English"),
Lang(code="es", label="Spanish"),
Lang(code="fi", label="Finnish"),
Lang(code="fr", label="French"),
Lang(code="he", label="Hebrew"),
Lang(code="hi", label="Hindi"),
Lang(code="it", label="Italian"),
Lang(code="ja", label="Japanese"),
Lang(code="ko", label="Korean"),
Lang(code="ms", label="Malay"),
Lang(code="nl", label="Dutch"),
Lang(code="no", label="Norwegian"),
Lang(code="pl", label="Polish"),
Lang(code="pt", label="Portuguese"),
Lang(code="ru", label="Russian"),
Lang(code="sv", label="Swedish"),
Lang(code="sw", label="Swahili"),
Lang(code="tr", label="Turkish"),
Lang(code="zh", label="Chinese"),
]
class Adapter:
id: ClassVar[str] = "chatterbox-mtl"
label: ClassVar[str] = "Chatterbox Multilingual"
description: ClassVar[str] = (
"23-language voice cloning. Pick a language at generate time."
)
languages: ClassVar[list[Lang]] = _MTL_LANGS
paralinguistic_tags: ClassVar[list[str]] = [] # TBD on first manual run
supports_voice_clone: ClassVar[bool] = True
params: ClassVar[list[ParamSpec]] = [
ParamSpec(
name="exaggeration", label="Exaggeration", type="float",
default=0.5, min=0.0, max=2.0, step=0.05,
help="How emotive the speech is. Higher pushes prosody and emphasis; lower stays flat and neutral.",
group="basic",
),
ParamSpec(
name="cfg_weight", label="CFG weight", type="float",
default=0.5, min=0.0, max=1.0, step=0.05,
help="Classifier-free guidance. Higher sticks closer to the reference voice; lower allows more variation but may drift in identity.",
group="basic",
),
ParamSpec(
name="temperature", label="Temperature", type="float",
default=0.8, min=0.1, max=1.5, step=0.05,
help="Sampling randomness. Lower = deterministic and safer; higher = more creative but riskier and prone to artifacts.",
group="basic",
),
ParamSpec(
name="repetition_penalty", label="Repetition penalty", type="float",
default=2.0, min=1.0, max=3.0, step=0.05,
help="Discourages repeating the same tokens. Higher than for English because non-Latin scripts loop more easily.",
group="basic",
),
ParamSpec(
name="seed", label="Seed", type="int",
default=-1, min=-1, step=1,
help="Reproducibility. -1 draws a fresh random seed every run; any non-negative value pins the result so you can reproduce it.",
group="advanced",
),
ParamSpec(
name="min_p", label="Min p", type="float",
default=0.05, min=0.0, max=1.0, step=0.01,
help="Cuts off tokens whose probability is below this fraction of the top token's. Higher trims more aggressively.",
group="advanced",
),
ParamSpec(
name="top_p", label="Top p", type="float",
default=1.0, min=0.0, max=1.0, step=0.01,
help="Nucleus sampling. Keep tokens until cumulative probability reaches this. Lower = safer/conservative.",
group="advanced",
),
]
def __init__(self, device: str) -> None:
self.device = device
self._model = None
def load(self) -> None:
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
self._model = ChatterboxMultilingualTTS.from_pretrained(device=self.device)
def unload(self) -> None:
self._model = None
def generate(
self,
text: str,
reference_wav_path: str | None,
language: str | None,
params: dict[str, Any],
) -> tuple[bytes, int, int]:
if self._model is None:
raise RuntimeError("model not loaded")
if not language:
raise ValueError("language is required for chatterbox-mtl")
seed_used = apply_seed(params.get("seed"))
wav = self._model.generate(
text,
language_id=language,
audio_prompt_path=reference_wav_path,
exaggeration=float(params.get("exaggeration", 0.5)),
cfg_weight=float(params.get("cfg_weight", 0.5)),
temperature=float(params.get("temperature", 0.8)),
repetition_penalty=float(params.get("repetition_penalty", 2.0)),
min_p=float(params.get("min_p", 0.05)),
top_p=float(params.get("top_p", 1.0)),
)
import numpy as np
import torch
if hasattr(wav, "detach"):
wav = wav.detach().cpu().numpy()
if isinstance(wav, torch.Tensor): # pragma: no cover
wav = wav.numpy()
arr = np.asarray(wav).squeeze()
sr = getattr(self._model, "sr", 24000)
buf = io.BytesIO()
sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
return buf.getvalue(), sr, seed_used