techfreakworm commited on
Commit
f111e30
·
unverified ·
1 Parent(s): cc6b3e5

feat(models): expand chatterbox-turbo params (seed, top_k, exaggeration, cfg_weight, etc.)

Browse files
Files changed (1) hide show
  1. server/models/chatterbox_turbo.py +41 -5
server/models/chatterbox_turbo.py CHANGED
@@ -31,10 +31,42 @@ class Adapter:
31
  ]
32
  supports_voice_clone: ClassVar[bool] = True
33
  params: ClassVar[list[ParamSpec]] = [
34
- ParamSpec(name="cfg_weight", label="CFG weight", type="float",
35
- default=0.5, min=0.0, max=1.0, step=0.05),
36
- ParamSpec(name="temperature", label="Temperature", type="float",
37
- default=0.8, min=0.1, max=1.5, step=0.05),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ]
39
 
40
  def __init__(self, device: str) -> None:
@@ -62,8 +94,12 @@ class Adapter:
62
  wav = self._model.generate(
63
  text,
64
  audio_prompt_path=reference_wav_path,
65
- cfg_weight=float(params.get("cfg_weight", 0.5)),
 
66
  temperature=float(params.get("temperature", 0.8)),
 
 
 
67
  )
68
  import numpy as np
69
  import torch
 
31
  ]
32
  supports_voice_clone: ClassVar[bool] = True
33
  params: ClassVar[list[ParamSpec]] = [
34
+ ParamSpec(
35
+ name="temperature", label="Temperature", type="float",
36
+ default=0.8, min=0.1, max=1.5, step=0.05,
37
+ group="basic",
38
+ ),
39
+ ParamSpec(
40
+ name="top_p", label="Top p", type="float",
41
+ default=0.95, min=0.0, max=1.0, step=0.01,
42
+ group="basic",
43
+ ),
44
+ ParamSpec(
45
+ name="repetition_penalty", label="Repetition penalty", type="float",
46
+ default=1.2, min=1.0, max=3.0, step=0.05,
47
+ group="basic",
48
+ ),
49
+ ParamSpec(
50
+ name="seed", label="Seed", type="int",
51
+ default=-1, min=-1, step=1,
52
+ help="-1 draws a random seed each time.",
53
+ group="advanced",
54
+ ),
55
+ ParamSpec(
56
+ name="top_k", label="Top k", type="int",
57
+ default=1000, min=1, max=4000, step=1,
58
+ group="advanced",
59
+ ),
60
+ ParamSpec(
61
+ name="exaggeration", label="Exaggeration", type="float",
62
+ default=0.0, min=0.0, max=2.0, step=0.05,
63
+ group="advanced",
64
+ ),
65
+ ParamSpec(
66
+ name="cfg_weight", label="CFG weight", type="float",
67
+ default=0.0, min=0.0, max=1.0, step=0.05,
68
+ group="advanced",
69
+ ),
70
  ]
71
 
72
  def __init__(self, device: str) -> None:
 
94
  wav = self._model.generate(
95
  text,
96
  audio_prompt_path=reference_wav_path,
97
+ exaggeration=float(params.get("exaggeration", 0.0)),
98
+ cfg_weight=float(params.get("cfg_weight", 0.0)),
99
  temperature=float(params.get("temperature", 0.8)),
100
+ top_p=float(params.get("top_p", 0.95)),
101
+ top_k=int(params.get("top_k", 1000)),
102
+ repetition_penalty=float(params.get("repetition_penalty", 1.2)),
103
  )
104
  import numpy as np
105
  import torch