Lyman commited on
Commit
48a5a98
·
verified ·
1 Parent(s): 676a232

Upload generate_npy.py

Browse files
Files changed (1) hide show
  1. generate_npy.py +93 -0
generate_npy.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import librosa
4
+ import os
5
+ import sys
6
+ import argparse
7
+
8
+ # Add project root to sys.path to ensure tts module is found
9
+ sys.path.append('/mnt/data/MegaTTS3')
10
+
11
+ try:
12
+ from tts.infer_cli import MegaTTS3DiTInfer, hparams
13
+ except ImportError as e:
14
+ print(f"Failed to import MegaTTS3DiTInfer and hparams: {e}")
15
+ sys.exit(1)
16
+
17
+ def generate_npy_file(audio_path, output_npy_path, model, sample_rate=24000):
18
+ """
19
+ Generate and save a .npy file containing the latent representation of an audio file.
20
+
21
+ :param audio_path: Path to the input audio file (e.g., .wav, .mp3).
22
+ :param output_npy_path: Path where the .npy file will be saved.
23
+ :param model: Instance of MegaTTS3DiTInfer with a loaded WaveVAE encoder.
24
+ :param sample_rate: Sample rate for audio (default: 24000).
25
+ :return: True if successful, False otherwise.
26
+ """
27
+ try:
28
+ if not os.path.exists(audio_path):
29
+ raise FileNotFoundError(f"Input audio file not found: {audio_path}")
30
+
31
+ # Ensure output directory exists
32
+ os.makedirs(os.path.dirname(output_npy_path), exist_ok=True)
33
+
34
+ # Load and preprocess audio
35
+ wav, _ = librosa.core.load(audio_path, sr=sample_rate)
36
+ ws = hparams['win_size']
37
+ if len(wav) % ws < ws - 1:
38
+ wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32)
39
+ wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32)
40
+
41
+ # Encode to latent representation
42
+ if model.has_vae_encoder:
43
+ wav = torch.FloatTensor(wav)[None].to(model.device)
44
+ with torch.inference_mode():
45
+ vae_latent = model.wavvae.encode_latent(wav) # Note: Changed from wavvae_en to wavvae
46
+ # Save latent to .npy file
47
+ np.save(output_npy_path, vae_latent.cpu().numpy())
48
+ return True
49
+ else:
50
+ raise ValueError("WaveVAE encoder model is not available. Cannot generate .npy file.")
51
+ except Exception as e:
52
+ print(f"Error generating .npy file: {e}")
53
+ return False
54
+
55
+ def extract_vae_features(input_wav, output_npy):
56
+ """
57
+ Wrapper function to initialize the model and generate the .npy file.
58
+
59
+ :param input_wav: Path to the input WAV file.
60
+ :param output_npy: Path where the .npy file will be saved.
61
+ :return: True if successful, False otherwise.
62
+ """
63
+ try:
64
+ # Initialize the MegaTTS3DiTInfer model
65
+ model = MegaTTS3DiTInfer(ckpt_root='/mnt/data/MegaTTS3/checkpoints')
66
+
67
+ # Generate the .npy file
68
+ success = generate_npy_file(input_wav, output_npy, model)
69
+
70
+ # Clean up model to free memory
71
+ model.wavvae = None
72
+ model.dur_model = None
73
+ model.dit = None
74
+ model.g2p_model = None
75
+ model.aligner_lm = None
76
+ torch.cuda.empty_cache()
77
+
78
+ return success
79
+ except Exception as e:
80
+ print(f"Error in extract_vae_features: {e}")
81
+ return False
82
+
83
+ if __name__ == "__main__":
84
+ parser = argparse.ArgumentParser(description="Extract VAE features from a WAV file and save as .npy")
85
+ parser.add_argument('--input_wav', type=str, required=True, help='输入WAV文件路径 (Path to input WAV file)')
86
+ parser.add_argument('--output_npy', type=str, required=True, help='输出NPY文件路径 (Path to output NPY file)')
87
+ args = parser.parse_args()
88
+
89
+ success = extract_vae_features(args.input_wav, args.output_npy)
90
+ if success:
91
+ print("特征提取完成! (Feature extraction completed!)")
92
+ else:
93
+ print("特征提取失败 (Feature extraction failed)")