| |
| |
| |
|
|
|
|
| import torch |
| from torch import nn |
| import onnx |
| import torch.nn.utils.parametrize as parametrize |
| from styletts2.models.styletts2 import StyleTTS2Model, StyleTTS2Config |
| from onnx_toolkit import ONNXParser |
| from safetensors.torch import save_file |
| from huggingface_hub import hf_hub_download |
|
|
|
|
| |
| |
| |
| def get_layer_from_key(model: nn.Module, key: str): |
| module = model |
| for part in key.split(".")[:-1]: |
| module = module[int(part)] if part.isdigit() else getattr(module, part) |
| return module |
|
|
|
|
| def is_bidirectional(node): |
| for attr in node.attribute: |
| if attr.name == "direction": |
| value = onnx.helper.get_attribute_value(attr) |
| if isinstance(value, bytes): |
| value = value.decode("utf-8") |
| return value == "bidirectional" |
| return False |
|
|
|
|
| |
| |
| |
| def convert_onnx_lstm(W, R, B, layer_name="lstm", bidirectional=False): |
| def reorder_gates(w): |
| i, o, f, c = torch.chunk(w, 4, dim=0) |
| return torch.cat([i, f, c, o], dim=0) |
|
|
| state_dict = {} |
|
|
| for d in range(W.shape[0]): |
| suffix = "" if d == 0 else "_reverse" |
|
|
| w_ih = reorder_gates(torch.tensor(W[d])) |
| w_hh = reorder_gates(torch.tensor(R[d])) |
|
|
| b_ih, b_hh = torch.chunk(torch.tensor(B[d]), 2, dim=0) |
| b_ih = reorder_gates(b_ih) |
| b_hh = reorder_gates(b_hh) |
|
|
| state_dict[f"{layer_name}.weight_ih_l0{suffix}"] = w_ih |
| state_dict[f"{layer_name}.weight_hh_l0{suffix}"] = w_hh |
| state_dict[f"{layer_name}.bias_ih_l0{suffix}"] = b_ih |
| state_dict[f"{layer_name}.bias_hh_l0{suffix}"] = b_hh |
|
|
| return state_dict |
|
|
|
|
| |
| |
| |
| def convert_model_from_local( |
| onnx_path: str, |
| config, |
| save_path: str |
| ): |
| """ |
| onnx_path: local path (hf_hub_download se mila hua bhi chalega) |
| config: StyleTTS2Config OR dict |
| save_path: output safetensors path |
| """ |
|
|
| |
| if isinstance(config, dict): |
| config = StyleTTS2Config(**config) |
|
|
| model = StyleTTS2Model(config) |
|
|
| |
| for attr in [ |
| "wd", "msd", "mpd", "pitch_extractor", |
| "text_aligner", "diffusion", |
| "predictor_encoder", "style_encoder" |
| ]: |
| if hasattr(model, attr): |
| delattr(model, attr) |
|
|
| |
| for module in model.modules(): |
| if hasattr(module, "parametrizations") and "weight" in module.parametrizations: |
| parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) |
|
|
| state_dict = model.state_dict() |
|
|
| |
| m = ONNXParser(onnx_path) |
|
|
| |
| pytorch_lstm = { |
| name: module for name, module in model.named_modules() |
| if isinstance(module, nn.LSTM) |
| } |
|
|
| onnx_lstm_layers = { |
| i.name.rstrip("/LSTM").lstrip("/").replace("/", "."): i |
| for i in m.find().find_by_op_type("LSTM") |
| } |
|
|
| predefined_dict = {} |
|
|
| for pt_name in pytorch_lstm: |
| if pt_name not in onnx_lstm_layers: |
| continue |
|
|
| node = onnx_lstm_layers[pt_name] |
| block = m.find().find_by_name(node.name, exact=True) |
|
|
| tensors = list(block.tensor().values()) |
| if len(tensors) != 3: |
| continue |
|
|
| w, r, b = tensors |
|
|
| converted = convert_onnx_lstm( |
| w, r, b, |
| pt_name, |
| is_bidirectional(block.single_node) |
| ) |
|
|
| predefined_dict.update(converted) |
|
|
| |
| finder = m.find() |
| new_state_dict = {} |
|
|
| for name, tensor in state_dict.items(): |
| full_key = "kmodel." + name |
| results = finder.find_by_tensor(full_key) |
|
|
| if results: |
| new_state_dict[name] = torch.tensor(results[0].tensor()[full_key]) |
| continue |
|
|
| module = get_layer_from_key(model, name) |
|
|
| if isinstance(module, nn.LSTM) and name in predefined_dict: |
| new_state_dict[name] = predefined_dict[name] |
| continue |
|
|
| new_state_dict[name] = tensor |
|
|
| |
| new_state_dict['decoder.generator.stft.window'] = model.decoder.generator.stft.window.clone() |
| model.load_state_dict(new_state_dict) |
|
|
| |
| final_sd = { |
| k: v.contiguous() if not v.is_contiguous() else v |
| for k, v in model.state_dict().items() |
| } |
| |
| save_file(final_sd, save_path) |
|
|
| from huggingface_hub import hf_hub_download |
| import numpy as np |
| import torch |
| from pathlib import Path |
| file_path = hf_hub_download( |
| repo_id="KittenML/kitten-tts-nano-0.8-fp32", |
| filename="voices.npz" |
| ) |
| keys = [ |
| 'expr-voice-2-m', 'expr-voice-2-f', |
| 'expr-voice-3-m', 'expr-voice-3-f', |
| 'expr-voice-4-m', 'expr-voice-4-f', |
| 'expr-voice-5-m', 'expr-voice-5-f' |
| ] |
|
|
| values = [ |
| 'Bella', 'Jasper', |
| 'Luna', 'Bruno', |
| 'Rosie', 'Hugo', |
| 'Kiki', 'Leo' |
| ] |
| voice_dict = dict(zip(keys, values)) |
| voice_dir = Path("voice_dir") |
| voice_dir.mkdir(exist_ok=True) |
| data = np.load(file_path) |
| for key in data.files: |
| if key in voice_dict: |
| tensor = torch.from_numpy(data[key]) |
| torch.save(tensor, voice_dir / f"{voice_dict[key]}.pt") |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|