| """ |
| Wav2Vec 2.0-based classifier for deepfake audio detection. |
| |
| Architecture: |
| Raw waveform (16 kHz, 4 sec, 64000 samples) |
| β Wav2Vec 2.0 Base backbone (95M params, 12 transformer layers) |
| β mean pooling over time dimension |
| β linear classification head (768 β 2) |
| β logits for [bonafide, spoof] |
| |
| In Stage 1, the backbone is frozen and only the head is trained. |
| In Stage 2, the top N transformer layers are unfrozen and fine-tuned. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import Wav2Vec2Model |
|
|
|
|
| class Wav2VecClassifier(nn.Module): |
| """Wav2Vec 2.0 + mean pooling + linear head.""" |
|
|
| def __init__( |
| self, |
| backbone_name: str = "facebook/wav2vec2-base", |
| num_classes: int = 2, |
| freeze_backbone: bool = True, |
| ): |
| super().__init__() |
|
|
| |
| self.backbone = Wav2Vec2Model.from_pretrained(backbone_name) |
|
|
| |
| hidden_size = self.backbone.config.hidden_size |
|
|
| |
| self.classifier = nn.Linear(hidden_size, num_classes) |
|
|
| |
| self.freeze_backbone(freeze_backbone) |
|
|
| def freeze_backbone(self, freeze: bool = True): |
| """Freeze or unfreeze the entire Wav2Vec backbone.""" |
| for param in self.backbone.parameters(): |
| param.requires_grad = not freeze |
|
|
| def unfreeze_top_n_layers(self, n: int): |
| """Unfreeze only the top N transformer layers (Stage 2).""" |
| |
| self.freeze_backbone(True) |
|
|
| |
| total_layers = len(self.backbone.encoder.layers) |
| for i in range(total_layers - n, total_layers): |
| for param in self.backbone.encoder.layers[i].parameters(): |
| param.requires_grad = True |
|
|
| |
| for param in self.backbone.encoder.layer_norm.parameters(): |
| param.requires_grad = True |
|
|
| def forward(self, waveforms: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| waveforms: (batch_size, num_samples) tensor of raw audio at 16 kHz |
| |
| Returns: |
| logits: (batch_size, num_classes) tensor of unnormalized scores |
| """ |
| |
| outputs = self.backbone(waveforms) |
| hidden_states = outputs.last_hidden_state |
|
|
| |
| pooled = hidden_states.mean(dim=1) |
|
|
| |
| logits = self.classifier(pooled) |
| return logits |
|
|
| def count_trainable_params(self) -> int: |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
| def count_total_params(self) -> int: |
| return sum(p.numel() for p in self.parameters()) |
|
|