import json import torch import torch.nn as nn import numpy as np import soundfile as sf import torchaudio.functional as TAF from transformers import PreTrainedModel, AutoProcessor, AutoModel from .configuration_apex import APEXConfig # BUILDING BLOCKS class SharedBlock(nn.Module): def __init__(self, in_dim, out_dim, dropout): super().__init__() self.block = nn.Sequential( nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.GELU(), nn.Dropout(dropout) ) def forward(self, x): return self.block(x) class BranchBlock(nn.Module): def __init__(self, in_dim, out_dim, dropout, use_bn=True): super().__init__() layers = [nn.Linear(in_dim, out_dim)] if use_bn: layers.append(nn.BatchNorm1d(out_dim)) layers += [nn.GELU(), nn.Dropout(dropout)] self.block = nn.Sequential(*layers) def forward(self, x): return self.block(x) class TaskBranch(nn.Module): def __init__(self, in_dim, branch_dims, dropout, scale, shift): super().__init__() layers = [] prev = in_dim for dim in branch_dims: layers.append(BranchBlock(prev, dim, dropout=dropout, use_bn=True)) prev = dim layers.append(nn.Linear(prev, 1)) self.branch = nn.Sequential(*layers) self.scale = scale self.shift = shift def forward(self, x): return torch.sigmoid(self.branch(x)) * self.scale + self.shift # APEX MODEL class APEXModel(PreTrainedModel): config_class = APEXConfig _keys_to_ignore_on_load_missing = [r"mert\..*", r"mert_processor\..*"] _tied_weights_keys = [] @property def all_tied_weights_keys(self): return {} def _init_weights(self, module): pass def __init__(self, config: APEXConfig): super().__init__(config) # Load MERT processor and encoder fresh from HuggingFace self.mert_processor = AutoProcessor.from_pretrained( config.mert_model_name, trust_remote_code = True ) with torch.device("cpu"): self.mert = AutoModel.from_pretrained( config.mert_model_name, trust_remote_code = True, device_map = None, low_cpu_mem_usage = False ) self.mert.eval() for param in self.mert.parameters(): param.requires_grad = False self.target_sr = self.mert_processor.sampling_rate # Conv1d aggregator with fixed seed torch.manual_seed(config.seed) self.aggregator = nn.Conv1d( in_channels = len(config.layer_indices), out_channels = 1, kernel_size = 1 ) # Shared layers: 768 → 512 → 256 shared_layers = [] prev_dim = config.input_dim for dim in config.shared_dims: shared_layers.append(SharedBlock(prev_dim, dim, dropout=config.dropout_shared)) prev_dim = dim self.shared = nn.Sequential(*shared_layers) out_dim = config.shared_dims[-1] # 256 # Task branches: 256 → 128 → 64 → 1 self.branch_score_streams = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=100, shift=0) self.branch_score_likes = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=100, shift=0) self.branch_coherence = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1) self.branch_musicality = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1) self.branch_memorability = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1) self.branch_clarity = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1) self.branch_naturalness = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1) def _init_weights(self, module): pass def forward(self, embedding): shared = self.shared(embedding) return { "score_streams": self.branch_score_streams(shared).squeeze(1), "score_likes" : self.branch_score_likes(shared).squeeze(1), "coherence" : self.branch_coherence(shared).squeeze(1), "musicality" : self.branch_musicality(shared).squeeze(1), "memorability" : self.branch_memorability(shared).squeeze(1), "clarity" : self.branch_clarity(shared).squeeze(1), "naturalness" : self.branch_naturalness(shared).squeeze(1), } def _load_audio(self, audio_path): waveform, sr = sf.read(audio_path, dtype="float32") waveform = torch.from_numpy(waveform) if len(waveform.shape) > 1 and waveform.shape[1] > 1: waveform = waveform.mean(dim=1) waveform = waveform.to(self.device) if sr != self.target_sr: waveform = TAF.resample(waveform, sr, self.target_sr) return waveform def _extract_embedding(self, waveform): segment_len = self.config.segment_sec * self.target_sr segment_embeddings = [] for start in range(0, waveform.shape[0], segment_len): segment = waveform[start:start + segment_len] if segment.numel() == 0: break if segment.shape[0] < segment_len: pad_len = segment_len - segment.shape[0] segment = torch.nn.functional.pad(segment, (0, pad_len)) inputs = self.mert_processor( segment.cpu().numpy(), sampling_rate = self.target_sr, return_tensors = "pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.mert(**inputs, output_hidden_states=True) all_hidden = torch.stack([ outputs.hidden_states[i].mean(dim=1) for i in self.config.layer_indices ]) all_hidden = all_hidden.squeeze(1) pooled = self.aggregator( all_hidden.unsqueeze(0) ).squeeze() segment_embeddings.append(pooled) del segment, inputs, outputs, all_hidden, pooled song_embedding = torch.stack(segment_embeddings).mean(dim=0) return song_embedding @torch.no_grad() def predict(self, audio_path, save_json=None): self.eval() print(f"\nProcessing: {audio_path}") waveform = self._load_audio(audio_path) duration = waveform.shape[0] / self.target_sr n_segs = int(np.ceil(duration / self.config.segment_sec)) print(f"Duration: {duration:.1f}s | Segments: {n_segs}") print("Extracting MERT embeddings...") embedding = self._extract_embedding(waveform) print("Running APEX model...") preds = self.forward(embedding.unsqueeze(0)) results = { task: float(preds[task].squeeze().cpu()) for task in preds } print(f"\n{'─'*50}") print(f" APEX Predictions") print(f"{'─'*50}") print(f"\n Popularity:") print(f" {'-'*40}") print(f" {'Streams Score':<20} {results['score_streams']:>8.2f} / 100") print(f" {'Likes Score':<20} {results['score_likes']:>8.2f} / 100") print(f"\n Aesthetic Quality:") print(f" {'-'*40}") for dim in ["coherence", "musicality", "memorability", "clarity", "naturalness"]: print(f" {dim.capitalize():<20} {results[dim]:>8.2f} / 5.00") if save_json: with open(save_json, "w") as f: json.dump({ "audio_path" : audio_path, "predictions": results }, f, indent=2) print(f"Results saved to {save_json}") return results