Jaavid25 commited on
Commit
0919d3c
·
verified ·
1 Parent(s): 40bdd7b

Upload modeling_apex.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_apex.py +219 -0
modeling_apex.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import torchaudio.functional as TAF
8
+ from transformers import PreTrainedModel, AutoProcessor, AutoModel
9
+ from .configuration_apex import APEXConfig
10
+
11
+
12
+ # Building blocks
13
+ class SharedBlock(nn.Module):
14
+ def __init__(self, in_dim, out_dim, dropout):
15
+ super().__init__()
16
+ self.block = nn.Sequential(
17
+ nn.Linear(in_dim, out_dim),
18
+ nn.BatchNorm1d(out_dim),
19
+ nn.GELU(),
20
+ nn.Dropout(dropout)
21
+ )
22
+
23
+ def forward(self, x):
24
+ return self.block(x)
25
+
26
+
27
+ class BranchBlock(nn.Module):
28
+ def __init__(self, in_dim, out_dim, dropout, use_bn=True):
29
+ super().__init__()
30
+ layers = [nn.Linear(in_dim, out_dim)]
31
+ if use_bn:
32
+ layers.append(nn.BatchNorm1d(out_dim))
33
+ layers += [nn.GELU(), nn.Dropout(dropout)]
34
+ self.block = nn.Sequential(*layers)
35
+
36
+ def forward(self, x):
37
+ return self.block(x)
38
+
39
+
40
+ class TaskBranch(nn.Module):
41
+ def __init__(self, in_dim, branch_dims, dropout, scale, shift):
42
+ super().__init__()
43
+ layers = []
44
+ prev = in_dim
45
+ for dim in branch_dims:
46
+ layers.append(BranchBlock(prev, dim, dropout=dropout, use_bn=True))
47
+ prev = dim
48
+ layers.append(nn.Linear(prev, 1))
49
+ self.branch = nn.Sequential(*layers)
50
+ self.scale = scale
51
+ self.shift = shift
52
+
53
+ def forward(self, x):
54
+ return torch.sigmoid(self.branch(x)) * self.scale + self.shift
55
+
56
+
57
+ # APEX MODEL
58
+ class APEXModel(PreTrainedModel):
59
+ config_class = APEXConfig
60
+
61
+ def __init__(self, config: APEXConfig):
62
+ super().__init__(config)
63
+
64
+ # MERT encoder + processor
65
+ self.mert_processor = AutoProcessor.from_pretrained(
66
+ config.mert_model_name, trust_remote_code=True
67
+ )
68
+ self.mert = AutoModel.from_pretrained(
69
+ config.mert_model_name, trust_remote_code=True
70
+ )
71
+ self.mert.eval()
72
+ for param in self.mert.parameters():
73
+ param.requires_grad = False
74
+
75
+ self.target_sr = self.mert_processor.sampling_rate
76
+
77
+ # Conv1d aggregator with fixed seed
78
+ torch.manual_seed(config.seed)
79
+ self.aggregator = nn.Conv1d(
80
+ in_channels = len(config.layer_indices),
81
+ out_channels = 1,
82
+ kernel_size = 1
83
+ )
84
+
85
+ # Shared layers: 768 → 512 → 256
86
+ shared_layers = []
87
+ prev_dim = config.input_dim
88
+ for dim in config.shared_dims:
89
+ shared_layers.append(SharedBlock(prev_dim, dim, dropout=config.dropout_shared))
90
+ prev_dim = dim
91
+ self.shared = nn.Sequential(*shared_layers)
92
+
93
+ out_dim = config.shared_dims[-1] # 256
94
+
95
+ # Task branches: 256 → 128 → 64 → 1
96
+ self.branch_score_streams = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=100, shift=0)
97
+ self.branch_score_likes = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=100, shift=0)
98
+ self.branch_coherence = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
99
+ self.branch_musicality = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
100
+ self.branch_memorability = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
101
+ self.branch_clarity = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
102
+ self.branch_naturalness = TaskBranch(out_dim, config.branch_dims, config.dropout_branch, scale=4, shift=1)
103
+
104
+ def forward(self, embedding):
105
+ shared = self.shared(embedding)
106
+ return {
107
+ "score_streams": self.branch_score_streams(shared).squeeze(1),
108
+ "score_likes" : self.branch_score_likes(shared).squeeze(1),
109
+ "coherence" : self.branch_coherence(shared).squeeze(1),
110
+ "musicality" : self.branch_musicality(shared).squeeze(1),
111
+ "memorability" : self.branch_memorability(shared).squeeze(1),
112
+ "clarity" : self.branch_clarity(shared).squeeze(1),
113
+ "naturalness" : self.branch_naturalness(shared).squeeze(1),
114
+ }
115
+
116
+ def _load_audio(self, audio_path):
117
+ waveform, sr = sf.read(audio_path, dtype="float32")
118
+ waveform = torch.from_numpy(waveform)
119
+
120
+ # Stereo to mono
121
+ if len(waveform.shape) > 1 and waveform.shape[1] > 1:
122
+ waveform = waveform.mean(dim=1)
123
+
124
+ waveform = waveform.to(self.device)
125
+
126
+ # Resample if needed
127
+ if sr != self.target_sr:
128
+ waveform = TAF.resample(waveform, sr, self.target_sr)
129
+
130
+ return waveform
131
+
132
+ def _extract_embedding(self, waveform):
133
+ segment_len = self.config.segment_sec * self.target_sr
134
+ segment_embeddings = []
135
+
136
+ for start in range(0, waveform.shape[0], segment_len):
137
+ segment = waveform[start:start + segment_len]
138
+ if segment.numel() == 0:
139
+ break
140
+
141
+ # Zero-pad if needed
142
+ if segment.shape[0] < segment_len:
143
+ pad_len = segment_len - segment.shape[0]
144
+ segment = torch.nn.functional.pad(segment, (0, pad_len))
145
+
146
+ # MERT forward
147
+ inputs = self.mert_processor(
148
+ segment.cpu().numpy(),
149
+ sampling_rate = self.target_sr,
150
+ return_tensors = "pt"
151
+ )
152
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
153
+
154
+ with torch.no_grad():
155
+ outputs = self.mert(**inputs, output_hidden_states=True)
156
+
157
+ # Extract layers and aggregate
158
+ all_hidden = torch.stack([
159
+ outputs.hidden_states[i].mean(dim=1)
160
+ for i in self.config.layer_indices
161
+ ]) # [4, 1, 768]
162
+ all_hidden = all_hidden.squeeze(1) # [4, 768]
163
+
164
+ # Conv1d aggregation
165
+ pooled = self.aggregator(
166
+ all_hidden.unsqueeze(0)
167
+ ).squeeze() # [768]
168
+
169
+ segment_embeddings.append(pooled)
170
+
171
+ del segment, inputs, outputs, all_hidden, pooled
172
+
173
+ # Average across segments to song-level embedding
174
+ song_embedding = torch.stack(segment_embeddings).mean(dim=0)
175
+ return song_embedding
176
+
177
+ @torch.no_grad()
178
+ def predict(self, audio_path, save_json=None):
179
+ self.eval()
180
+ print(f"\nProcessing: {audio_path}")
181
+
182
+ waveform = self._load_audio(audio_path)
183
+ duration = waveform.shape[0] / self.target_sr
184
+ n_segs = int(np.ceil(duration / self.config.segment_sec))
185
+ print(f"Duration: {duration:.1f}s | Segments: {n_segs}")
186
+
187
+ print("Extracting MERT embeddings...")
188
+ embedding = self._extract_embedding(waveform)
189
+
190
+ print("Running APEX model...")
191
+ preds = self.forward(embedding.unsqueeze(0))
192
+
193
+ results = {
194
+ task: float(preds[task].squeeze().cpu())
195
+ for task in preds
196
+ }
197
+
198
+ print(f"\n{'='*50}")
199
+ print(f" APEX Predictions")
200
+ print(f"{'='*50}")
201
+ print(f"\n Popularity:")
202
+ print(f" {'─'*40}")
203
+ print(f" {'Streams Score':<20} {results['score_streams']:>8.2f} / 100")
204
+ print(f" {'Likes Score':<20} {results['score_likes']:>8.2f} / 100")
205
+ print(f"\n Aesthetic Quality:")
206
+ print(f" {'─'*40}")
207
+ for dim in ["coherence", "musicality", "memorability", "clarity", "naturalness"]:
208
+ print(f" {dim.capitalize():<20} {results[dim]:>8.2f} / 5.00")
209
+ print(f"{'='*50}\n")
210
+
211
+ if save_json:
212
+ with open(save_json, "w") as f:
213
+ json.dump({
214
+ "audio_path" : audio_path,
215
+ "predictions": results
216
+ }, f, indent=2)
217
+ print(f"Results saved to {save_json}")
218
+
219
+ return results