xiangzai's picture
Add files using upload-large-folder tool
3e4f775 verified
import torch
from torch import nn
class ICNN(torch.nn.Module):
"""Input Convex Neural Network."""
def __init__(self, dim=2, dimh=64, num_hidden_layers=4):
super().__init__()
Wzs = []
Wzs.append(nn.Linear(dim, dimh))
for _ in range(num_hidden_layers - 1):
Wzs.append(torch.nn.Linear(dimh, dimh, bias=False))
Wzs.append(torch.nn.Linear(dimh, 1, bias=False))
self.Wzs = torch.nn.ModuleList(Wzs)
Wxs = []
for _ in range(num_hidden_layers - 1):
Wxs.append(nn.Linear(dim, dimh))
Wxs.append(nn.Linear(dim, 1, bias=False))
self.Wxs = torch.nn.ModuleList(Wxs)
self.act = nn.Softplus()
def forward(self, x):
z = self.act(self.Wzs[0](x))
for Wz, Wx in zip(self.Wzs[1:-1], self.Wxs[:-1]):
z = self.act(Wz(z) + Wx(x))
return self.Wzs[-1](z) + self.Wxs[-1](x)