File size: 1,836 Bytes
16d6869 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 | """
Gradient Reversal Layer (Ganin et al. 2016, DANN).
During the forward pass the GRL acts as an identity.
During the backward pass it multiplies the incoming gradient by -alpha,
which forces the upstream encoder to *maximise* whatever loss is downstream
of the GRL — i.e. to learn features that confuse the site classifier.
Alpha is set externally (typically annealed from 0 → 1 using the Ganin
schedule: alpha = 2/(1+exp(-10*p)) - 1, where p ∈ [0,1] is training progress).
"""
from __future__ import annotations
import math
import torch
from torch.autograd import Function
class _GRLFunction(Function):
@staticmethod
def forward(ctx, x: torch.Tensor, alpha: float) -> torch.Tensor:
ctx.alpha = alpha
return x.clone()
@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Flip and scale gradients; None for the alpha grad (not a tensor)
return -ctx.alpha * grad_output, None
class GradientReversal(torch.nn.Module):
"""Wraps _GRLFunction as a stateful nn.Module so alpha can be updated
between epochs without re-building the model."""
def __init__(self, alpha: float = 0.0):
super().__init__()
self.alpha = alpha # updated externally by the Lightning task
def forward(self, x: torch.Tensor) -> torch.Tensor:
return _GRLFunction.apply(x, self.alpha)
def __repr__(self) -> str:
return f"GradientReversal(alpha={self.alpha:.4f})"
def ganin_alpha(epoch: int, max_epochs: int) -> float:
"""Ganin et al. (2016) annealing schedule.
Starts at 0 (GRL has no effect) and saturates towards 1.
Using 10× steeper ramp than the original paper so alpha reaches
~0.9 by the midpoint of training.
"""
p = epoch / max(max_epochs - 1, 1)
return 2.0 / (1.0 + math.exp(-10.0 * p)) - 1.0
|