Jaavid25 commited on
Commit
7831472
·
verified ·
1 Parent(s): f59314f

Upload modeling_apex.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_apex.py +14 -16
modeling_apex.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import json
3
  import torch
4
  import torch.nn as nn
@@ -9,7 +8,9 @@ from transformers import PreTrainedModel, AutoProcessor, AutoModel
9
  from .configuration_apex import APEXConfig
10
 
11
 
12
- # Building blocks
 
 
13
  class SharedBlock(nn.Module):
14
  def __init__(self, in_dim, out_dim, dropout):
15
  super().__init__()
@@ -54,18 +55,22 @@ class TaskBranch(nn.Module):
54
  return torch.sigmoid(self.branch(x)) * self.scale + self.shift
55
 
56
 
 
57
  # APEX MODEL
 
58
  class APEXModel(PreTrainedModel):
59
- config_class = APEXConfig
 
60
 
61
  def __init__(self, config: APEXConfig):
62
  super().__init__(config)
63
 
64
- # MERT encoder + processor
65
  self.mert_processor = AutoProcessor.from_pretrained(
66
- config.mert_model_name, trust_remote_code=True
 
67
  )
68
- with torch.device('cpu'):
69
  self.mert = AutoModel.from_pretrained(
70
  config.mert_model_name,
71
  trust_remote_code = True,
@@ -121,13 +126,11 @@ class APEXModel(PreTrainedModel):
121
  waveform, sr = sf.read(audio_path, dtype="float32")
122
  waveform = torch.from_numpy(waveform)
123
 
124
- # Stereo to mono
125
  if len(waveform.shape) > 1 and waveform.shape[1] > 1:
126
  waveform = waveform.mean(dim=1)
127
 
128
  waveform = waveform.to(self.device)
129
 
130
- # Resample if needed
131
  if sr != self.target_sr:
132
  waveform = TAF.resample(waveform, sr, self.target_sr)
133
 
@@ -142,12 +145,10 @@ class APEXModel(PreTrainedModel):
142
  if segment.numel() == 0:
143
  break
144
 
145
- # Zero-pad if needed
146
  if segment.shape[0] < segment_len:
147
  pad_len = segment_len - segment.shape[0]
148
  segment = torch.nn.functional.pad(segment, (0, pad_len))
149
 
150
- # MERT forward
151
  inputs = self.mert_processor(
152
  segment.cpu().numpy(),
153
  sampling_rate = self.target_sr,
@@ -158,23 +159,20 @@ class APEXModel(PreTrainedModel):
158
  with torch.no_grad():
159
  outputs = self.mert(**inputs, output_hidden_states=True)
160
 
161
- # Extract layers and aggregate
162
  all_hidden = torch.stack([
163
  outputs.hidden_states[i].mean(dim=1)
164
  for i in self.config.layer_indices
165
- ]) # [4, 1, 768]
166
- all_hidden = all_hidden.squeeze(1) # [4, 768]
167
 
168
- # Conv1d aggregation
169
  pooled = self.aggregator(
170
  all_hidden.unsqueeze(0)
171
- ).squeeze() # [768]
172
 
173
  segment_embeddings.append(pooled)
174
 
175
  del segment, inputs, outputs, all_hidden, pooled
176
 
177
- # Average across segments to song-level embedding
178
  song_embedding = torch.stack(segment_embeddings).mean(dim=0)
179
  return song_embedding
180
 
 
 
1
  import json
2
  import torch
3
  import torch.nn as nn
 
8
  from .configuration_apex import APEXConfig
9
 
10
 
11
+ # -------------------------------
12
+ # BUILDING BLOCKS
13
+ # -------------------------------
14
  class SharedBlock(nn.Module):
15
  def __init__(self, in_dim, out_dim, dropout):
16
  super().__init__()
 
55
  return torch.sigmoid(self.branch(x)) * self.scale + self.shift
56
 
57
 
58
+ # -------------------------------
59
  # APEX MODEL
60
+ # -------------------------------
61
  class APEXModel(PreTrainedModel):
62
+ config_class = APEXConfig
63
+ _keys_to_ignore_on_load_missing = [r"mert\..*", r"mert_processor\..*"]
64
 
65
  def __init__(self, config: APEXConfig):
66
  super().__init__(config)
67
 
68
+ # Load MERT processor and encoder fresh from HuggingFace
69
  self.mert_processor = AutoProcessor.from_pretrained(
70
+ config.mert_model_name,
71
+ trust_remote_code = True
72
  )
73
+ with torch.device("cpu"):
74
  self.mert = AutoModel.from_pretrained(
75
  config.mert_model_name,
76
  trust_remote_code = True,
 
126
  waveform, sr = sf.read(audio_path, dtype="float32")
127
  waveform = torch.from_numpy(waveform)
128
 
 
129
  if len(waveform.shape) > 1 and waveform.shape[1] > 1:
130
  waveform = waveform.mean(dim=1)
131
 
132
  waveform = waveform.to(self.device)
133
 
 
134
  if sr != self.target_sr:
135
  waveform = TAF.resample(waveform, sr, self.target_sr)
136
 
 
145
  if segment.numel() == 0:
146
  break
147
 
 
148
  if segment.shape[0] < segment_len:
149
  pad_len = segment_len - segment.shape[0]
150
  segment = torch.nn.functional.pad(segment, (0, pad_len))
151
 
 
152
  inputs = self.mert_processor(
153
  segment.cpu().numpy(),
154
  sampling_rate = self.target_sr,
 
159
  with torch.no_grad():
160
  outputs = self.mert(**inputs, output_hidden_states=True)
161
 
 
162
  all_hidden = torch.stack([
163
  outputs.hidden_states[i].mean(dim=1)
164
  for i in self.config.layer_indices
165
+ ])
166
+ all_hidden = all_hidden.squeeze(1)
167
 
 
168
  pooled = self.aggregator(
169
  all_hidden.unsqueeze(0)
170
+ ).squeeze()
171
 
172
  segment_embeddings.append(pooled)
173
 
174
  del segment, inputs, outputs, all_hidden, pooled
175
 
 
176
  song_embedding = torch.stack(segment_embeddings).mean(dim=0)
177
  return song_embedding
178