| import torch | |
| from torch import nn | |
| class ICNN(torch.nn.Module): | |
| """Input Convex Neural Network.""" | |
| def __init__(self, dim=2, dimh=64, num_hidden_layers=4): | |
| super().__init__() | |
| Wzs = [] | |
| Wzs.append(nn.Linear(dim, dimh)) | |
| for _ in range(num_hidden_layers - 1): | |
| Wzs.append(torch.nn.Linear(dimh, dimh, bias=False)) | |
| Wzs.append(torch.nn.Linear(dimh, 1, bias=False)) | |
| self.Wzs = torch.nn.ModuleList(Wzs) | |
| Wxs = [] | |
| for _ in range(num_hidden_layers - 1): | |
| Wxs.append(nn.Linear(dim, dimh)) | |
| Wxs.append(nn.Linear(dim, 1, bias=False)) | |
| self.Wxs = torch.nn.ModuleList(Wxs) | |
| self.act = nn.Softplus() | |
| def forward(self, x): | |
| z = self.act(self.Wzs[0](x)) | |
| for Wz, Wx in zip(self.Wzs[1:-1], self.Wxs[:-1]): | |
| z = self.act(Wz(z) + Wx(x)) | |
| return self.Wzs[-1](z) + self.Wxs[-1](x) | |