techfreakworm commited on
Commit
b473465
·
unverified ·
1 Parent(s): f111e30

feat(models): expand chatterbox-mtl params (seed, repetition_penalty, min_p, top_p)

Browse files
Files changed (1) hide show
  1. server/models/chatterbox_mtl.py +40 -4
server/models/chatterbox_mtl.py CHANGED
@@ -47,10 +47,42 @@ class Adapter:
47
  paralinguistic_tags: ClassVar[list[str]] = [] # TBD on first manual run
48
  supports_voice_clone: ClassVar[bool] = True
49
  params: ClassVar[list[ParamSpec]] = [
50
- ParamSpec(name="exaggeration", label="Exaggeration", type="float",
51
- default=0.5, min=0.0, max=2.0, step=0.05),
52
- ParamSpec(name="cfg_weight", label="CFG weight", type="float",
53
- default=0.5, min=0.0, max=1.0, step=0.05),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ]
55
 
56
  def __init__(self, device: str) -> None:
@@ -83,6 +115,10 @@ class Adapter:
83
  audio_prompt_path=reference_wav_path,
84
  exaggeration=float(params.get("exaggeration", 0.5)),
85
  cfg_weight=float(params.get("cfg_weight", 0.5)),
 
 
 
 
86
  )
87
  import numpy as np
88
  import torch
 
47
  paralinguistic_tags: ClassVar[list[str]] = [] # TBD on first manual run
48
  supports_voice_clone: ClassVar[bool] = True
49
  params: ClassVar[list[ParamSpec]] = [
50
+ ParamSpec(
51
+ name="exaggeration", label="Exaggeration", type="float",
52
+ default=0.5, min=0.0, max=2.0, step=0.05,
53
+ group="basic",
54
+ ),
55
+ ParamSpec(
56
+ name="cfg_weight", label="CFG weight", type="float",
57
+ default=0.5, min=0.0, max=1.0, step=0.05,
58
+ group="basic",
59
+ ),
60
+ ParamSpec(
61
+ name="temperature", label="Temperature", type="float",
62
+ default=0.8, min=0.1, max=1.5, step=0.05,
63
+ group="basic",
64
+ ),
65
+ ParamSpec(
66
+ name="repetition_penalty", label="Repetition penalty", type="float",
67
+ default=2.0, min=1.0, max=3.0, step=0.05,
68
+ group="basic",
69
+ ),
70
+ ParamSpec(
71
+ name="seed", label="Seed", type="int",
72
+ default=-1, min=-1, step=1,
73
+ help="-1 draws a random seed each time.",
74
+ group="advanced",
75
+ ),
76
+ ParamSpec(
77
+ name="min_p", label="Min p", type="float",
78
+ default=0.05, min=0.0, max=1.0, step=0.01,
79
+ group="advanced",
80
+ ),
81
+ ParamSpec(
82
+ name="top_p", label="Top p", type="float",
83
+ default=1.0, min=0.0, max=1.0, step=0.01,
84
+ group="advanced",
85
+ ),
86
  ]
87
 
88
  def __init__(self, device: str) -> None:
 
115
  audio_prompt_path=reference_wav_path,
116
  exaggeration=float(params.get("exaggeration", 0.5)),
117
  cfg_weight=float(params.get("cfg_weight", 0.5)),
118
+ temperature=float(params.get("temperature", 0.8)),
119
+ repetition_penalty=float(params.get("repetition_penalty", 2.0)),
120
+ min_p=float(params.get("min_p", 0.05)),
121
+ top_p=float(params.get("top_p", 1.0)),
122
  )
123
  import numpy as np
124
  import torch