File size: 4,388 Bytes
829be0a
 
 
 
 
 
 
 
 
b066638
829be0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200e3fe
cc6b3e5
829be0a
 
 
 
200e3fe
cc6b3e5
829be0a
 
 
 
200e3fe
cc6b3e5
 
 
 
 
200e3fe
cc6b3e5
 
 
 
 
200e3fe
cc6b3e5
 
 
 
 
200e3fe
cc6b3e5
 
 
 
 
200e3fe
cc6b3e5
829be0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b066638
829be0a
 
b066638
829be0a
 
 
 
 
 
cc6b3e5
 
 
829be0a
 
 
 
 
 
 
 
 
 
 
 
b066638
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Chatterbox English adapter (ResembleAI/chatterbox)."""
from __future__ import annotations

import io
from typing import Any, ClassVar

import soundfile as sf

from server.schemas import Lang, ParamSpec
from server.seed import apply_seed


class Adapter:
    id: ClassVar[str] = "chatterbox-en"
    label: ClassVar[str] = "Chatterbox (English)"
    description: ClassVar[str] = (
        "Original Chatterbox English voice cloning with CFG and exaggeration controls."
    )
    languages: ClassVar[list[Lang]] = [Lang(code="en", label="English")]
    paralinguistic_tags: ClassVar[list[str]] = []
    supports_voice_clone: ClassVar[bool] = True
    params: ClassVar[list[ParamSpec]] = [
        ParamSpec(
            name="exaggeration", label="Exaggeration", type="float",
            default=0.5, min=0.0, max=2.0, step=0.05,
            help="How emotive the speech is. Higher pushes prosody, pacing, and emphasis; lower stays flat and neutral.",
            group="basic",
        ),
        ParamSpec(
            name="cfg_weight", label="CFG weight", type="float",
            default=0.5, min=0.0, max=1.0, step=0.05,
            help="Classifier-free guidance. Higher sticks closer to the reference voice; lower allows more variation but may drift in identity.",
            group="basic",
        ),
        ParamSpec(
            name="temperature", label="Temperature", type="float",
            default=0.8, min=0.1, max=1.5, step=0.05,
            help="Sampling randomness. Lower = deterministic and safer; higher = more creative but riskier and prone to artifacts.",
            group="basic",
        ),
        ParamSpec(
            name="seed", label="Seed", type="int",
            default=-1, min=-1, step=1,
            help="Reproducibility. -1 draws a fresh random seed every run; any non-negative value pins the result so you can reproduce it.",
            group="advanced",
        ),
        ParamSpec(
            name="repetition_penalty", label="Repetition penalty", type="float",
            default=1.2, min=1.0, max=3.0, step=0.05,
            help="Discourages repeating the same tokens. >1 reduces stuttering and loops; too high hurts natural fluency.",
            group="advanced",
        ),
        ParamSpec(
            name="min_p", label="Min p", type="float",
            default=0.05, min=0.0, max=1.0, step=0.01,
            help="Cuts off tokens whose probability is below this fraction of the top token's. Higher trims more aggressively.",
            group="advanced",
        ),
        ParamSpec(
            name="top_p", label="Top p", type="float",
            default=1.0, min=0.0, max=1.0, step=0.01,
            help="Nucleus sampling. Keep tokens until cumulative probability reaches this. Lower = safer/conservative.",
            group="advanced",
        ),
    ]

    def __init__(self, device: str) -> None:
        self.device = device
        self._model = None

    def load(self) -> None:
        from chatterbox.tts import ChatterboxTTS

        self._model = ChatterboxTTS.from_pretrained(device=self.device)

    def unload(self) -> None:
        self._model = None

    def generate(
        self,
        text: str,
        reference_wav_path: str | None,
        language: str | None,
        params: dict[str, Any],
    ) -> tuple[bytes, int, int]:
        if self._model is None:
            raise RuntimeError("model not loaded")
        seed_used = apply_seed(params.get("seed"))
        wav = self._model.generate(
            text,
            audio_prompt_path=reference_wav_path,
            exaggeration=float(params.get("exaggeration", 0.5)),
            cfg_weight=float(params.get("cfg_weight", 0.5)),
            temperature=float(params.get("temperature", 0.8)),
            repetition_penalty=float(params.get("repetition_penalty", 1.2)),
            min_p=float(params.get("min_p", 0.05)),
            top_p=float(params.get("top_p", 1.0)),
        )
        import numpy as np
        import torch

        if hasattr(wav, "detach"):
            wav = wav.detach().cpu().numpy()
        if isinstance(wav, torch.Tensor):  # pragma: no cover
            wav = wav.numpy()
        arr = np.asarray(wav).squeeze()
        sr = getattr(self._model, "sr", 24000)
        buf = io.BytesIO()
        sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
        return buf.getvalue(), sr, seed_used