import torch from torch import nn class ICNN(torch.nn.Module): """Input Convex Neural Network.""" def __init__(self, dim=2, dimh=64, num_hidden_layers=4): super().__init__() Wzs = [] Wzs.append(nn.Linear(dim, dimh)) for _ in range(num_hidden_layers - 1): Wzs.append(torch.nn.Linear(dimh, dimh, bias=False)) Wzs.append(torch.nn.Linear(dimh, 1, bias=False)) self.Wzs = torch.nn.ModuleList(Wzs) Wxs = [] for _ in range(num_hidden_layers - 1): Wxs.append(nn.Linear(dim, dimh)) Wxs.append(nn.Linear(dim, 1, bias=False)) self.Wxs = torch.nn.ModuleList(Wxs) self.act = nn.Softplus() def forward(self, x): z = self.act(self.Wzs[0](x)) for Wz, Wx in zip(self.Wzs[1:-1], self.Wxs[:-1]): z = self.act(Wz(z) + Wx(x)) return self.Wzs[-1](z) + self.Wxs[-1](x)