xiangzai's picture
Add files using upload-large-folder tool
3e4f775 verified
from typing import List, Optional
import torch
from torch import nn
ACTIVATION_MAP = {
"relu": nn.ReLU,
"sigmoid": nn.Sigmoid,
"tanh": nn.Tanh,
"selu": nn.SELU,
"elu": nn.ELU,
"lrelu": nn.LeakyReLU,
"softplus": nn.Softplus,
"silu": nn.SiLU,
}
class SimpleDenseNet(nn.Module):
def __init__(
self,
input_size: int,
target_size: int,
activation: str,
batch_norm: bool = True,
hidden_dims: Optional[List[int]] = None,
):
super().__init__()
if hidden_dims is None:
hidden_dims = [256, 256, 256]
dims = [input_size, *hidden_dims, target_size]
layers = []
for i in range(len(dims) - 2):
layers.append(nn.Linear(dims[i], dims[i + 1]))
if batch_norm:
layers.append(nn.BatchNorm1d(dims[i + 1]))
layers.append(ACTIVATION_MAP[activation]())
layers.append(nn.Linear(dims[-2], dims[-1]))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class DivergenceFreeNet(SimpleDenseNet):
"""Implements a divergence free network as the gradient of a scalar potential function."""
def __init__(self, dim: int, *args, **kwargs):
super().__init__(input_size=dim + 1, target_size=1, *args, **kwargs)
def energy(self, x):
return self.model(x)
def forward(self, t, x, *args, **kwargs):
"""Ignore t run model."""
if t.dim() < 2:
t = t.repeat(x.shape[0])[:, None]
x = torch.cat([t, x], dim=-1)
x = x.requires_grad_(True)
grad = torch.autograd.grad(torch.sum(self.model(x)), x, create_graph=True)[0]
return grad[:, :-1]
class TimeInvariantVelocityNet(SimpleDenseNet):
def __init__(self, dim: int, *args, **kwargs):
super().__init__(input_size=dim, target_size=dim, *args, **kwargs)
def forward(self, t, x, *args, **kwargs):
"""Ignore t run model."""
del t
return self.model(x)
class VelocityNet(SimpleDenseNet):
def __init__(self, dim: int, *args, **kwargs):
super().__init__(input_size=dim + 1, target_size=dim, *args, **kwargs)
def forward(self, t, x, *args, **kwargs):
"""Ignore t run model."""
if t.dim() < 1 or t.shape[0] != x.shape[0]:
t = t.repeat(x.shape[0])[:, None]
if t.dim() < 2:
t = t[:, None]
x = torch.cat([t, x], dim=-1)
return self.model(x)
if __name__ == "__main__":
_ = SimpleDenseNet()
_ = TimeInvariantVelocityNet()