# %%bash # pip uninstall -q styletts2 -y # pip install git+https://github.com/dummyjenil/styletts2 -q import torch from torch import nn import onnx import torch.nn.utils.parametrize as parametrize from styletts2.models.styletts2 import StyleTTS2Model, StyleTTS2Config from onnx_toolkit import ONNXParser from safetensors.torch import save_file from huggingface_hub import hf_hub_download # ----------------------------- # Utils # ----------------------------- def get_layer_from_key(model: nn.Module, key: str): module = model for part in key.split(".")[:-1]: module = module[int(part)] if part.isdigit() else getattr(module, part) return module def is_bidirectional(node): for attr in node.attribute: if attr.name == "direction": value = onnx.helper.get_attribute_value(attr) if isinstance(value, bytes): value = value.decode("utf-8") return value == "bidirectional" return False # ----------------------------- # LSTM Conversion # ----------------------------- def convert_onnx_lstm(W, R, B, layer_name="lstm", bidirectional=False): def reorder_gates(w): i, o, f, c = torch.chunk(w, 4, dim=0) return torch.cat([i, f, c, o], dim=0) state_dict = {} for d in range(W.shape[0]): suffix = "" if d == 0 else "_reverse" w_ih = reorder_gates(torch.tensor(W[d])) w_hh = reorder_gates(torch.tensor(R[d])) b_ih, b_hh = torch.chunk(torch.tensor(B[d]), 2, dim=0) b_ih = reorder_gates(b_ih) b_hh = reorder_gates(b_hh) state_dict[f"{layer_name}.weight_ih_l0{suffix}"] = w_ih state_dict[f"{layer_name}.weight_hh_l0{suffix}"] = w_hh state_dict[f"{layer_name}.bias_ih_l0{suffix}"] = b_ih state_dict[f"{layer_name}.bias_hh_l0{suffix}"] = b_hh return state_dict # ----------------------------- # MAIN FUNCTION (UPDATED) # ----------------------------- def convert_model_from_local( onnx_path: str, config, save_path: str ): """ onnx_path: local path (hf_hub_download se mila hua bhi chalega) config: StyleTTS2Config OR dict save_path: output safetensors path """ # ---- Config handling ---- if isinstance(config, dict): config = StyleTTS2Config(**config) model = StyleTTS2Model(config) # Remove unused modules for attr in [ "wd", "msd", "mpd", "pitch_extractor", "text_aligner", "diffusion", "predictor_encoder", "style_encoder" ]: if hasattr(model, attr): delattr(model, attr) # Remove parametrizations for module in model.modules(): if hasattr(module, "parametrizations") and "weight" in module.parametrizations: parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) state_dict = model.state_dict() # ---- Load ONNX (LOCAL PATH) ---- m = ONNXParser(onnx_path) # -------- LSTM handling -------- pytorch_lstm = { name: module for name, module in model.named_modules() if isinstance(module, nn.LSTM) } onnx_lstm_layers = { i.name.rstrip("/LSTM").lstrip("/").replace("/", "."): i for i in m.find().find_by_op_type("LSTM") } predefined_dict = {} for pt_name in pytorch_lstm: if pt_name not in onnx_lstm_layers: continue node = onnx_lstm_layers[pt_name] block = m.find().find_by_name(node.name, exact=True) tensors = list(block.tensor().values()) if len(tensors) != 3: continue w, r, b = tensors converted = convert_onnx_lstm( w, r, b, pt_name, is_bidirectional(block.single_node) ) predefined_dict.update(converted) # -------- Build state_dict -------- finder = m.find() new_state_dict = {} for name, tensor in state_dict.items(): full_key = "kmodel." + name results = finder.find_by_tensor(full_key) if results: new_state_dict[name] = torch.tensor(results[0].tensor()[full_key]) continue module = get_layer_from_key(model, name) if isinstance(module, nn.LSTM) and name in predefined_dict: new_state_dict[name] = predefined_dict[name] continue new_state_dict[name] = tensor # fallback # Load weights new_state_dict['decoder.generator.stft.window'] = model.decoder.generator.stft.window.clone() model.load_state_dict(new_state_dict) # Make contiguous final_sd = { k: v.contiguous() if not v.is_contiguous() else v for k, v in model.state_dict().items() } save_file(final_sd, save_path) from huggingface_hub import hf_hub_download import numpy as np import torch from pathlib import Path file_path = hf_hub_download( repo_id="KittenML/kitten-tts-nano-0.8-fp32", filename="voices.npz" ) keys = [ 'expr-voice-2-m', 'expr-voice-2-f', 'expr-voice-3-m', 'expr-voice-3-f', 'expr-voice-4-m', 'expr-voice-4-f', 'expr-voice-5-m', 'expr-voice-5-f' ] values = [ 'Bella', 'Jasper', 'Luna', 'Bruno', 'Rosie', 'Hugo', 'Kiki', 'Leo' ] voice_dict = dict(zip(keys, values)) voice_dir = Path("voice_dir") voice_dir.mkdir(exist_ok=True) data = np.load(file_path) for key in data.files: if key in voice_dict: tensor = torch.from_numpy(data[key]) torch.save(tensor, voice_dir / f"{voice_dict[key]}.pt") # config = StyleTTS2Config.from_yaml("config.yaml") # convert_model_from_local( # onnx_path=hf_hub_download("KittenML/kitten-tts-nano-0.8-fp32","kitten_tts_nano_v0_8.onnx"), # config=config, # save_path="mini_model.safetensors" # ) # ---- # quantization baaki hai # config.model_params.style_dim = 512 # config.model_params.hidden_dim = 512 # config.model_params.decoder.hidden_dim = 1024 # config.model_params.decoder.decoder_out_dim = 512 # config.model_params.decoder.asr_res_in = 256 # config.model_params.decoder.upsample_initial_channel = 512 # convert_model_from_local( # onnx_path=hf_hub_download("KittenML/kitten-tts-mini-0.8","kitten_tts_mini_v0_8.onnx"), # config=config, # save_path="mini_model.safetensors" # )