File size: 5,420 Bytes
85b2e31
 
 
 
 
 
 
 
 
b066638
85b2e31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b473465
 
 
200e3fe
b473465
 
 
 
 
200e3fe
b473465
 
 
 
 
200e3fe
b473465
 
 
 
 
200e3fe
b473465
 
 
 
 
200e3fe
b473465
 
 
 
 
200e3fe
b473465
 
 
 
 
200e3fe
b473465
 
85b2e31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b066638
85b2e31
 
 
 
b066638
85b2e31
 
 
 
 
 
b473465
 
 
 
85b2e31
 
 
 
 
 
 
 
 
 
 
 
b066638
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Chatterbox Multilingual adapter (23 languages)."""
from __future__ import annotations

import io
from typing import Any, ClassVar

import soundfile as sf

from server.schemas import Lang, ParamSpec
from server.seed import apply_seed


_MTL_LANGS: list[Lang] = [
    Lang(code="ar", label="Arabic"),
    Lang(code="da", label="Danish"),
    Lang(code="de", label="German"),
    Lang(code="el", label="Greek"),
    Lang(code="en", label="English"),
    Lang(code="es", label="Spanish"),
    Lang(code="fi", label="Finnish"),
    Lang(code="fr", label="French"),
    Lang(code="he", label="Hebrew"),
    Lang(code="hi", label="Hindi"),
    Lang(code="it", label="Italian"),
    Lang(code="ja", label="Japanese"),
    Lang(code="ko", label="Korean"),
    Lang(code="ms", label="Malay"),
    Lang(code="nl", label="Dutch"),
    Lang(code="no", label="Norwegian"),
    Lang(code="pl", label="Polish"),
    Lang(code="pt", label="Portuguese"),
    Lang(code="ru", label="Russian"),
    Lang(code="sv", label="Swedish"),
    Lang(code="sw", label="Swahili"),
    Lang(code="tr", label="Turkish"),
    Lang(code="zh", label="Chinese"),
]


class Adapter:
    id: ClassVar[str] = "chatterbox-mtl"
    label: ClassVar[str] = "Chatterbox Multilingual"
    description: ClassVar[str] = (
        "23-language voice cloning. Pick a language at generate time."
    )
    languages: ClassVar[list[Lang]] = _MTL_LANGS
    paralinguistic_tags: ClassVar[list[str]] = []  # TBD on first manual run
    supports_voice_clone: ClassVar[bool] = True
    params: ClassVar[list[ParamSpec]] = [
        ParamSpec(
            name="exaggeration", label="Exaggeration", type="float",
            default=0.5, min=0.0, max=2.0, step=0.05,
            help="How emotive the speech is. Higher pushes prosody and emphasis; lower stays flat and neutral.",
            group="basic",
        ),
        ParamSpec(
            name="cfg_weight", label="CFG weight", type="float",
            default=0.5, min=0.0, max=1.0, step=0.05,
            help="Classifier-free guidance. Higher sticks closer to the reference voice; lower allows more variation but may drift in identity.",
            group="basic",
        ),
        ParamSpec(
            name="temperature", label="Temperature", type="float",
            default=0.8, min=0.1, max=1.5, step=0.05,
            help="Sampling randomness. Lower = deterministic and safer; higher = more creative but riskier and prone to artifacts.",
            group="basic",
        ),
        ParamSpec(
            name="repetition_penalty", label="Repetition penalty", type="float",
            default=2.0, min=1.0, max=3.0, step=0.05,
            help="Discourages repeating the same tokens. Higher than for English because non-Latin scripts loop more easily.",
            group="basic",
        ),
        ParamSpec(
            name="seed", label="Seed", type="int",
            default=-1, min=-1, step=1,
            help="Reproducibility. -1 draws a fresh random seed every run; any non-negative value pins the result so you can reproduce it.",
            group="advanced",
        ),
        ParamSpec(
            name="min_p", label="Min p", type="float",
            default=0.05, min=0.0, max=1.0, step=0.01,
            help="Cuts off tokens whose probability is below this fraction of the top token's. Higher trims more aggressively.",
            group="advanced",
        ),
        ParamSpec(
            name="top_p", label="Top p", type="float",
            default=1.0, min=0.0, max=1.0, step=0.01,
            help="Nucleus sampling. Keep tokens until cumulative probability reaches this. Lower = safer/conservative.",
            group="advanced",
        ),
    ]

    def __init__(self, device: str) -> None:
        self.device = device
        self._model = None

    def load(self) -> None:
        from chatterbox.mtl_tts import ChatterboxMultilingualTTS

        self._model = ChatterboxMultilingualTTS.from_pretrained(device=self.device)

    def unload(self) -> None:
        self._model = None

    def generate(
        self,
        text: str,
        reference_wav_path: str | None,
        language: str | None,
        params: dict[str, Any],
    ) -> tuple[bytes, int, int]:
        if self._model is None:
            raise RuntimeError("model not loaded")
        if not language:
            raise ValueError("language is required for chatterbox-mtl")
        seed_used = apply_seed(params.get("seed"))
        wav = self._model.generate(
            text,
            language_id=language,
            audio_prompt_path=reference_wav_path,
            exaggeration=float(params.get("exaggeration", 0.5)),
            cfg_weight=float(params.get("cfg_weight", 0.5)),
            temperature=float(params.get("temperature", 0.8)),
            repetition_penalty=float(params.get("repetition_penalty", 2.0)),
            min_p=float(params.get("min_p", 0.05)),
            top_p=float(params.get("top_p", 1.0)),
        )
        import numpy as np
        import torch

        if hasattr(wav, "detach"):
            wav = wav.detach().cpu().numpy()
        if isinstance(wav, torch.Tensor):  # pragma: no cover
            wav = wav.numpy()
        arr = np.asarray(wav).squeeze()
        sr = getattr(self._model, "sr", 24000)
        buf = io.BytesIO()
        sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
        return buf.getvalue(), sr, seed_used