| from typing import List, Optional | |
| import torch | |
| from torch import nn | |
| ACTIVATION_MAP = { | |
| "relu": nn.ReLU, | |
| "sigmoid": nn.Sigmoid, | |
| "tanh": nn.Tanh, | |
| "selu": nn.SELU, | |
| "elu": nn.ELU, | |
| "lrelu": nn.LeakyReLU, | |
| "softplus": nn.Softplus, | |
| "silu": nn.SiLU, | |
| } | |
| class SimpleDenseNet(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| target_size: int, | |
| activation: str, | |
| batch_norm: bool = True, | |
| hidden_dims: Optional[List[int]] = None, | |
| ): | |
| super().__init__() | |
| if hidden_dims is None: | |
| hidden_dims = [256, 256, 256] | |
| dims = [input_size, *hidden_dims, target_size] | |
| layers = [] | |
| for i in range(len(dims) - 2): | |
| layers.append(nn.Linear(dims[i], dims[i + 1])) | |
| if batch_norm: | |
| layers.append(nn.BatchNorm1d(dims[i + 1])) | |
| layers.append(ACTIVATION_MAP[activation]()) | |
| layers.append(nn.Linear(dims[-2], dims[-1])) | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.model(x) | |
| class DivergenceFreeNet(SimpleDenseNet): | |
| """Implements a divergence free network as the gradient of a scalar potential function.""" | |
| def __init__(self, dim: int, *args, **kwargs): | |
| super().__init__(input_size=dim + 1, target_size=1, *args, **kwargs) | |
| def energy(self, x): | |
| return self.model(x) | |
| def forward(self, t, x, *args, **kwargs): | |
| """Ignore t run model.""" | |
| if t.dim() < 2: | |
| t = t.repeat(x.shape[0])[:, None] | |
| x = torch.cat([t, x], dim=-1) | |
| x = x.requires_grad_(True) | |
| grad = torch.autograd.grad(torch.sum(self.model(x)), x, create_graph=True)[0] | |
| return grad[:, :-1] | |
| class TimeInvariantVelocityNet(SimpleDenseNet): | |
| def __init__(self, dim: int, *args, **kwargs): | |
| super().__init__(input_size=dim, target_size=dim, *args, **kwargs) | |
| def forward(self, t, x, *args, **kwargs): | |
| """Ignore t run model.""" | |
| del t | |
| return self.model(x) | |
| class VelocityNet(SimpleDenseNet): | |
| def __init__(self, dim: int, *args, **kwargs): | |
| super().__init__(input_size=dim + 1, target_size=dim, *args, **kwargs) | |
| def forward(self, t, x, *args, **kwargs): | |
| """Ignore t run model.""" | |
| if t.dim() < 1 or t.shape[0] != x.shape[0]: | |
| t = t.repeat(x.shape[0])[:, None] | |
| if t.dim() < 2: | |
| t = t[:, None] | |
| x = torch.cat([t, x], dim=-1) | |
| return self.model(x) | |
| if __name__ == "__main__": | |
| _ = SimpleDenseNet() | |
| _ = TimeInvariantVelocityNet() | |