123Test / extract_feature_print.py
Pj12's picture
Update extract_feature_print.py
1b0482d verified
import os
import sys
import json
import traceback
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import soundfile as sf
import librosa
from fairseq.tasks.audio_pretraining import AudioPretrainingTask
from fairseq import checkpoint_utils
from transformers import HubertModel, Wav2Vec2Model
# Environment settings for MPS fallback
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
# Configuration class for device and precision management
class Config:
def __init__(self, device):
self.device = device if torch.cuda.is_available() else "cpu"
self.is_half = self.device != "cpu"
self.version_config_paths = [
os.path.join("", f"{k}.json") for k in ["32k", "40k", "48k", "48k_v2", "32k_v2"]
]
self.json_config = self.load_config_json()
self.device_config()
def load_config_json(self):
configs = {}
for config_file in self.version_config_paths:
config_path = os.path.join("configs", config_file)
with open(config_path, "r") as f:
configs[config_file] = json.load(f)
return configs
def device_config(self):
if self.device.startswith("cuda"):
i_device = int(self.device.split(":")[-1])
gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
self.is_half = gpu_mem > 4 and "V100" in torch.cuda.get_device_name(i_device)
elif torch.backends.mps.is_available():
self.device = "mps"
self.is_half = False
else:
self.device = "cpu"
self.is_half = False
# Model-specific definitions
class HubertModelWithFinalProj(HubertModel):
def __init__(self, config):
super().__init__(config)
self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
def load_hubert_fairseq(model_path, device, is_half):
# Load checkpoint and force audio_pretraining task
saved_state = checkpoint_utils.load_checkpoint_to_cpu(model_path)
saved_cfg = saved_state["cfg"]
task = AudioPretrainingTask.setup_task(saved_cfg.task)
models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task([model_path], task=task)
model = models[0].to(device)
if is_half and device not in ["mps", "cpu"]:
model = model.half()
model.eval()
return {"model": model, "saved_cfg": saved_cfg}
def load_huggingface_model(model_path, device, is_half, model_class=HubertModelWithFinalProj):
dtype = torch.float16 if is_half and "cuda" in device else torch.float32
model = model_class.from_pretrained(model_path).to(device).to(dtype)
model.eval()
return {"model": model}
def hubert_preprocess(feats, saved_cfg):
if saved_cfg.task.normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
return feats
def hubert_prepare_input(feats, device, version):
padding_mask = torch.BoolTensor(feats.shape).fill_(False).to(device)
output_layer = 9 if version == "v1" else 12
return {
"source": feats.half().to(device) if device not in ["mps", "cpu"] else feats.to(device),
"padding_mask": padding_mask,
"output_layer": output_layer,
}
def hubert_extract_features(model, inputs):
with torch.no_grad():
logits = model.extract_features(**inputs)
feats = model.final_proj(logits[0]) if inputs["output_layer"] == 9 else logits[0]
return feats
def general_preprocess(feats, *args):
return feats
def general_prepare_input(feats, device):
return feats.to(device)
def general_extract_features(model, inputs):
with torch.no_grad():
feats = model(inputs)["last_hidden_state"]
return feats
# Model configurations
model_configs = {
"hubert": {
"target_sr": 16000,
"load_model": load_hubert_fairseq,
"preprocess": hubert_preprocess,
"prepare_input": hubert_prepare_input,
"extract_features": hubert_extract_features,
},
"contentvec": {
"target_sr": 16000,
"load_model": lambda path, dev, half: load_huggingface_model(path, dev, half, ContentVecModel),
"preprocess": general_preprocess,
"prepare_input": general_prepare_input,
"extract_features": general_extract_features,
},
"wav2vec": {
"target_sr": 16000,
"load_model": lambda path, dev, half: load_huggingface_model(path, dev, half, Wav2Vec2Model),
"preprocess": general_preprocess,
"prepare_input": general_prepare_input,
"extract_features": general_extract_features,
},
}
# Utility functions
def load_audio(file, target_sr):
audio, sr = sf.read(file.strip())
if audio.ndim > 1:
audio = librosa.to_mono(audio.T)
if sr != target_sr:
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
return audio
def printt(f, strr):
print(strr)
f.write(f"{strr}\n")
f.flush()
# Main script
def main():
# Parse arguments
device = sys.argv[1]
n_part = int(sys.argv[2])
i_part = int(sys.argv[3])
exp_dir = sys.argv[4] if len(sys.argv) == 6 else sys.argv[5]
version = sys.argv[5] if len(sys.argv) == 6 else sys.argv[6]
model_path = sys.argv[7]
model_name = sys.argv[8]
if len(sys.argv) > 6:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[4]
config = Config(device)
log_file = open(f"{exp_dir}/extract_f0_feature.log", "a+")
printt(log_file, f"Args: {sys.argv}")
# Resolve model path and name
custom_mappings = {
"wav2vec" : ("wav2vec_small_960h.pt", "wav2vec"),
"hubert_base": ("hubert_base.pt", "hubert"),
"contentvec_base": ("contentvec_base.pt", "contentvec"),
"hubert_base_japanese" : ("hubert_base_japanese.pt","hubert_base_japanese")
}
if os.path.split(model_path)[-1] == "Custom" and model_name in custom_mappings:
model_path, resolved_model_name = custom_mappings[model_name]
model_name = resolved_model_name
if not os.path.exists(model_path):
printt(log_file, f"Error: {model_path} does not exist.")
sys.exit(1)
# Load model
model_config = model_configs.get(model_name, model_configs[model_name])
model_dict = model_config["load_model"](model_path, config.device, config.is_half)
model = model_dict["model"]
additional_configs = model_dict.get("saved_cfg")
printt(log_file, f"Loaded model from {model_path} on {config.device}")
# Setup directories
feature_dim = 256 if version == "v1" else 768 if model_name != "hubert_large_ll60k" else 1024
wav_path = f"{exp_dir}/1_16k_wavs"
out_path = f"{exp_dir}/3_feature{feature_dim}"
os.makedirs(out_path, exist_ok=True)
# Process audio files
todo = sorted(os.listdir(wav_path))[i_part::n_part]
printt(log_file, f"Total files to process: {len(todo)}")
if not todo:
printt(log_file, "No files to process.")
return
target_sr = model_config["target_sr"]
for idx, file in enumerate(todo):
if not file.endswith(".wav"):
continue
try:
wav_file = f"{wav_path}/{file}"
out_file = f"{out_path}/{file.replace('.wav', '.npy')}"
if os.path.exists(out_file):
continue
# Load and preprocess audio
wav = load_audio(wav_file, target_sr)
feats = torch.from_numpy(wav).float().view(1, -1)
if feats.dim() > 2:
feats = feats.mean(-1)
preprocessed_feats = model_config["preprocess"](feats, additional_configs)
inputs = model_config["prepare_input"](preprocessed_feats, config.device, version)
feats = model_config["extract_features"](model, inputs)
# Save features
feats = feats.squeeze(0).float().cpu().numpy()
if not np.isnan(feats).any():
np.save(out_file, feats, allow_pickle=False)
printt(log_file, f"Processed {file}: {feats.shape}")
else:
printt(log_file, f"{file} contains NaN values")
except Exception:
printt(log_file, traceback.format_exc())
printt(log_file, "Feature extraction completed.")
log_file.close()
if __name__ == "__main__":
main()