| |
| """ |
| Convert original WavTokenizer checkpoint to HuggingFace format. |
| |
| Usage: |
| python convert_wavtokenizer.py \ |
| --config_path configs/wavtokenizer_smalldata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml \ |
| --checkpoint_path checkpoints/wavtokenizer_small_320_24k_4096.ckpt \ |
| --output_dir ./wavtokenizer_hf_converted |
| |
| This will create a HuggingFace-compatible model directory that can be loaded with: |
| model = AutoModel.from_pretrained("./wavtokenizer_hf_converted", trust_remote_code=True) |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import shutil |
| from pathlib import Path |
|
|
| import torch |
| import yaml |
|
|
|
|
| def convert_wavtokenizer(config_path: str, checkpoint_path: str, output_dir: str): |
| """Convert WavTokenizer checkpoint to HuggingFace format.""" |
| |
| print(f"Loading config from: {config_path}") |
| print(f"Loading checkpoint from: {checkpoint_path}") |
| |
| |
| with open(config_path, 'r') as f: |
| yaml_cfg = yaml.safe_load(f) |
| |
| |
| model_args = yaml_cfg.get('model', {}).get('init_args', {}) |
| |
| |
| head_args = model_args.get('head', {}).get('init_args', {}) |
| backbone_args = model_args.get('backbone', {}).get('init_args', {}) |
| quantizer_args = model_args.get('quantizer', {}).get('init_args', {}) |
| feature_extractor_args = model_args.get('feature_extractor', {}).get('init_args', {}) |
| |
| |
| hf_config = { |
| "_name_or_path": "WavTokenizerSmall", |
| "architectures": ["WavTokenizer"], |
| "auto_map": { |
| "AutoConfig": "configuration_wavtokenizer.WavTokenizerConfig", |
| "AutoModel": "modeling_wavtokenizer.WavTokenizer" |
| }, |
| "model_type": "wavtokenizer", |
| |
| |
| "sample_rate": feature_extractor_args.get('sample_rate', 24000), |
| "n_fft": head_args.get('n_fft', 1280), |
| "hop_length": head_args.get('hop_length', 320), |
| "n_mels": feature_extractor_args.get('n_mels', 128), |
| "padding": head_args.get('padding', 'center'), |
| |
| |
| "feature_dim": backbone_args.get('dim', 512), |
| "encoder_dim": 64, |
| "encoder_rates": [8, 5, 4, 2], |
| "latent_dim": backbone_args.get('input_channels', 512), |
| |
| |
| "codebook_size": quantizer_args.get('codebook_size', 4096), |
| "codebook_dim": quantizer_args.get('codebook_dim', 8), |
| "num_quantizers": quantizer_args.get('num_quantizers', 1), |
| |
| |
| "backbone_type": "vocos", |
| "backbone_dim": backbone_args.get('dim', 512), |
| "backbone_num_blocks": backbone_args.get('num_layers', 8), |
| "backbone_intermediate_dim": backbone_args.get('intermediate_dim', 1536), |
| "backbone_kernel_size": 7, |
| "backbone_layer_scale_init_value": 1e-6, |
| |
| |
| "head_type": "istft", |
| "head_dim": head_args.get('n_fft', 1280) // 2 + 1, |
| |
| |
| "use_attention": True, |
| "attention_dim": backbone_args.get('dim', 512), |
| "attention_heads": 8, |
| "attention_layers": 1, |
| |
| "torch_dtype": "float32", |
| "transformers_version": "4.40.0" |
| } |
| |
| |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| config_out_path = os.path.join(output_dir, "config.json") |
| with open(config_out_path, 'w') as f: |
| json.dump(hf_config, f, indent=2) |
| print(f"Saved config to: {config_out_path}") |
| |
| |
| print("Loading checkpoint...") |
| ckpt = torch.load(checkpoint_path, map_location='cpu') |
| state_dict = ckpt.get('state_dict', ckpt) |
| |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| |
| if k.startswith('model.'): |
| k = k[6:] |
| new_state_dict[k] = v |
| |
| |
| model_out_path = os.path.join(output_dir, "pytorch_model.bin") |
| torch.save(new_state_dict, model_out_path) |
| print(f"Saved model weights to: {model_out_path}") |
| |
| |
| script_dir = Path(__file__).parent |
| |
| |
| config_py = script_dir / "configuration_wavtokenizer.py" |
| if config_py.exists(): |
| shutil.copy(config_py, output_dir) |
| print(f"Copied: configuration_wavtokenizer.py") |
| |
| |
| modeling_py = script_dir / "modeling_wavtokenizer.py" |
| if modeling_py.exists(): |
| shutil.copy(modeling_py, output_dir) |
| print(f"Copied: modeling_wavtokenizer.py") |
| |
| |
| readme = script_dir / "README.md" |
| if readme.exists(): |
| shutil.copy(readme, output_dir) |
| print(f"Copied: README.md") |
| |
| print(f"\nConversion complete! Model saved to: {output_dir}") |
| print("\nTo load the model:") |
| print(f' model = AutoModel.from_pretrained("{output_dir}", trust_remote_code=True)') |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Convert WavTokenizer checkpoint to HuggingFace format") |
| parser.add_argument( |
| "--config_path", |
| type=str, |
| required=True, |
| help="Path to WavTokenizer YAML config file" |
| ) |
| parser.add_argument( |
| "--checkpoint_path", |
| type=str, |
| required=True, |
| help="Path to WavTokenizer .ckpt checkpoint file" |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default="./wavtokenizer_hf_converted", |
| help="Output directory for HuggingFace model" |
| ) |
| |
| args = parser.parse_args() |
| convert_wavtokenizer(args.config_path, args.checkpoint_path, args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |