| import json |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import soundfile as sf |
| import torchaudio.functional as TAF |
| from transformers import PreTrainedModel, AutoProcessor, AutoModel |
| from .configuration_apex import APEXConfig |
|
|
|
|
|
|
| |
| class SharedBlock(nn.Module): |
| def __init__(self, in_dim, out_dim, dropout): |
| super().__init__() |
| self.block = nn.Sequential( |
| nn.Linear(in_dim, out_dim), |
| nn.BatchNorm1d(out_dim), |
| nn.GELU(), |
| nn.Dropout(dropout) |
| ) |
|
|
| def forward(self, x): |
| return self.block(x) |
|
|
|
|
| class BranchBlock(nn.Module): |
| def __init__(self, in_dim, out_dim, dropout, use_bn=True): |
| super().__init__() |
| layers = [nn.Linear(in_dim, out_dim)] |
| if use_bn: |
| layers.append(nn.BatchNorm1d(out_dim)) |
| layers += [nn.GELU(), nn.Dropout(dropout)] |
| self.block = nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| return self.block(x) |
|
|
|
|
| class TaskBranch(nn.Module): |
| def __init__(self, in_dim, branch_dims, dropout, scale, shift): |
| super().__init__() |
| layers = [] |
| prev = in_dim |
| for dim in branch_dims: |
| layers.append(BranchBlock(prev, dim, dropout=dropout, use_bn=True)) |
| prev = dim |
| layers.append(nn.Linear(prev, 1)) |
| self.branch = nn.Sequential(*layers) |
| self.scale = scale |
| self.shift = shift |
|
|
| def forward(self, x): |
| return torch.sigmoid(self.branch(x)) * self.scale + self.shift |
|
|
|
|
| |
| class APEXModel(PreTrainedModel): |
| config_class = APEXConfig |
| _keys_to_ignore_on_load_missing = [r"mert\..*", r"mert_processor\..*"] |
| _tied_weights_keys = [] |
|
|
| @property |
| def all_tied_weights_keys(self): |
| return {} |
|
|
| def _init_weights(self, module): |
| pass |
| |
| def __init__(self, config: APEXConfig): |
| super().__init__(config) |
|
|
| |
| self.mert_processor = AutoProcessor.from_pretrained( |
| config.mert_model_name, |
| trust_remote_code = True |
| ) |
| with torch.device("cpu"): |
| self.mert = AutoModel.from_pretrained( |
| config.mert_model_name, |
| trust_remote_code = True, |
| device_map = None, |
| low_cpu_mem_usage = False |
| ) |
| self.mert.eval() |
| for param in self.mert.parameters(): |
| param.requires_grad = False |
|
|
| self.target_sr = self.mert_processor.sampling_rate |
|
|
| |
| torch.manual_seed(config.seed) |
| self.aggregator = nn.Conv1d( |
| in_channels = len(config.layer_indices), |
| out_channels = 1, |
| kernel_size = 1 |
| ) |
|
|
| |
| shared_layers = [] |
| prev_dim = config.input_dim |
| for dim in config.shared_dims: |
| shared_layers.append(SharedBlock(prev_dim, dim, dropout=config.dropout_shared)) |
| prev_dim = dim |
| self.shared = nn.Sequential(*shared_layers) |
|
|
| out_dim = config.shared_dims[-1] |
|
|
| |
| self.branch_score_streams = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=100, shift=0) |
| self.branch_score_likes = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=100, shift=0) |
| self.branch_coherence = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1) |
| self.branch_musicality = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1) |
| self.branch_memorability = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1) |
| self.branch_clarity = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1) |
| self.branch_naturalness = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1) |
|
|
| def _init_weights(self, module): |
| pass |
|
|
| def forward(self, embedding): |
| shared = self.shared(embedding) |
| return { |
| "score_streams": self.branch_score_streams(shared).squeeze(1), |
| "score_likes" : self.branch_score_likes(shared).squeeze(1), |
| "coherence" : self.branch_coherence(shared).squeeze(1), |
| "musicality" : self.branch_musicality(shared).squeeze(1), |
| "memorability" : self.branch_memorability(shared).squeeze(1), |
| "clarity" : self.branch_clarity(shared).squeeze(1), |
| "naturalness" : self.branch_naturalness(shared).squeeze(1), |
| } |
|
|
| def _load_audio(self, audio_path): |
| waveform, sr = sf.read(audio_path, dtype="float32") |
| waveform = torch.from_numpy(waveform) |
|
|
| if len(waveform.shape) > 1 and waveform.shape[1] > 1: |
| waveform = waveform.mean(dim=1) |
|
|
| waveform = waveform.to(self.device) |
|
|
| if sr != self.target_sr: |
| waveform = TAF.resample(waveform, sr, self.target_sr) |
|
|
| return waveform |
|
|
| def _extract_embedding(self, waveform): |
| segment_len = self.config.segment_sec * self.target_sr |
| segment_embeddings = [] |
|
|
| for start in range(0, waveform.shape[0], segment_len): |
| segment = waveform[start:start + segment_len] |
| if segment.numel() == 0: |
| break |
|
|
| if segment.shape[0] < segment_len: |
| pad_len = segment_len - segment.shape[0] |
| segment = torch.nn.functional.pad(segment, (0, pad_len)) |
|
|
| inputs = self.mert_processor( |
| segment.cpu().numpy(), |
| sampling_rate = self.target_sr, |
| return_tensors = "pt" |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = self.mert(**inputs, output_hidden_states=True) |
|
|
| all_hidden = torch.stack([ |
| outputs.hidden_states[i].mean(dim=1) |
| for i in self.config.layer_indices |
| ]) |
| all_hidden = all_hidden.squeeze(1) |
|
|
| pooled = self.aggregator( |
| all_hidden.unsqueeze(0) |
| ).squeeze() |
|
|
| segment_embeddings.append(pooled) |
|
|
| del segment, inputs, outputs, all_hidden, pooled |
|
|
| song_embedding = torch.stack(segment_embeddings).mean(dim=0) |
| return song_embedding |
|
|
| @torch.no_grad() |
| def predict(self, audio_path, save_json=None): |
| self.eval() |
| print(f"\nProcessing: {audio_path}") |
|
|
| waveform = self._load_audio(audio_path) |
| duration = waveform.shape[0] / self.target_sr |
| n_segs = int(np.ceil(duration / self.config.segment_sec)) |
| print(f"Duration: {duration:.1f}s | Segments: {n_segs}") |
|
|
| print("Extracting MERT embeddings...") |
| embedding = self._extract_embedding(waveform) |
|
|
| print("Running APEX model...") |
| preds = self.forward(embedding.unsqueeze(0)) |
|
|
| results = { |
| task: float(preds[task].squeeze().cpu()) |
| for task in preds |
| } |
|
|
| print(f"\n{'─'*50}") |
| print(f" APEX Predictions") |
| print(f"{'─'*50}") |
| print(f"\n Popularity:") |
| print(f" {'-'*40}") |
| print(f" {'Streams Score':<20} {results['score_streams']:>8.2f} / 100") |
| print(f" {'Likes Score':<20} {results['score_likes']:>8.2f} / 100") |
| print(f"\n Aesthetic Quality:") |
| print(f" {'-'*40}") |
| for dim in ["coherence", "musicality", "memorability", "clarity", "naturalness"]: |
| print(f" {dim.capitalize():<20} {results[dim]:>8.2f} / 5.00") |
| |
|
|
| if save_json: |
| with open(save_json, "w") as f: |
| json.dump({ |
| "audio_path" : audio_path, |
| "predictions": results |
| }, f, indent=2) |
| print(f"Results saved to {save_json}") |
|
|
| return results |