| import pdb |
| from typing import Tuple |
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| import argparse |
| import importlib |
| import json |
| import math |
| import multiprocessing as mp |
| import os |
| import time |
| from argparse import Namespace |
| from pathlib import Path |
|
|
| |
| import scipy |
| import numpy as np |
|
|
| scipy.inf = np.inf |
|
|
| import librosa |
| import torch |
| from ema_pytorch import EMA |
| from loguru import logger |
| from muq import MuQ |
| from musicfm.model.musicfm_25hz import MusicFM25Hz |
| from omegaconf import OmegaConf |
| from tqdm import tqdm |
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from configuration_songformer import SongFormerConfig |
| from model_config import ModelConfig |
|
|
| from model import Model |
| from omegaconf import OmegaConf |
|
|
| |
| MUSICFM_HOME_PATH = "/home/node59_tmpdata3/cbhao/SongFormer_kaiyuan_test/github_test/SongFormer/src/SongFormer/ckpts/MusicFM" |
|
|
| BEFORE_DOWNSAMPLING_FRAME_RATES = 25 |
| AFTER_DOWNSAMPLING_FRAME_RATES = 8.333 |
|
|
| DATASET_LABEL = "SongForm-HX-8Class" |
| DATASET_IDS = [5] |
|
|
| TIME_DUR = 420 |
| INPUT_SAMPLING_RATE = 24000 |
|
|
| from dataset.label2id import DATASET_ID_ALLOWED_LABEL_IDS, DATASET_LABEL_TO_DATASET_ID |
| from postprocessing.functional import postprocess_functional_structure |
|
|
|
|
| def rule_post_processing(msa_list): |
| if len(msa_list) <= 2: |
| return msa_list |
|
|
| result = msa_list.copy() |
|
|
| while len(result) > 2: |
| first_duration = result[1][0] - result[0][0] |
| if first_duration < 1.0 and len(result) > 2: |
| result[0] = (result[0][0], result[1][1]) |
| result = [result[0]] + result[2:] |
| else: |
| break |
|
|
| while len(result) > 2: |
| last_label_duration = result[-1][0] - result[-2][0] |
| if last_label_duration < 1.0: |
| result = result[:-2] + [result[-1]] |
| else: |
| break |
|
|
| while len(result) > 2: |
| if result[0][1] == result[1][1] and result[1][0] <= 10.0: |
| result = [(result[0][0], result[0][1])] + result[2:] |
| else: |
| break |
|
|
| while len(result) > 2: |
| last_duration = result[-1][0] - result[-2][0] |
| if result[-2][1] == result[-3][1] and last_duration <= 10.0: |
| result = result[:-2] + [result[-1]] |
| else: |
| break |
|
|
| return result |
|
|
|
|
| class SongFormerModel(PreTrainedModel): |
| config_class = SongFormerConfig |
|
|
| def __init__(self, config: SongFormerConfig): |
| super().__init__(config) |
| device = "cpu" |
| root_dir = os.environ["SONGFORMER_LOCAL_DIR"] |
| with open(os.path.join(root_dir, "muq_config2.json"), "r") as f: |
| muq_config_file = OmegaConf.load(f) |
| |
| self.muq = MuQ(muq_config_file) |
|
|
| self.musicfm = MusicFM25Hz( |
| is_flash=False, |
| stat_path=os.path.join(root_dir, "msd_stats.json"), |
| |
| ) |
| self.songformer = Model(ModelConfig()) |
|
|
| num_classes = config.num_classes |
| dataset_id2label_mask = {} |
| for key, allowed_ids in DATASET_ID_ALLOWED_LABEL_IDS.items(): |
| dataset_id2label_mask[key] = np.ones(config.num_classes, dtype=bool) |
| dataset_id2label_mask[key][allowed_ids] = False |
|
|
| self.num_classes = num_classes |
| self.dataset_id2label_mask = dataset_id2label_mask |
| self.config = config |
|
|
| def forward(self, input): |
| with torch.no_grad(): |
| INPUT_SAMPLING_RATE = 24000 |
|
|
| device = next(self.parameters()).device |
| |
| if isinstance(input, (torch.Tensor, np.ndarray)): |
| audio = torch.tensor(input).to(device) |
| elif os.path.exists(input): |
| wav, sr = librosa.load(input, sr=INPUT_SAMPLING_RATE) |
| audio = torch.tensor(wav).to(device) |
| else: |
| raise ValueError("input should be a tensor/numpy or a valid file path") |
|
|
| win_size = self.config.win_size |
| hop_size = self.config.hop_size |
| num_classes = self.config.num_classes |
| total_len = ( |
| (audio.shape[0] // INPUT_SAMPLING_RATE) // TIME_DUR |
| ) * TIME_DUR + TIME_DUR |
| total_frames = math.ceil(total_len * AFTER_DOWNSAMPLING_FRAME_RATES) |
|
|
| logits = { |
| "function_logits": np.zeros([total_frames, num_classes]), |
| "boundary_logits": np.zeros([total_frames]), |
| } |
| logits_num = { |
| "function_logits": np.zeros([total_frames, num_classes]), |
| "boundary_logits": np.zeros([total_frames]), |
| } |
|
|
| lens = 0 |
| i = 0 |
| while True: |
| start_idx = i * INPUT_SAMPLING_RATE |
| end_idx = min((i + win_size) * INPUT_SAMPLING_RATE, audio.shape[-1]) |
| if start_idx >= audio.shape[-1]: |
| break |
| if end_idx - start_idx <= 1024: |
| continue |
| audio_seg = audio[start_idx:end_idx] |
|
|
| |
| muq_output = self.muq(audio_seg.unsqueeze(0), output_hidden_states=True) |
| muq_embd_420s = muq_output["hidden_states"][10] |
| del muq_output |
| torch.cuda.empty_cache() |
|
|
| |
| _, musicfm_hidden_states = self.musicfm.get_predictions( |
| audio_seg.unsqueeze(0) |
| ) |
| musicfm_embd_420s = musicfm_hidden_states[10] |
| del musicfm_hidden_states |
| torch.cuda.empty_cache() |
|
|
| wraped_muq_embd_30s = [] |
| wraped_musicfm_embd_30s = [] |
|
|
| for idx_30s in range(i, i + hop_size, 30): |
| start_idx_30s = idx_30s * INPUT_SAMPLING_RATE |
| end_idx_30s = min( |
| (idx_30s + 30) * INPUT_SAMPLING_RATE, |
| audio.shape[-1], |
| (i + hop_size) * INPUT_SAMPLING_RATE, |
| ) |
| if start_idx_30s >= audio.shape[-1]: |
| break |
| if end_idx_30s - start_idx_30s <= 1024: |
| continue |
| wraped_muq_embd_30s.append( |
| self.muq( |
| audio[start_idx_30s:end_idx_30s].unsqueeze(0), |
| output_hidden_states=True, |
| )["hidden_states"][10] |
| ) |
| torch.cuda.empty_cache() |
| wraped_musicfm_embd_30s.append( |
| self.musicfm.get_predictions( |
| audio[start_idx_30s:end_idx_30s].unsqueeze(0) |
| )[1][10] |
| ) |
| torch.cuda.empty_cache() |
|
|
| wraped_muq_embd_30s = torch.concatenate(wraped_muq_embd_30s, dim=1) |
| wraped_musicfm_embd_30s = torch.concatenate( |
| wraped_musicfm_embd_30s, dim=1 |
| ) |
| all_embds = [ |
| wraped_musicfm_embd_30s, |
| wraped_muq_embd_30s, |
| musicfm_embd_420s, |
| muq_embd_420s, |
| ] |
|
|
| if len(all_embds) > 1: |
| embd_lens = [x.shape[1] for x in all_embds] |
| max_embd_len = max(embd_lens) |
| min_embd_len = min(embd_lens) |
| if abs(max_embd_len - min_embd_len) > 4: |
| raise ValueError( |
| f"Embedding shapes differ too much: {max_embd_len} vs {min_embd_len}" |
| ) |
|
|
| for idx in range(len(all_embds)): |
| all_embds[idx] = all_embds[idx][:, :min_embd_len, :] |
|
|
| embd = torch.concatenate(all_embds, axis=-1) |
|
|
| dataset_label = DATASET_LABEL |
| dataset_ids = torch.Tensor(DATASET_IDS).to(device, dtype=torch.long) |
| msa_info, chunk_logits = self.songformer.infer( |
| input_embeddings=embd, |
| dataset_ids=dataset_ids, |
| label_id_masks=torch.Tensor( |
| self.dataset_id2label_mask[ |
| DATASET_LABEL_TO_DATASET_ID[dataset_label] |
| ] |
| ) |
| .to(device, dtype=bool) |
| .unsqueeze(0) |
| .unsqueeze(0), |
| with_logits=True, |
| ) |
|
|
| start_frame = int(i * AFTER_DOWNSAMPLING_FRAME_RATES) |
| end_frame = start_frame + min( |
| math.ceil(hop_size * AFTER_DOWNSAMPLING_FRAME_RATES), |
| chunk_logits["boundary_logits"][0].shape[0], |
| ) |
|
|
| logits["function_logits"][start_frame:end_frame, :] += ( |
| chunk_logits["function_logits"][0].detach().cpu().numpy() |
| ) |
| logits["boundary_logits"][start_frame:end_frame] = ( |
| chunk_logits["boundary_logits"][0].detach().cpu().numpy() |
| ) |
| logits_num["function_logits"][start_frame:end_frame, :] += 1 |
| logits_num["boundary_logits"][start_frame:end_frame] += 1 |
| lens += end_frame - start_frame |
|
|
| i += hop_size |
| logits["function_logits"] /= logits_num["function_logits"] |
| logits["boundary_logits"] /= logits_num["boundary_logits"] |
|
|
| logits["function_logits"] = torch.from_numpy( |
| logits["function_logits"][:lens] |
| ).unsqueeze(0) |
| logits["boundary_logits"] = torch.from_numpy( |
| logits["boundary_logits"][:lens] |
| ).unsqueeze(0) |
|
|
| msa_infer_output = postprocess_functional_structure(logits, self.config) |
|
|
| assert msa_infer_output[-1][-1] == "end" |
| if not self.config.no_rule_post_processing: |
| msa_infer_output = rule_post_processing(msa_infer_output) |
| msa_json = [] |
| for idx in range(len(msa_infer_output) - 1): |
| msa_json.append( |
| { |
| "label": msa_infer_output[idx][1], |
| "start": msa_infer_output[idx][0], |
| "end": msa_infer_output[idx + 1][0], |
| } |
| ) |
| return msa_json |
|
|
| @staticmethod |
| def _fix_state_dict_key_on_load(key: str) -> Tuple[str, bool]: |
| """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight.""" |
|
|
| |
| if key.startswith("muq."): |
| return key, False |
| |
|
|
| |
| |
| if key.endswith("LayerNorm.beta"): |
| return key.replace("LayerNorm.beta", "LayerNorm.bias"), True |
| if key.endswith("LayerNorm.gamma"): |
| return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True |
|
|
| |
| |
| |
| if hasattr(nn.utils.parametrizations, "weight_norm"): |
| if key.endswith("weight_g"): |
| return key.replace( |
| "weight_g", "parametrizations.weight.original0" |
| ), True |
| if key.endswith("weight_v"): |
| return key.replace( |
| "weight_v", "parametrizations.weight.original1" |
| ), True |
| else: |
| if key.endswith("parametrizations.weight.original0"): |
| return key.replace( |
| "parametrizations.weight.original0", "weight_g" |
| ), True |
| if key.endswith("parametrizations.weight.original1"): |
| return key.replace( |
| "parametrizations.weight.original1", "weight_v" |
| ), True |
|
|
| return key, False |
|
|