StyleTTS / main.py
shethjenil's picture
Update main.py
71d986e verified
# %%bash
# pip uninstall -q styletts2 -y
# pip install git+https://github.com/dummyjenil/styletts2 -q
import torch
from torch import nn
import onnx
import torch.nn.utils.parametrize as parametrize
from styletts2.models.styletts2 import StyleTTS2Model, StyleTTS2Config
from onnx_toolkit import ONNXParser
from safetensors.torch import save_file
from huggingface_hub import hf_hub_download
# -----------------------------
# Utils
# -----------------------------
def get_layer_from_key(model: nn.Module, key: str):
module = model
for part in key.split(".")[:-1]:
module = module[int(part)] if part.isdigit() else getattr(module, part)
return module
def is_bidirectional(node):
for attr in node.attribute:
if attr.name == "direction":
value = onnx.helper.get_attribute_value(attr)
if isinstance(value, bytes):
value = value.decode("utf-8")
return value == "bidirectional"
return False
# -----------------------------
# LSTM Conversion
# -----------------------------
def convert_onnx_lstm(W, R, B, layer_name="lstm", bidirectional=False):
def reorder_gates(w):
i, o, f, c = torch.chunk(w, 4, dim=0)
return torch.cat([i, f, c, o], dim=0)
state_dict = {}
for d in range(W.shape[0]):
suffix = "" if d == 0 else "_reverse"
w_ih = reorder_gates(torch.tensor(W[d]))
w_hh = reorder_gates(torch.tensor(R[d]))
b_ih, b_hh = torch.chunk(torch.tensor(B[d]), 2, dim=0)
b_ih = reorder_gates(b_ih)
b_hh = reorder_gates(b_hh)
state_dict[f"{layer_name}.weight_ih_l0{suffix}"] = w_ih
state_dict[f"{layer_name}.weight_hh_l0{suffix}"] = w_hh
state_dict[f"{layer_name}.bias_ih_l0{suffix}"] = b_ih
state_dict[f"{layer_name}.bias_hh_l0{suffix}"] = b_hh
return state_dict
# -----------------------------
# MAIN FUNCTION (UPDATED)
# -----------------------------
def convert_model_from_local(
onnx_path: str,
config,
save_path: str
):
"""
onnx_path: local path (hf_hub_download se mila hua bhi chalega)
config: StyleTTS2Config OR dict
save_path: output safetensors path
"""
# ---- Config handling ----
if isinstance(config, dict):
config = StyleTTS2Config(**config)
model = StyleTTS2Model(config)
# Remove unused modules
for attr in [
"wd", "msd", "mpd", "pitch_extractor",
"text_aligner", "diffusion",
"predictor_encoder", "style_encoder"
]:
if hasattr(model, attr):
delattr(model, attr)
# Remove parametrizations
for module in model.modules():
if hasattr(module, "parametrizations") and "weight" in module.parametrizations:
parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)
state_dict = model.state_dict()
# ---- Load ONNX (LOCAL PATH) ----
m = ONNXParser(onnx_path)
# -------- LSTM handling --------
pytorch_lstm = {
name: module for name, module in model.named_modules()
if isinstance(module, nn.LSTM)
}
onnx_lstm_layers = {
i.name.rstrip("/LSTM").lstrip("/").replace("/", "."): i
for i in m.find().find_by_op_type("LSTM")
}
predefined_dict = {}
for pt_name in pytorch_lstm:
if pt_name not in onnx_lstm_layers:
continue
node = onnx_lstm_layers[pt_name]
block = m.find().find_by_name(node.name, exact=True)
tensors = list(block.tensor().values())
if len(tensors) != 3:
continue
w, r, b = tensors
converted = convert_onnx_lstm(
w, r, b,
pt_name,
is_bidirectional(block.single_node)
)
predefined_dict.update(converted)
# -------- Build state_dict --------
finder = m.find()
new_state_dict = {}
for name, tensor in state_dict.items():
full_key = "kmodel." + name
results = finder.find_by_tensor(full_key)
if results:
new_state_dict[name] = torch.tensor(results[0].tensor()[full_key])
continue
module = get_layer_from_key(model, name)
if isinstance(module, nn.LSTM) and name in predefined_dict:
new_state_dict[name] = predefined_dict[name]
continue
new_state_dict[name] = tensor # fallback
# Load weights
new_state_dict['decoder.generator.stft.window'] = model.decoder.generator.stft.window.clone()
model.load_state_dict(new_state_dict)
# Make contiguous
final_sd = {
k: v.contiguous() if not v.is_contiguous() else v
for k, v in model.state_dict().items()
}
save_file(final_sd, save_path)
from huggingface_hub import hf_hub_download
import numpy as np
import torch
from pathlib import Path
file_path = hf_hub_download(
repo_id="KittenML/kitten-tts-nano-0.8-fp32",
filename="voices.npz"
)
keys = [
'expr-voice-2-m', 'expr-voice-2-f',
'expr-voice-3-m', 'expr-voice-3-f',
'expr-voice-4-m', 'expr-voice-4-f',
'expr-voice-5-m', 'expr-voice-5-f'
]
values = [
'Bella', 'Jasper',
'Luna', 'Bruno',
'Rosie', 'Hugo',
'Kiki', 'Leo'
]
voice_dict = dict(zip(keys, values))
voice_dir = Path("voice_dir")
voice_dir.mkdir(exist_ok=True)
data = np.load(file_path)
for key in data.files:
if key in voice_dict:
tensor = torch.from_numpy(data[key])
torch.save(tensor, voice_dir / f"{voice_dict[key]}.pt")
# config = StyleTTS2Config.from_yaml("config.yaml")
# convert_model_from_local(
# onnx_path=hf_hub_download("KittenML/kitten-tts-nano-0.8-fp32","kitten_tts_nano_v0_8.onnx"),
# config=config,
# save_path="mini_model.safetensors"
# )
# ----
# quantization baaki hai
# config.model_params.style_dim = 512
# config.model_params.hidden_dim = 512
# config.model_params.decoder.hidden_dim = 1024
# config.model_params.decoder.decoder_out_dim = 512
# config.model_params.decoder.asr_res_in = 256
# config.model_params.decoder.upsample_initial_channel = 512
# convert_model_from_local(
# onnx_path=hf_hub_download("KittenML/kitten-tts-mini-0.8","kitten_tts_mini_v0_8.onnx"),
# config=config,
# save_path="mini_model.safetensors"
# )