Upload 2 files
Browse files- config.yaml +105 -0
- main.py +195 -0
config.yaml
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
log_dir: "Models/Output"
|
| 2 |
+
save_freq: 5
|
| 3 |
+
log_interval: 10
|
| 4 |
+
device: "cuda"
|
| 5 |
+
epochs: 50
|
| 6 |
+
batch_size: 8
|
| 7 |
+
max_len: 400
|
| 8 |
+
pretrained_model: ""
|
| 9 |
+
second_stage_load_pretrained: true
|
| 10 |
+
load_only_params: true
|
| 11 |
+
|
| 12 |
+
external_models:
|
| 13 |
+
asr:
|
| 14 |
+
input_dim: 80
|
| 15 |
+
hidden_dim: 256
|
| 16 |
+
n_token: 178
|
| 17 |
+
plbert:
|
| 18 |
+
vocab_size: 178
|
| 19 |
+
hidden_size: 768
|
| 20 |
+
num_attention_heads: 12
|
| 21 |
+
intermediate_size: 2048
|
| 22 |
+
dropout: 0.1
|
| 23 |
+
|
| 24 |
+
data_params:
|
| 25 |
+
train_data: "shethjenil/audiodata"
|
| 26 |
+
root_path: ""
|
| 27 |
+
min_length: 50
|
| 28 |
+
|
| 29 |
+
preprocess_params:
|
| 30 |
+
sr: 24000
|
| 31 |
+
n_fft: 2048
|
| 32 |
+
win_length: 1200
|
| 33 |
+
hop_length: 300
|
| 34 |
+
|
| 35 |
+
model_params:
|
| 36 |
+
multispeaker: true
|
| 37 |
+
dim_in: 64
|
| 38 |
+
hidden_dim: 128
|
| 39 |
+
max_conv_dim: 512
|
| 40 |
+
n_layer: 2
|
| 41 |
+
n_mels: 80
|
| 42 |
+
n_token: 178
|
| 43 |
+
max_dur: 50
|
| 44 |
+
style_dim: 128
|
| 45 |
+
dropout: 0.2
|
| 46 |
+
decoder:
|
| 47 |
+
type: "istftnet"
|
| 48 |
+
hidden_dim: 256
|
| 49 |
+
decoder_out_dim: 256
|
| 50 |
+
asr_res_in: 128
|
| 51 |
+
resblock_kernel_sizes: [3, 3]
|
| 52 |
+
upsample_rates: [10, 6]
|
| 53 |
+
upsample_initial_channel: 256
|
| 54 |
+
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
| 55 |
+
upsample_kernel_sizes: [20, 12]
|
| 56 |
+
gen_istft_n_fft: 20
|
| 57 |
+
gen_istft_hop_size: 5
|
| 58 |
+
disable_complex: true
|
| 59 |
+
slm:
|
| 60 |
+
model: "microsoft/wavlm-base-plus"
|
| 61 |
+
sr: 16000
|
| 62 |
+
hidden: 768
|
| 63 |
+
nlayers: 13
|
| 64 |
+
initial_channel: 64
|
| 65 |
+
diffusion:
|
| 66 |
+
embedding_mask_proba: 0.1
|
| 67 |
+
transformer:
|
| 68 |
+
num_layers: 3
|
| 69 |
+
num_heads: 8
|
| 70 |
+
head_features: 64
|
| 71 |
+
multiplier: 2
|
| 72 |
+
dist:
|
| 73 |
+
sigma_data: 0.2
|
| 74 |
+
estimate_sigma_data: true
|
| 75 |
+
mean: -3.0
|
| 76 |
+
std: 1.0
|
| 77 |
+
|
| 78 |
+
loss_params:
|
| 79 |
+
lambda_mel: 5.0
|
| 80 |
+
lambda_gen: 1.0
|
| 81 |
+
lambda_slm: 1.0
|
| 82 |
+
lambda_mono: 1.0
|
| 83 |
+
lambda_s2s: 1.0
|
| 84 |
+
lambda_F0: 1.0
|
| 85 |
+
lambda_norm: 1.0
|
| 86 |
+
lambda_dur: 1.0
|
| 87 |
+
lambda_ce: 20.0
|
| 88 |
+
lambda_sty: 1.0
|
| 89 |
+
lambda_diff: 1.0
|
| 90 |
+
diff_epoch: 10
|
| 91 |
+
joint_epoch: 30
|
| 92 |
+
|
| 93 |
+
optimizer_params:
|
| 94 |
+
lr: 0.0001
|
| 95 |
+
bert_lr: 0.00001
|
| 96 |
+
ft_lr: 0.0001
|
| 97 |
+
|
| 98 |
+
slmadv_params:
|
| 99 |
+
min_len: 400
|
| 100 |
+
max_len: 500
|
| 101 |
+
batch_percentage: 0.5
|
| 102 |
+
iter: 10
|
| 103 |
+
thresh: 5.0
|
| 104 |
+
scale: 0.01
|
| 105 |
+
sig: 1.5
|
main.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %%bash
|
| 2 |
+
# pip uninstall -q styletts2 -y
|
| 3 |
+
# pip install git+https://github.com/dummyjenil/styletts2 -q
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import onnx
|
| 9 |
+
import torch.nn.utils.parametrize as parametrize
|
| 10 |
+
from styletts2.models.styletts2 import StyleTTS2Model, StyleTTS2Config
|
| 11 |
+
from onnx_toolkit import ONNXParser
|
| 12 |
+
from safetensors.torch import save_file
|
| 13 |
+
from huggingface_hub import hf_hub_download
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# -----------------------------
|
| 17 |
+
# Utils
|
| 18 |
+
# -----------------------------
|
| 19 |
+
def get_layer_from_key(model: nn.Module, key: str):
|
| 20 |
+
module = model
|
| 21 |
+
for part in key.split(".")[:-1]:
|
| 22 |
+
module = module[int(part)] if part.isdigit() else getattr(module, part)
|
| 23 |
+
return module
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def is_bidirectional(node):
|
| 27 |
+
for attr in node.attribute:
|
| 28 |
+
if attr.name == "direction":
|
| 29 |
+
value = onnx.helper.get_attribute_value(attr)
|
| 30 |
+
if isinstance(value, bytes):
|
| 31 |
+
value = value.decode("utf-8")
|
| 32 |
+
return value == "bidirectional"
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# -----------------------------
|
| 37 |
+
# LSTM Conversion
|
| 38 |
+
# -----------------------------
|
| 39 |
+
def convert_onnx_lstm(W, R, B, layer_name="lstm", bidirectional=False):
|
| 40 |
+
def reorder_gates(w):
|
| 41 |
+
i, o, f, c = torch.chunk(w, 4, dim=0)
|
| 42 |
+
return torch.cat([i, f, c, o], dim=0)
|
| 43 |
+
|
| 44 |
+
state_dict = {}
|
| 45 |
+
|
| 46 |
+
for d in range(W.shape[0]):
|
| 47 |
+
suffix = "" if d == 0 else "_reverse"
|
| 48 |
+
|
| 49 |
+
w_ih = reorder_gates(torch.tensor(W[d]))
|
| 50 |
+
w_hh = reorder_gates(torch.tensor(R[d]))
|
| 51 |
+
|
| 52 |
+
b_ih, b_hh = torch.chunk(torch.tensor(B[d]), 2, dim=0)
|
| 53 |
+
b_ih = reorder_gates(b_ih)
|
| 54 |
+
b_hh = reorder_gates(b_hh)
|
| 55 |
+
|
| 56 |
+
state_dict[f"{layer_name}.weight_ih_l0{suffix}"] = w_ih
|
| 57 |
+
state_dict[f"{layer_name}.weight_hh_l0{suffix}"] = w_hh
|
| 58 |
+
state_dict[f"{layer_name}.bias_ih_l0{suffix}"] = b_ih
|
| 59 |
+
state_dict[f"{layer_name}.bias_hh_l0{suffix}"] = b_hh
|
| 60 |
+
|
| 61 |
+
return state_dict
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# -----------------------------
|
| 65 |
+
# MAIN FUNCTION (UPDATED)
|
| 66 |
+
# -----------------------------
|
| 67 |
+
def convert_model_from_local(
|
| 68 |
+
onnx_path: str,
|
| 69 |
+
config,
|
| 70 |
+
save_path: str
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
onnx_path: local path (hf_hub_download se mila hua bhi chalega)
|
| 74 |
+
config: StyleTTS2Config OR dict
|
| 75 |
+
save_path: output safetensors path
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
# ---- Config handling ----
|
| 79 |
+
if isinstance(config, dict):
|
| 80 |
+
config = StyleTTS2Config(**config)
|
| 81 |
+
|
| 82 |
+
model = StyleTTS2Model(config)
|
| 83 |
+
|
| 84 |
+
# Remove unused modules
|
| 85 |
+
for attr in [
|
| 86 |
+
"wd", "msd", "mpd", "pitch_extractor",
|
| 87 |
+
"text_aligner", "diffusion",
|
| 88 |
+
"predictor_encoder", "style_encoder"
|
| 89 |
+
]:
|
| 90 |
+
if hasattr(model, attr):
|
| 91 |
+
delattr(model, attr)
|
| 92 |
+
|
| 93 |
+
# Remove parametrizations
|
| 94 |
+
for module in model.modules():
|
| 95 |
+
if hasattr(module, "parametrizations") and "weight" in module.parametrizations:
|
| 96 |
+
parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
|
| 97 |
+
|
| 98 |
+
state_dict = model.state_dict()
|
| 99 |
+
|
| 100 |
+
# ---- Load ONNX (LOCAL PATH) ----
|
| 101 |
+
m = ONNXParser(onnx_path)
|
| 102 |
+
|
| 103 |
+
# -------- LSTM handling --------
|
| 104 |
+
pytorch_lstm = {
|
| 105 |
+
name: module for name, module in model.named_modules()
|
| 106 |
+
if isinstance(module, nn.LSTM)
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
onnx_lstm_layers = {
|
| 110 |
+
i.name.rstrip("/LSTM").lstrip("/").replace("/", "."): i
|
| 111 |
+
for i in m.find().find_by_op_type("LSTM")
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
predefined_dict = {}
|
| 115 |
+
|
| 116 |
+
for pt_name in pytorch_lstm:
|
| 117 |
+
if pt_name not in onnx_lstm_layers:
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
node = onnx_lstm_layers[pt_name]
|
| 121 |
+
block = m.find().find_by_name(node.name, exact=True)
|
| 122 |
+
|
| 123 |
+
tensors = list(block.tensor().values())
|
| 124 |
+
if len(tensors) != 3:
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
w, r, b = tensors
|
| 128 |
+
|
| 129 |
+
converted = convert_onnx_lstm(
|
| 130 |
+
w, r, b,
|
| 131 |
+
pt_name,
|
| 132 |
+
is_bidirectional(block.single_node)
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
predefined_dict.update(converted)
|
| 136 |
+
|
| 137 |
+
# -------- Build state_dict --------
|
| 138 |
+
finder = m.find()
|
| 139 |
+
new_state_dict = {}
|
| 140 |
+
|
| 141 |
+
for name, tensor in state_dict.items():
|
| 142 |
+
full_key = "kmodel." + name
|
| 143 |
+
results = finder.find_by_tensor(full_key)
|
| 144 |
+
|
| 145 |
+
if results:
|
| 146 |
+
new_state_dict[name] = torch.tensor(results[0].tensor()[full_key])
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
module = get_layer_from_key(model, name)
|
| 150 |
+
|
| 151 |
+
if isinstance(module, nn.LSTM) and name in predefined_dict:
|
| 152 |
+
new_state_dict[name] = predefined_dict[name]
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
new_state_dict[name] = tensor # fallback
|
| 156 |
+
|
| 157 |
+
# Load weights
|
| 158 |
+
new_state_dict['decoder.generator.stft.window'] = model.decoder.generator.stft.window.clone()
|
| 159 |
+
model.load_state_dict(new_state_dict)
|
| 160 |
+
|
| 161 |
+
# Make contiguous
|
| 162 |
+
final_sd = {
|
| 163 |
+
k: v.contiguous() if not v.is_contiguous() else v
|
| 164 |
+
for k, v in model.state_dict().items()
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
save_file(final_sd, save_path)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# config = StyleTTS2Config.from_yaml("config.yaml")
|
| 172 |
+
# convert_model_from_local(
|
| 173 |
+
# onnx_path=hf_hub_download("KittenML/kitten-tts-nano-0.8-fp32","kitten_tts_nano_v0_8.onnx"),
|
| 174 |
+
# config=config,
|
| 175 |
+
# save_path="mini_model.safetensors"
|
| 176 |
+
# )
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# ----
|
| 181 |
+
# quantization baaki hai
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# config.model_params.style_dim = 512
|
| 185 |
+
# config.model_params.hidden_dim = 512
|
| 186 |
+
# config.model_params.decoder.hidden_dim = 1024
|
| 187 |
+
# config.model_params.decoder.decoder_out_dim = 512
|
| 188 |
+
# config.model_params.decoder.asr_res_in = 256
|
| 189 |
+
# config.model_params.decoder.upsample_initial_channel = 512
|
| 190 |
+
# convert_model_from_local(
|
| 191 |
+
# onnx_path=hf_hub_download("KittenML/kitten-tts-mini-0.8","kitten_tts_mini_v0_8.onnx"),
|
| 192 |
+
# config=config,
|
| 193 |
+
# save_path="mini_model.safetensors"
|
| 194 |
+
# )
|
| 195 |
+
|