File size: 3,019 Bytes
f35a6e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | """
Wav2Vec 2.0-based classifier for deepfake audio detection.
Architecture:
Raw waveform (16 kHz, 4 sec, 64000 samples)
β Wav2Vec 2.0 Base backbone (95M params, 12 transformer layers)
β mean pooling over time dimension
β linear classification head (768 β 2)
β logits for [bonafide, spoof]
In Stage 1, the backbone is frozen and only the head is trained.
In Stage 2, the top N transformer layers are unfrozen and fine-tuned.
"""
import torch
import torch.nn as nn
from transformers import Wav2Vec2Model
class Wav2VecClassifier(nn.Module):
"""Wav2Vec 2.0 + mean pooling + linear head."""
def __init__(
self,
backbone_name: str = "facebook/wav2vec2-base",
num_classes: int = 2,
freeze_backbone: bool = True,
):
super().__init__()
# Load pretrained backbone from HuggingFace
self.backbone = Wav2Vec2Model.from_pretrained(backbone_name)
# Get hidden size from the backbone config (768 for Base, 1024 for Large)
hidden_size = self.backbone.config.hidden_size
# Classification head
self.classifier = nn.Linear(hidden_size, num_classes)
# Freeze backbone if requested (Stage 1 default)
self.freeze_backbone(freeze_backbone)
def freeze_backbone(self, freeze: bool = True):
"""Freeze or unfreeze the entire Wav2Vec backbone."""
for param in self.backbone.parameters():
param.requires_grad = not freeze
def unfreeze_top_n_layers(self, n: int):
"""Unfreeze only the top N transformer layers (Stage 2)."""
# First freeze everything
self.freeze_backbone(True)
# Then unfreeze top N transformer encoder layers
total_layers = len(self.backbone.encoder.layers)
for i in range(total_layers - n, total_layers):
for param in self.backbone.encoder.layers[i].parameters():
param.requires_grad = True
# Also unfreeze the layer norm at the end (small but matters)
for param in self.backbone.encoder.layer_norm.parameters():
param.requires_grad = True
def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
"""
Args:
waveforms: (batch_size, num_samples) tensor of raw audio at 16 kHz
Returns:
logits: (batch_size, num_classes) tensor of unnormalized scores
"""
# Backbone produces (batch, time_frames, hidden_size)
outputs = self.backbone(waveforms)
hidden_states = outputs.last_hidden_state # (B, T, H)
# Mean pool over time dimension β (B, H)
pooled = hidden_states.mean(dim=1)
# Classification head β (B, num_classes)
logits = self.classifier(pooled)
return logits
def count_trainable_params(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def count_total_params(self) -> int:
return sum(p.numel() for p in self.parameters())
|