File size: 3,019 Bytes
f35a6e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Wav2Vec 2.0-based classifier for deepfake audio detection.

Architecture:
    Raw waveform (16 kHz, 4 sec, 64000 samples)
        β†’ Wav2Vec 2.0 Base backbone (95M params, 12 transformer layers)
        β†’ mean pooling over time dimension
        β†’ linear classification head (768 β†’ 2)
        β†’ logits for [bonafide, spoof]

In Stage 1, the backbone is frozen and only the head is trained.
In Stage 2, the top N transformer layers are unfrozen and fine-tuned.
"""

import torch
import torch.nn as nn
from transformers import Wav2Vec2Model


class Wav2VecClassifier(nn.Module):
    """Wav2Vec 2.0 + mean pooling + linear head."""

    def __init__(
        self,
        backbone_name: str = "facebook/wav2vec2-base",
        num_classes: int = 2,
        freeze_backbone: bool = True,
    ):
        super().__init__()

        # Load pretrained backbone from HuggingFace
        self.backbone = Wav2Vec2Model.from_pretrained(backbone_name)

        # Get hidden size from the backbone config (768 for Base, 1024 for Large)
        hidden_size = self.backbone.config.hidden_size

        # Classification head
        self.classifier = nn.Linear(hidden_size, num_classes)

        # Freeze backbone if requested (Stage 1 default)
        self.freeze_backbone(freeze_backbone)

    def freeze_backbone(self, freeze: bool = True):
        """Freeze or unfreeze the entire Wav2Vec backbone."""
        for param in self.backbone.parameters():
            param.requires_grad = not freeze

    def unfreeze_top_n_layers(self, n: int):
        """Unfreeze only the top N transformer layers (Stage 2)."""
        # First freeze everything
        self.freeze_backbone(True)

        # Then unfreeze top N transformer encoder layers
        total_layers = len(self.backbone.encoder.layers)
        for i in range(total_layers - n, total_layers):
            for param in self.backbone.encoder.layers[i].parameters():
                param.requires_grad = True

        # Also unfreeze the layer norm at the end (small but matters)
        for param in self.backbone.encoder.layer_norm.parameters():
            param.requires_grad = True

    def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
        """
        Args:
            waveforms: (batch_size, num_samples) tensor of raw audio at 16 kHz

        Returns:
            logits: (batch_size, num_classes) tensor of unnormalized scores
        """
        # Backbone produces (batch, time_frames, hidden_size)
        outputs = self.backbone(waveforms)
        hidden_states = outputs.last_hidden_state  # (B, T, H)

        # Mean pool over time dimension β†’ (B, H)
        pooled = hidden_states.mean(dim=1)

        # Classification head β†’ (B, num_classes)
        logits = self.classifier(pooled)
        return logits

    def count_trainable_params(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def count_total_params(self) -> int:
        return sum(p.numel() for p in self.parameters())