| import torch
|
| from torch import nn
|
|
|
|
|
| class Regularizer(nn.Module):
|
| def __init__(self):
|
| pass
|
|
|
|
|
| def _batch_root_mean_squared(tensor):
|
| tensor = tensor.view(tensor.shape[0], -1)
|
| return torch.norm(tensor, p=2, dim=1) / tensor.shape[1] ** 0.5
|
|
|
|
|
| class RegularizationFunc(nn.Module):
|
| def forward(self, t, x, dx, context) -> torch.Tensor:
|
| """Outputs a batch of scaler regularizations."""
|
| raise NotImplementedError
|
|
|
|
|
| class L1Reg(RegularizationFunc):
|
| def forward(self, t, x, dx, context) -> torch.Tensor:
|
| return torch.mean(torch.abs(dx), dim=1)
|
|
|
|
|
| class L2Reg(RegularizationFunc):
|
| def forward(self, t, x, dx, context) -> torch.Tensor:
|
| return _batch_root_mean_squared(dx)
|
|
|
|
|
| class SquaredL2Reg(RegularizationFunc):
|
| def forward(self, t, x, dx, context) -> torch.Tensor:
|
| to_return = dx.view(dx.shape[0], -1)
|
| return torch.pow(torch.norm(to_return, p=2, dim=1), 2)
|
|
|
|
|
| def _get_minibatch_jacobian(y, x, create_graph=True):
|
| """Computes the Jacobian of y wrt x assuming minibatch-mode.
|
|
|
| Args:
|
| y: (N, ...) with a total of D_y elements in ...
|
| x: (N, ...) with a total of D_x elements in ...
|
| Returns:
|
| The minibatch Jacobian matrix of shape (N, D_y, D_x)
|
| """
|
|
|
| y = y.view(y.shape[0], -1)
|
|
|
|
|
| jac = []
|
| for j in range(y.shape[1]):
|
| dy_j_dx = torch.autograd.grad(
|
| y[:, j],
|
| x,
|
| torch.ones_like(y[:, j]),
|
| retain_graph=True,
|
| create_graph=create_graph,
|
| )[0]
|
| jac.append(torch.unsqueeze(dy_j_dx, -1))
|
| jac = torch.cat(jac, -1)
|
| return jac
|
|
|
|
|
| class JacobianFrobeniusReg(RegularizationFunc):
|
| def forward(self, t, x, dx, context) -> torch.Tensor:
|
| if hasattr(context, "jac"):
|
| jac = context.jac
|
| else:
|
| jac = _get_minibatch_jacobian(dx, x)
|
| context.jac = jac
|
| jac = _get_minibatch_jacobian(dx, x)
|
| context.jac = jac
|
| return _batch_root_mean_squared(jac)
|
|
|
|
|
| class JacobianDiagFrobeniusReg(RegularizationFunc):
|
| def forward(self, t, x, dx, context) -> torch.Tensor:
|
| if hasattr(context, "jac"):
|
| jac = context.jac
|
| else:
|
| jac = _get_minibatch_jacobian(dx, x)
|
| context.jac = jac
|
| diagonal = jac.view(jac.shape[0], -1)[
|
| :, :: jac.shape[1]
|
| ]
|
| return _batch_root_mean_squared(diagonal)
|
|
|
|
|
| class JacobianOffDiagFrobeniusReg(RegularizationFunc):
|
| def forward(self, t, x, dx, context) -> torch.Tensor:
|
| if hasattr(context, "jac"):
|
| jac = context.jac
|
| else:
|
| jac = _get_minibatch_jacobian(dx, x)
|
| context.jac = jac
|
| diagonal = jac.view(jac.shape[0], -1)[
|
| :, :: jac.shape[1]
|
| ]
|
| ss_offdiag = torch.sum(jac.view(jac.shape[0], -1) ** 2, dim=1) - torch.sum(
|
| diagonal**2, dim=1
|
| )
|
| ms_offdiag = ss_offdiag / (diagonal.shape[1] * (diagonal.shape[1] - 1))
|
| return ms_offdiag
|
|
|
|
|
| def autograd_trace(x_out, x_in, **kwargs):
|
| """Standard brute-force means of obtaining trace of the Jacobian, O(d) calls to autograd."""
|
| trJ = 0.0
|
| for i in range(x_in.shape[1]):
|
| trJ += torch.autograd.grad(x_out[:, i].sum(), x_in, allow_unused=False, create_graph=True)[
|
| 0
|
| ][:, i]
|
| return trJ
|
|
|
|
|
| class CNFReg(RegularizationFunc):
|
| def __init__(self, trace_estimator=None, noise_dist=None):
|
| super().__init__()
|
| self.trace_estimator = trace_estimator if trace_estimator is not None else autograd_trace
|
| self.noise_dist, self.noise = noise_dist, None
|
|
|
| def forward(self, t, x, dx, context):
|
|
|
| return -self.trace_estimator(dx, x, noise=self.noise)
|
|
|
|
|
| class AugmentationModule(nn.Module):
|
| """Class orchestrating augmentations.
|
|
|
| Also establishes order.
|
| """
|
|
|
| def __init__(
|
| self,
|
| cnf_estimator: str = None,
|
| l1_reg: float = 0.0,
|
| l2_reg: float = 0.0,
|
| squared_l2_reg: float = 0.0,
|
| jacobian_frobenius_reg: float = 0.0,
|
| jacobian_diag_frobenius_reg: float = 0.0,
|
| jacobian_off_diag_frobenius_reg: float = 0.0,
|
| ) -> None:
|
| super().__init__()
|
| coeffs = []
|
| regs = []
|
| if cnf_estimator == "exact":
|
| coeffs.append(1)
|
| regs.append(CNFReg(None, noise_dist=None))
|
| if l1_reg > 0.0:
|
| coeffs.append(l1_reg)
|
| regs.append(L1Reg())
|
| if l2_reg > 0.0:
|
| coeffs.append(l2_reg)
|
| regs.append(L2Reg())
|
| if squared_l2_reg > 0.0:
|
| coeffs.append(squared_l2_reg)
|
| regs.append(SquaredL2Reg())
|
| if jacobian_frobenius_reg > 0.0:
|
| coeffs.append(jacobian_frobenius_reg)
|
| regs.append(JacobianFrobeniusReg())
|
| if jacobian_diag_frobenius_reg > 0.0:
|
| coeffs.append(jacobian_diag_frobenius_reg)
|
| regs.append(JacobianDiagFrobeniusReg())
|
| if jacobian_off_diag_frobenius_reg > 0.0:
|
| coeffs.append(jacobian_off_diag_frobenius_reg)
|
| regs.append(JacobianOffDiagFrobeniusReg())
|
|
|
| self.coeffs = torch.tensor(coeffs)
|
| self.regs = torch.ModuleList(regs)
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| class SharedContext:
|
| pass
|
|
|
| for reg in [
|
| L1Reg,
|
| L2Reg,
|
| SquaredL2Reg,
|
| JacobianFrobeniusReg,
|
| JacobianDiagFrobeniusReg,
|
| JacobianOffDiagFrobeniusReg,
|
| ]:
|
| x = torch.ones(2, 3).requires_grad_(True)
|
| dx = x * 2
|
| out = reg().forward(torch.ones(1), x, dx, SharedContext)
|
| assert out.dim() == 1
|
| assert out.shape[0] == 2
|
|
|