apex / modeling_apex.py
Jaavid25's picture
Update modeling_apex.py
2adcf7c verified
import json
import torch
import torch.nn as nn
import numpy as np
import soundfile as sf
import torchaudio.functional as TAF
from transformers import PreTrainedModel, AutoProcessor, AutoModel
from .configuration_apex import APEXConfig
# BUILDING BLOCKS
class SharedBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout):
super().__init__()
self.block = nn.Sequential(
nn.Linear(in_dim, out_dim),
nn.BatchNorm1d(out_dim),
nn.GELU(),
nn.Dropout(dropout)
)
def forward(self, x):
return self.block(x)
class BranchBlock(nn.Module):
def __init__(self, in_dim, out_dim, dropout, use_bn=True):
super().__init__()
layers = [nn.Linear(in_dim, out_dim)]
if use_bn:
layers.append(nn.BatchNorm1d(out_dim))
layers += [nn.GELU(), nn.Dropout(dropout)]
self.block = nn.Sequential(*layers)
def forward(self, x):
return self.block(x)
class TaskBranch(nn.Module):
def __init__(self, in_dim, branch_dims, dropout, scale, shift):
super().__init__()
layers = []
prev = in_dim
for dim in branch_dims:
layers.append(BranchBlock(prev, dim, dropout=dropout, use_bn=True))
prev = dim
layers.append(nn.Linear(prev, 1))
self.branch = nn.Sequential(*layers)
self.scale = scale
self.shift = shift
def forward(self, x):
return torch.sigmoid(self.branch(x)) * self.scale + self.shift
# APEX MODEL
class APEXModel(PreTrainedModel):
config_class = APEXConfig
_keys_to_ignore_on_load_missing = [r"mert\..*", r"mert_processor\..*"]
_tied_weights_keys = []
@property
def all_tied_weights_keys(self):
return {}
def _init_weights(self, module):
pass
def __init__(self, config: APEXConfig):
super().__init__(config)
# Load MERT processor and encoder fresh from HuggingFace
self.mert_processor = AutoProcessor.from_pretrained(
config.mert_model_name,
trust_remote_code = True
)
with torch.device("cpu"):
self.mert = AutoModel.from_pretrained(
config.mert_model_name,
trust_remote_code = True,
device_map = None,
low_cpu_mem_usage = False
)
self.mert.eval()
for param in self.mert.parameters():
param.requires_grad = False
self.target_sr = self.mert_processor.sampling_rate
# Conv1d aggregator with fixed seed
torch.manual_seed(config.seed)
self.aggregator = nn.Conv1d(
in_channels = len(config.layer_indices),
out_channels = 1,
kernel_size = 1
)
# Shared layers: 768 → 512 → 256
shared_layers = []
prev_dim = config.input_dim
for dim in config.shared_dims:
shared_layers.append(SharedBlock(prev_dim, dim, dropout=config.dropout_shared))
prev_dim = dim
self.shared = nn.Sequential(*shared_layers)
out_dim = config.shared_dims[-1] # 256
# Task branches: 256 → 128 → 64 → 1
self.branch_score_streams = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=100, shift=0)
self.branch_score_likes = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=100, shift=0)
self.branch_coherence = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
self.branch_musicality = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
self.branch_memorability = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
self.branch_clarity = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
self.branch_naturalness = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
def _init_weights(self, module):
pass
def forward(self, embedding):
shared = self.shared(embedding)
return {
"score_streams": self.branch_score_streams(shared).squeeze(1),
"score_likes" : self.branch_score_likes(shared).squeeze(1),
"coherence" : self.branch_coherence(shared).squeeze(1),
"musicality" : self.branch_musicality(shared).squeeze(1),
"memorability" : self.branch_memorability(shared).squeeze(1),
"clarity" : self.branch_clarity(shared).squeeze(1),
"naturalness" : self.branch_naturalness(shared).squeeze(1),
}
def _load_audio(self, audio_path):
waveform, sr = sf.read(audio_path, dtype="float32")
waveform = torch.from_numpy(waveform)
if len(waveform.shape) > 1 and waveform.shape[1] > 1:
waveform = waveform.mean(dim=1)
waveform = waveform.to(self.device)
if sr != self.target_sr:
waveform = TAF.resample(waveform, sr, self.target_sr)
return waveform
def _extract_embedding(self, waveform):
segment_len = self.config.segment_sec * self.target_sr
segment_embeddings = []
for start in range(0, waveform.shape[0], segment_len):
segment = waveform[start:start + segment_len]
if segment.numel() == 0:
break
if segment.shape[0] < segment_len:
pad_len = segment_len - segment.shape[0]
segment = torch.nn.functional.pad(segment, (0, pad_len))
inputs = self.mert_processor(
segment.cpu().numpy(),
sampling_rate = self.target_sr,
return_tensors = "pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.mert(**inputs, output_hidden_states=True)
all_hidden = torch.stack([
outputs.hidden_states[i].mean(dim=1)
for i in self.config.layer_indices
])
all_hidden = all_hidden.squeeze(1)
pooled = self.aggregator(
all_hidden.unsqueeze(0)
).squeeze()
segment_embeddings.append(pooled)
del segment, inputs, outputs, all_hidden, pooled
song_embedding = torch.stack(segment_embeddings).mean(dim=0)
return song_embedding
@torch.no_grad()
def predict(self, audio_path, save_json=None):
self.eval()
print(f"\nProcessing: {audio_path}")
waveform = self._load_audio(audio_path)
duration = waveform.shape[0] / self.target_sr
n_segs = int(np.ceil(duration / self.config.segment_sec))
print(f"Duration: {duration:.1f}s | Segments: {n_segs}")
print("Extracting MERT embeddings...")
embedding = self._extract_embedding(waveform)
print("Running APEX model...")
preds = self.forward(embedding.unsqueeze(0))
results = {
task: float(preds[task].squeeze().cpu())
for task in preds
}
print(f"\n{'─'*50}")
print(f" APEX Predictions")
print(f"{'─'*50}")
print(f"\n Popularity:")
print(f" {'-'*40}")
print(f" {'Streams Score':<20} {results['score_streams']:>8.2f} / 100")
print(f" {'Likes Score':<20} {results['score_likes']:>8.2f} / 100")
print(f"\n Aesthetic Quality:")
print(f" {'-'*40}")
for dim in ["coherence", "musicality", "memorability", "clarity", "naturalness"]:
print(f" {dim.capitalize():<20} {results[dim]:>8.2f} / 5.00")
if save_json:
with open(save_json, "w") as f:
json.dump({
"audio_path" : audio_path,
"predictions": results
}, f, indent=2)
print(f"Results saved to {save_json}")
return results