| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def mlp( |
| input_dim, |
| hidden_dim, |
| output_dim, |
| hidden_depth, |
| output_mod=None, |
| batchnorm=False, |
| activation=nn.ReLU, |
| ): |
| if hidden_depth == 0: |
| mods = [nn.Linear(input_dim, output_dim)] |
| else: |
| mods = ( |
| [nn.Linear(input_dim, hidden_dim), activation(inplace=True)] |
| if not batchnorm |
| else [ |
| nn.Linear(input_dim, hidden_dim), |
| nn.BatchNorm1d(hidden_dim), |
| activation(inplace=True), |
| ] |
| ) |
| for _ in range(hidden_depth - 1): |
| mods += ( |
| [nn.Linear(hidden_dim, hidden_dim), activation(inplace=True)] |
| if not batchnorm |
| else [ |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.BatchNorm1d(hidden_dim), |
| activation(inplace=True), |
| ] |
| ) |
| mods.append(nn.Linear(hidden_dim, output_dim)) |
| if output_mod is not None: |
| mods.append(output_mod) |
| trunk = nn.Sequential(*mods) |
| return trunk |
|
|
|
|
| def weight_init(m): |
| """Custom weight init for Conv2D and Linear layers.""" |
| if isinstance(m, nn.Linear): |
| nn.init.orthogonal_(m.weight.data) |
| if hasattr(m.bias, "data"): |
| m.bias.data.fill_(0.0) |
|
|
|
|
| class MLP(nn.Module): |
| def __init__( |
| self, |
| input_dim, |
| hidden_dim, |
| output_dim, |
| hidden_depth, |
| output_mod=None, |
| batchnorm=False, |
| activation=nn.ReLU, |
| ): |
| super().__init__() |
| self.trunk = mlp( |
| input_dim, |
| hidden_dim, |
| output_dim, |
| hidden_depth, |
| output_mod, |
| batchnorm=batchnorm, |
| activation=activation, |
| ) |
| self.apply(weight_init) |
|
|
| def forward(self, x): |
| return self.trunk(x) |
|
|