techfreakworm commited on
Commit
85b2e31
·
unverified ·
1 Parent(s): 829be0a

feat(models): chatterbox-turbo and chatterbox-mtl adapters

Browse files
server/models/chatterbox_mtl.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chatterbox Multilingual adapter (23 languages)."""
2
+ from __future__ import annotations
3
+
4
+ import io
5
+ from typing import Any, ClassVar
6
+
7
+ import soundfile as sf
8
+
9
+ from server.schemas import Lang, ParamSpec
10
+
11
+
12
+ _MTL_LANGS: list[Lang] = [
13
+ Lang(code="ar", label="Arabic"),
14
+ Lang(code="da", label="Danish"),
15
+ Lang(code="de", label="German"),
16
+ Lang(code="el", label="Greek"),
17
+ Lang(code="en", label="English"),
18
+ Lang(code="es", label="Spanish"),
19
+ Lang(code="fi", label="Finnish"),
20
+ Lang(code="fr", label="French"),
21
+ Lang(code="he", label="Hebrew"),
22
+ Lang(code="hi", label="Hindi"),
23
+ Lang(code="it", label="Italian"),
24
+ Lang(code="ja", label="Japanese"),
25
+ Lang(code="ko", label="Korean"),
26
+ Lang(code="ms", label="Malay"),
27
+ Lang(code="nl", label="Dutch"),
28
+ Lang(code="no", label="Norwegian"),
29
+ Lang(code="pl", label="Polish"),
30
+ Lang(code="pt", label="Portuguese"),
31
+ Lang(code="ru", label="Russian"),
32
+ Lang(code="sv", label="Swedish"),
33
+ Lang(code="sw", label="Swahili"),
34
+ Lang(code="tr", label="Turkish"),
35
+ Lang(code="zh", label="Chinese"),
36
+ ]
37
+
38
+
39
+ class Adapter:
40
+ id: ClassVar[str] = "chatterbox-mtl"
41
+ label: ClassVar[str] = "Chatterbox Multilingual"
42
+ description: ClassVar[str] = (
43
+ "23-language voice cloning. Pick a language at generate time."
44
+ )
45
+ languages: ClassVar[list[Lang]] = _MTL_LANGS
46
+ paralinguistic_tags: ClassVar[list[str]] = [] # TBD on first manual run
47
+ supports_voice_clone: ClassVar[bool] = True
48
+ params: ClassVar[list[ParamSpec]] = [
49
+ ParamSpec(name="exaggeration", label="Exaggeration", type="float",
50
+ default=0.5, min=0.0, max=2.0, step=0.05),
51
+ ParamSpec(name="cfg_weight", label="CFG weight", type="float",
52
+ default=0.5, min=0.0, max=1.0, step=0.05),
53
+ ]
54
+
55
+ def __init__(self, device: str) -> None:
56
+ self.device = device
57
+ self._model = None
58
+
59
+ def load(self) -> None:
60
+ from chatterbox.mtl_tts import ChatterboxMultilingualTTS
61
+
62
+ self._model = ChatterboxMultilingualTTS.from_pretrained(device=self.device)
63
+
64
+ def unload(self) -> None:
65
+ self._model = None
66
+
67
+ def generate(
68
+ self,
69
+ text: str,
70
+ reference_wav_path: str | None,
71
+ language: str | None,
72
+ params: dict[str, Any],
73
+ ) -> tuple[bytes, int]:
74
+ if self._model is None:
75
+ raise RuntimeError("model not loaded")
76
+ if not language:
77
+ raise ValueError("language is required for chatterbox-mtl")
78
+ wav = self._model.generate(
79
+ text,
80
+ language_id=language,
81
+ audio_prompt_path=reference_wav_path,
82
+ exaggeration=float(params.get("exaggeration", 0.5)),
83
+ cfg_weight=float(params.get("cfg_weight", 0.5)),
84
+ )
85
+ import numpy as np
86
+ import torch
87
+
88
+ if hasattr(wav, "detach"):
89
+ wav = wav.detach().cpu().numpy()
90
+ if isinstance(wav, torch.Tensor): # pragma: no cover
91
+ wav = wav.numpy()
92
+ arr = np.asarray(wav).squeeze()
93
+ sr = getattr(self._model, "sr", 24000)
94
+ buf = io.BytesIO()
95
+ sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
96
+ return buf.getvalue(), sr
server/models/chatterbox_turbo.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chatterbox Turbo adapter — fast English with paralinguistic tags."""
2
+ from __future__ import annotations
3
+
4
+ import io
5
+ from typing import Any, ClassVar
6
+
7
+ import soundfile as sf
8
+
9
+ from server.schemas import Lang, ParamSpec
10
+
11
+
12
+ class Adapter:
13
+ id: ClassVar[str] = "chatterbox-turbo"
14
+ label: ClassVar[str] = "Chatterbox Turbo"
15
+ description: ClassVar[str] = (
16
+ "Faster, lower-VRAM English variant. Supports [laugh], [cough], [chuckle] tags."
17
+ )
18
+ languages: ClassVar[list[Lang]] = [Lang(code="en", label="English")]
19
+ paralinguistic_tags: ClassVar[list[str]] = ["[laugh]", "[cough]", "[chuckle]"]
20
+ supports_voice_clone: ClassVar[bool] = True
21
+ params: ClassVar[list[ParamSpec]] = [
22
+ ParamSpec(name="cfg_weight", label="CFG weight", type="float",
23
+ default=0.5, min=0.0, max=1.0, step=0.05),
24
+ ParamSpec(name="temperature", label="Temperature", type="float",
25
+ default=0.8, min=0.1, max=1.5, step=0.05),
26
+ ]
27
+
28
+ def __init__(self, device: str) -> None:
29
+ self.device = device
30
+ self._model = None
31
+
32
+ def load(self) -> None:
33
+ from chatterbox.tts_turbo import ChatterboxTurboTTS
34
+
35
+ self._model = ChatterboxTurboTTS.from_pretrained(device=self.device)
36
+
37
+ def unload(self) -> None:
38
+ self._model = None
39
+
40
+ def generate(
41
+ self,
42
+ text: str,
43
+ reference_wav_path: str | None,
44
+ language: str | None,
45
+ params: dict[str, Any],
46
+ ) -> tuple[bytes, int]:
47
+ if self._model is None:
48
+ raise RuntimeError("model not loaded")
49
+ wav = self._model.generate(
50
+ text,
51
+ audio_prompt_path=reference_wav_path,
52
+ cfg_weight=float(params.get("cfg_weight", 0.5)),
53
+ temperature=float(params.get("temperature", 0.8)),
54
+ )
55
+ import numpy as np
56
+ import torch
57
+
58
+ if hasattr(wav, "detach"):
59
+ wav = wav.detach().cpu().numpy()
60
+ if isinstance(wav, torch.Tensor): # pragma: no cover
61
+ wav = wav.numpy()
62
+ arr = np.asarray(wav).squeeze()
63
+ sr = getattr(self._model, "sr", 24000)
64
+ buf = io.BytesIO()
65
+ sf.write(buf, arr, sr, format="WAV", subtype="PCM_16")
66
+ return buf.getvalue(), sr
tests/test_adapter_contract.py CHANGED
@@ -8,6 +8,8 @@ from server.schemas import ParamSpec
8
 
9
  ADAPTER_MODULES = [
10
  "server.models.chatterbox_en",
 
 
11
  ]
12
 
13
 
 
8
 
9
  ADAPTER_MODULES = [
10
  "server.models.chatterbox_en",
11
+ "server.models.chatterbox_turbo",
12
+ "server.models.chatterbox_mtl",
13
  ]
14
 
15