| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel, AutoConfig, AutoFeatureExtractor |
| import torchaudio |
| from safetensors import safe_open |
| from typing import List, Dict |
|
|
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cuda.enable_flash_sdp(True) |
| torch.backends.cuda.enable_mem_efficient_sdp(True) |
| torch.backends.cuda.enable_math_sdp(False) |
|
|
|
|
| class WavLMForMusicDetection(nn.Module): |
| """ |
| Music detection model based on WavLM. |
| Uses attention pooling + classification head. |
| Outputs probability that input audio contains music. |
| Supports batched inference with automatic batching and preprocessing. |
| EER - 2.5-3 % |
| """ |
| def __init__( |
| self, |
| base_model_name: str = 'microsoft/wavlm-base-plus', |
| batch_size: int = 32, |
| device: str = 'cuda' |
| ) -> None: |
| super().__init__() |
| self.config = AutoConfig.from_pretrained(base_model_name) |
| self.wavlm = AutoModel.from_pretrained(base_model_name, config=self.config) |
| self.processor = AutoFeatureExtractor.from_pretrained(base_model_name) |
|
|
| self.batch_size = batch_size |
| self.device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
|
| self.target_sample_rate = self.processor.sampling_rate |
|
|
| |
| self.pool_attention = nn.Sequential( |
| nn.Linear(self.config.hidden_size, 256), |
| nn.Tanh(), |
| nn.Linear(256, 1) |
| ) |
|
|
| |
| self.classifier = nn.Sequential( |
| nn.Linear(self.config.hidden_size, 256), |
| nn.LayerNorm(256), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(256, 64), |
| nn.LayerNorm(64), |
| nn.GELU(), |
| nn.Linear(64, 1) |
| ) |
|
|
| |
| self.to(self.device) |
|
|
| def _attention_pool( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Apply attention-based pooling over time dimension. |
| Args: |
| hidden_states (torch.Tensor): [batch_size, seq_len, hidden_size] |
| attention_mask (torch.Tensor): [batch_size, seq_len] — mask to ignore padding |
| Returns: |
| torch.Tensor: [batch_size, hidden_size] — context vector |
| """ |
| |
| attention_weights = self.pool_attention(hidden_states) |
| |
| attention_weights = attention_weights + ( |
| (1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9 |
| ) |
|
|
| attention_weights = F.softmax(attention_weights, dim=1) |
|
|
| |
| weighted_sum = torch.sum(hidden_states * attention_weights, dim=1) |
| return weighted_sum |
|
|
| def forward( |
| self, |
| input_values: torch.Tensor, |
| attention_mask: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Forward pass for inference. |
| Args: |
| input_values (torch.Tensor): [batch_size, audio_seq_len] — raw audio waveform |
| attention_mask (torch.Tensor): [batch_size, audio_seq_len] — input mask (1 = real, 0 = pad) |
| Returns: |
| torch.Tensor: [batch_size, 1] — probability that audio contains music |
| """ |
| assert isinstance(input_values, torch.Tensor), f"Expected torch.Tensor, got {type(input_values)}" |
| assert isinstance(attention_mask, torch.Tensor), f"Expected torch.Tensor, got {type(attention_mask)}" |
|
|
|
|
| input_values = input_values.to(dtype=self.dtype, device=self.device) |
| attention_mask = attention_mask.to(device=self.device, dtype=self.dtype) |
|
|
| outputs = self.wavlm(input_values, attention_mask=attention_mask) |
| hidden_states = outputs.last_hidden_state |
|
|
| input_length = attention_mask.size(1) |
| hidden_length = hidden_states.size(1) |
| ratio = input_length / hidden_length |
| indices = (torch.arange(hidden_length, device=attention_mask.device) * ratio).long() |
| attention_mask = attention_mask[:, indices] |
| attention_mask = attention_mask.bool() |
|
|
| pooled = self._attention_pool(hidden_states, attention_mask) |
| logits = self.classifier(pooled) |
|
|
| probs = torch.sigmoid(logits) |
| return probs |
|
|
| def _prepare_batches(self, audio_paths: List[str]) -> List[List[str]]: |
| """ |
| Split list of audio paths into batches of size `self.batch_size`. |
| Args: |
| audio_paths (List[str]): List of paths to audio files. |
| Returns: |
| List[List[str]]: List of batches, each batch is a list of paths. |
| """ |
| batches = [] |
| current_batch = [] |
| counter = 0 |
|
|
| while counter < len(audio_paths): |
| if len(current_batch) == self.batch_size: |
| batches.append(current_batch) |
| current_batch = [] |
| current_batch.append(audio_paths[counter]) |
| counter += 1 |
|
|
| if current_batch: |
| batches.append(current_batch) |
|
|
| return batches |
|
|
| def _preprocess_audio_batch(self, audio_paths: List[str]) -> Dict[str, torch.Tensor]: |
| """ |
| Load and preprocess a batch of audio files. |
| Args: |
| audio_paths (List[str]): List of file paths. |
| Returns: |
| Dict with keys: |
| "input_values": tensor [B, T] |
| "attention_mask": tensor [B, T] |
| """ |
| waveforms = [] |
|
|
| for audio_path in audio_paths: |
| waveform, sample_rate = torchaudio.load(audio_path) |
|
|
| |
| if sample_rate != self.target_sample_rate: |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate) |
| waveform = resampler(waveform) |
|
|
| |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
|
|
| waveforms.append(waveform.squeeze()) |
|
|
| |
| inputs = self.processor( |
| [w.numpy() for w in waveforms], |
| sampling_rate=self.target_sample_rate, |
| return_tensors="pt", |
| padding=True, |
| truncation=False |
| ) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
| return inputs |
|
|
| def predict_proba(self, audio_paths: List[str]) -> torch.Tensor: |
| """ |
| Predict music probability for a list of audio files. |
| Args: |
| audio_paths (List[str]): List of audio file paths. |
| Returns: |
| torch.Tensor: [N] — probabilities for each audio file. |
| """ |
|
|
| all_probs = [] |
|
|
| batches = self._prepare_batches(audio_paths) |
|
|
| for batch in batches: |
| inputs = self._preprocess_audio_batch(batch) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| probs = self.forward(**inputs).squeeze(-1) |
| all_probs.append(probs) |
|
|
| return torch.cat(all_probs, dim=0) |
|
|
| def convert_to_bf16(self): |
| self.wavlm = self.wavlm.to(torch.bfloat16) |
| self.pool_attention = self.pool_attention.to(torch.bfloat16) |
| self.classifier = self.classifier.to(torch.bfloat16) |
| self.dtype = torch.bfloat16 |
| return self |
| |
| def predict_proba_smart_batching( |
| self, |
| audio_paths: List[str], |
| audio_lengths: List[float] |
| ) -> torch.Tensor: |
| |
| assert len(audio_paths) == len(audio_lengths), \ |
| f"Mismatch: {len(audio_paths)} paths vs {len(audio_lengths)} lengths" |
| |
| was_training = self.training |
| self.eval() |
| |
| try: |
| indexed_audios = [ |
| (i, path, length) |
| for i, (path, length) in enumerate(zip(audio_paths, audio_lengths)) |
| ] |
| |
| sorted_audios = sorted(indexed_audios, key=lambda x: x[2]) |
| batches = [] |
| for i in range(0, len(sorted_audios), self.batch_size): |
| batch = sorted_audios[i:i + self.batch_size] |
| batches.append(batch) |
| |
| results = {} |
| |
| for batch in batches: |
| batch_paths = [item[1] for item in batch] |
| batch_indices = [item[0] for item in batch] |
| |
| inputs = self._preprocess_audio_batch(batch_paths) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| probs = self.forward(**inputs).squeeze(-1) |
| |
| if probs.dim() == 0: |
| probs = probs.unsqueeze(0) |
| |
| for idx, prob in zip(batch_indices, probs): |
| results[idx] = prob.cpu() |
| |
| all_probs = [results[i] for i in range(len(audio_paths))] |
| return torch.stack(all_probs) |
| finally: |
| if was_training: |
| self.train() |
| |
| if __name__ == "__main__": |
| device = 'cuda:0' |
| checkpoint_path = './music_detection.safetensors' |
| model = WavLMForMusicDetection('microsoft/wavlm-base-plus', batch_size=8, device=device) |
| model.convert_to_bf16() |
| model.eval() |
| with safe_open(checkpoint_path, framework="pt", device=device) as f: |
| state_dict = {key: f.get_tensor(key) for key in f.keys()} |
| model.load_state_dict(state_dict) |
|
|