import torch from torch import nn class Regularizer(nn.Module): def __init__(self): pass def _batch_root_mean_squared(tensor): tensor = tensor.view(tensor.shape[0], -1) return torch.norm(tensor, p=2, dim=1) / tensor.shape[1] ** 0.5 class RegularizationFunc(nn.Module): def forward(self, t, x, dx, context) -> torch.Tensor: """Outputs a batch of scaler regularizations.""" raise NotImplementedError class L1Reg(RegularizationFunc): def forward(self, t, x, dx, context) -> torch.Tensor: return torch.mean(torch.abs(dx), dim=1) class L2Reg(RegularizationFunc): def forward(self, t, x, dx, context) -> torch.Tensor: return _batch_root_mean_squared(dx) class SquaredL2Reg(RegularizationFunc): def forward(self, t, x, dx, context) -> torch.Tensor: to_return = dx.view(dx.shape[0], -1) return torch.pow(torch.norm(to_return, p=2, dim=1), 2) def _get_minibatch_jacobian(y, x, create_graph=True): """Computes the Jacobian of y wrt x assuming minibatch-mode. Args: y: (N, ...) with a total of D_y elements in ... x: (N, ...) with a total of D_x elements in ... Returns: The minibatch Jacobian matrix of shape (N, D_y, D_x) """ # assert y.shape[0] == x.shape[0] y = y.view(y.shape[0], -1) # Compute Jacobian row by row. jac = [] for j in range(y.shape[1]): dy_j_dx = torch.autograd.grad( y[:, j], x, torch.ones_like(y[:, j]), retain_graph=True, create_graph=create_graph, )[0] jac.append(torch.unsqueeze(dy_j_dx, -1)) jac = torch.cat(jac, -1) return jac class JacobianFrobeniusReg(RegularizationFunc): def forward(self, t, x, dx, context) -> torch.Tensor: if hasattr(context, "jac"): jac = context.jac else: jac = _get_minibatch_jacobian(dx, x) context.jac = jac jac = _get_minibatch_jacobian(dx, x) context.jac = jac return _batch_root_mean_squared(jac) class JacobianDiagFrobeniusReg(RegularizationFunc): def forward(self, t, x, dx, context) -> torch.Tensor: if hasattr(context, "jac"): jac = context.jac else: jac = _get_minibatch_jacobian(dx, x) context.jac = jac diagonal = jac.view(jac.shape[0], -1)[ :, :: jac.shape[1] ] # assumes jac is minibatch square, ie. (N, M, M). return _batch_root_mean_squared(diagonal) class JacobianOffDiagFrobeniusReg(RegularizationFunc): def forward(self, t, x, dx, context) -> torch.Tensor: if hasattr(context, "jac"): jac = context.jac else: jac = _get_minibatch_jacobian(dx, x) context.jac = jac diagonal = jac.view(jac.shape[0], -1)[ :, :: jac.shape[1] ] # assumes jac is minibatch square, ie. (N, M, M). ss_offdiag = torch.sum(jac.view(jac.shape[0], -1) ** 2, dim=1) - torch.sum( diagonal**2, dim=1 ) ms_offdiag = ss_offdiag / (diagonal.shape[1] * (diagonal.shape[1] - 1)) return ms_offdiag def autograd_trace(x_out, x_in, **kwargs): """Standard brute-force means of obtaining trace of the Jacobian, O(d) calls to autograd.""" trJ = 0.0 for i in range(x_in.shape[1]): trJ += torch.autograd.grad(x_out[:, i].sum(), x_in, allow_unused=False, create_graph=True)[ 0 ][:, i] return trJ class CNFReg(RegularizationFunc): def __init__(self, trace_estimator=None, noise_dist=None): super().__init__() self.trace_estimator = trace_estimator if trace_estimator is not None else autograd_trace self.noise_dist, self.noise = noise_dist, None def forward(self, t, x, dx, context): # TODO we could check if jac is in the context to speed up return -self.trace_estimator(dx, x, noise=self.noise) class AugmentationModule(nn.Module): """Class orchestrating augmentations. Also establishes order. """ def __init__( self, cnf_estimator: str = None, l1_reg: float = 0.0, l2_reg: float = 0.0, squared_l2_reg: float = 0.0, jacobian_frobenius_reg: float = 0.0, jacobian_diag_frobenius_reg: float = 0.0, jacobian_off_diag_frobenius_reg: float = 0.0, ) -> None: super().__init__() coeffs = [] regs = [] if cnf_estimator == "exact": coeffs.append(1) regs.append(CNFReg(None, noise_dist=None)) if l1_reg > 0.0: coeffs.append(l1_reg) regs.append(L1Reg()) if l2_reg > 0.0: coeffs.append(l2_reg) regs.append(L2Reg()) if squared_l2_reg > 0.0: coeffs.append(squared_l2_reg) regs.append(SquaredL2Reg()) if jacobian_frobenius_reg > 0.0: coeffs.append(jacobian_frobenius_reg) regs.append(JacobianFrobeniusReg()) if jacobian_diag_frobenius_reg > 0.0: coeffs.append(jacobian_diag_frobenius_reg) regs.append(JacobianDiagFrobeniusReg()) if jacobian_off_diag_frobenius_reg > 0.0: coeffs.append(jacobian_off_diag_frobenius_reg) regs.append(JacobianOffDiagFrobeniusReg()) self.coeffs = torch.tensor(coeffs) self.regs = torch.ModuleList(regs) if __name__ == "__main__": # Test Shapes class SharedContext: pass for reg in [ L1Reg, L2Reg, SquaredL2Reg, JacobianFrobeniusReg, JacobianDiagFrobeniusReg, JacobianOffDiagFrobeniusReg, ]: x = torch.ones(2, 3).requires_grad_(True) dx = x * 2 out = reg().forward(torch.ones(1), x, dx, SharedContext) assert out.dim() == 1 assert out.shape[0] == 2