| |
| |
| ''' |
| @Project :Waveformer-main |
| @File :CLAPSep.py |
| @IDE :PyCharm |
| @Author :Aisaka/Hao Ma @SDU |
| @Date :2024/2/28 下午1:12 |
| ''' |
|
|
| import torch |
| from torch import nn |
| import torchaudio |
| import laion_clap |
| from .CLAPSep_decoder import HTSAT_Decoder |
| import copy |
| import loralib as lora |
| from torchlibrosa import ISTFT, STFT |
| from torchlibrosa.stft import magphase |
| import librosa |
|
|
| def set_module(model, submodule_key, module): |
| tokens = submodule_key.split('.') |
| sub_tokens = tokens[:-1] |
| cur_mod = model |
| for s in sub_tokens: |
| cur_mod = getattr(cur_mod, s) |
| setattr(cur_mod, tokens[-1], module) |
|
|
|
|
| def process_model(model, rank): |
| for n, module in model.named_modules(): |
| if 'WindowAttention' in str(type(module)): |
| for n_, layer in module.named_modules(): |
| if isinstance(layer, torch.nn.Linear): |
| lora_layer = lora.Linear(layer.in_features, layer.out_features, r=rank, |
| bias=hasattr(layer, 'bias'), merge_weights=True) |
| lora_layer.weight = layer.weight |
| if hasattr(layer, 'bias'): |
| lora_layer.bias = layer.bias |
| set_module(model, n+'.'+n_, lora_layer) |
| return model |
|
|
|
|
| class CLAPSep(nn.Module): |
| def __init__(self, model_config, CLAP_path, use_lora=True, rank=16, nfft=1024): |
| super().__init__() |
| self.resampler = torchaudio.transforms.Resample(32000, 48000) |
| self.clap_model = laion_clap.CLAP_Module(enable_fusion=False, amodel='HTSAT-base', device='cpu') |
| self.clap_model.load_ckpt(CLAP_path) |
| for p in self.clap_model.parameters(): |
| p.requires_grad = False |
| self.audio_branch = copy.deepcopy(self.clap_model.model.audio_branch) |
| if use_lora: |
| process_model(self.audio_branch, rank) |
| self.decoder_model = HTSAT_Decoder(**model_config) |
| self.stft = STFT(n_fft=nfft, hop_length=320, |
| win_length=nfft, window='hann', center=True, pad_mode='reflect', |
| freeze_parameters=True) |
| self.istft = ISTFT(n_fft=nfft, hop_length=320, |
| win_length=nfft, window='hann', center=True, pad_mode='reflect', |
| freeze_parameters=True) |
| self.features = self.install_forward_hooks() |
|
|
| def wav_reconstruct(self, mask, mag_x, cos_x, sin_x, length): |
| mag_y = torch.nn.functional.relu_(mag_x * mask) |
| cos_y = cos_x |
| sin_y = sin_x |
| pred = self.istft(mag_y * cos_y, mag_y * sin_y, length=length) |
| return pred |
|
|
| def inference_from_data(self, mixed, embed_pos, embed_neg): |
| self.eval() |
| real, imag = self.stft(mixed) |
| mag, cos, sin = magphase(real, imag) |
| self.features.append(mag) |
| with torch.no_grad(): |
| embed = torch.nn.functional.normalize(torch.concat([embed_pos, embed_neg], dim=-1), dim=-1) |
| self.audio_branch({"waveform": self.resampler(mixed)}) |
| mask = self.decoder_model(hidden_state=self.features[-1], skip_features=self.features[:-1], embed=embed) |
| pred = self.wav_reconstruct(mask, mag, cos, sin, length=mixed.size(-1)) |
| del self.features[:] |
| return pred |
|
|
| def install_forward_hooks(self): |
| features = [] |
|
|
| def get_features_list(_, __, output): |
| features.append(output) |
|
|
| def get_features_list_basic_layer(_, __, output): |
| features.append(output[0]) |
|
|
| def spectrogram_padding(_, __, out): |
| return torch.nn.functional.pad(out, (0, 0, 0, 1024 - out.size(2))) |
|
|
| self.audio_branch.spectrogram_extractor.register_forward_hook(spectrogram_padding) |
| self.audio_branch.patch_embed.register_forward_hook(get_features_list) |
| for module in self.audio_branch.layers: |
| module.register_forward_hook(get_features_list_basic_layer) |
| return features |
|
|
| if __name__ == '__main__': |
| model_config = {"lan_embed_dim": 1024, |
| "depths": [1, 1, 1, 1], |
| "embed_dim": 128, |
| "encoder_embed_dim": 128, |
| "phase": False, |
| "spec_factor": 8, |
| "d_attn": 640, |
| "n_masker_layer": 3, |
| "conv": False} |
| CLAP_path = "./music_audioset_epoch_15_esc_90.14.pt" |
|
|
| model = CLAPSep(model_config, CLAP_path) |
| ckpt = torch.load('best_model.ckpt', map_location='cpu') |
| model.load_state_dict(ckpt, strict=False) |
| model.eval() |
| audio, fs = librosa.load("./510_25.221254348754883_mixture.wav", sr=32000) |
| pred = model.inference_from_data(torch.tensor(audio).unsqueeze(0), pos_prompt=[''], neg_prompt=['A vehicle engine revving then powering down.']) |
| import soundfile as sf |
| sf.write('./pred.wav', pred.squeeze().numpy(), 32000) |