File size: 1,836 Bytes
16d6869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""
Gradient Reversal Layer (Ganin et al. 2016, DANN).

During the forward pass the GRL acts as an identity.
During the backward pass it multiplies the incoming gradient by -alpha,
which forces the upstream encoder to *maximise* whatever loss is downstream
of the GRL — i.e. to learn features that confuse the site classifier.

Alpha is set externally (typically annealed from 0 → 1 using the Ganin
schedule: alpha = 2/(1+exp(-10*p)) - 1, where p ∈ [0,1] is training progress).
"""

from __future__ import annotations

import math

import torch
from torch.autograd import Function


class _GRLFunction(Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, alpha: float) -> torch.Tensor:
        ctx.alpha = alpha
        return x.clone()

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        # Flip and scale gradients; None for the alpha grad (not a tensor)
        return -ctx.alpha * grad_output, None


class GradientReversal(torch.nn.Module):
    """Wraps _GRLFunction as a stateful nn.Module so alpha can be updated
    between epochs without re-building the model."""

    def __init__(self, alpha: float = 0.0):
        super().__init__()
        self.alpha = alpha          # updated externally by the Lightning task

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return _GRLFunction.apply(x, self.alpha)

    def __repr__(self) -> str:
        return f"GradientReversal(alpha={self.alpha:.4f})"


def ganin_alpha(epoch: int, max_epochs: int) -> float:
    """Ganin et al. (2016) annealing schedule.

    Starts at 0 (GRL has no effect) and saturates towards 1.
    Using 10× steeper ramp than the original paper so alpha reaches
    ~0.9 by the midpoint of training.
    """
    p = epoch / max(max_epochs - 1, 1)
    return 2.0 / (1.0 + math.exp(-10.0 * p)) - 1.0