| import torch |
|
|
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.nn.functional as F |
| import torch.nn.init as init |
|
|
|
|
| class MultiLatentEncoder(nn.Module): |
| def __init__(self, opt): |
| super(MultiLatentEncoder, self).__init__() |
|
|
| self.neuron_input = Siren( |
| dim_in = 7, |
| dim_out = opt.pos_encode_dim |
| ) |
|
|
| def forward(self, pos, direct, imp): |
| input_encoded = torch.concat((pos, direct, imp), -1) |
| output = self.neuron_input(input_encoded) |
| return output |
|
|
| def predict(self, pos, direct, imp): |
| input_encoded = torch.concat((pos, direct, imp), -1) |
| output = self.neuron_input(input_encoded) |
| return output |
|
|
| class AutoDecoder(nn.Module): |
| def __init__(self, opt): |
| super(AutoDecoder, self).__init__() |
|
|
| self.ndf = opt.ndf |
| self.data_shape = opt.data_shape |
|
|
| |
| def block(in_feat, out_feat, normalize=True): |
| layers = [nn.ConvTranspose3d(in_feat, out_feat, 4, 2, 1)] |
| if normalize: |
| layers.append(nn.BatchNorm3d(out_feat, 0.8)) |
| layers.append(nn.LeakyReLU(0.2, inplace=True)) |
| return layers |
|
|
| self.fc = nn.Sequential( |
| nn.Linear(opt.pos_encode_dim + opt.z_latent_dim, int((self.ndf*8)*int(self.data_shape/16)*int(self.data_shape/16)*int(self.data_shape/16))), |
| nn.LeakyReLU(0.2, inplace=True), |
| ) |
| self.decoder = nn.Sequential( |
| *block(self.ndf*8, self.ndf*4), |
| *block(self.ndf*4, self.ndf*2), |
| *block(self.ndf*2, self.ndf) |
| ) |
|
|
| self.toVoxelMd = nn.Sequential( |
| nn.ConvTranspose3d(self.ndf , 1, 4, 2, 1, bias=False), |
| nn.Tanh(), |
| ) |
|
|
| self.toVoxelBig = nn.Sequential( |
| *block(self.ndf, int(self.ndf/2)), |
| nn.ConvTranspose3d(int(self.ndf/2), 1, 4, 2, 1, bias=False), |
| nn.Tanh(), |
| ) |
|
|
| self.latent_vectors = nn.Parameter(torch.FloatTensor(opt.train_dataset_size, opt.z_latent_dim)) |
| self.cookbook = nn.Parameter(torch.FloatTensor(opt.train_dataset_size, opt.pos_encode_dim + opt.z_latent_dim)) |
|
|
| init.xavier_normal_(self.latent_vectors) |
|
|
| def Cook(self, x, y): |
| input_x = self.embedding(x,y) |
| distances = ( |
| (input_x ** 2).sum(1, keepdim=True) |
| - 2 * input_x @ self.cookbook.transpose(0, 1) |
| + (self.cookbook.transpose(0, 1) ** 2).sum(0, keepdim=True) |
| ) |
| encoding_indices = distances.argmin(1) |
| output = F.embedding(encoding_indices.view(input_x.shape[0],*input_x.shape[2:]), self.cookbook) |
| distance = ((input_x - output.detach()) ** 2).mean() |
|
|
| |
|
|
| return output, encoding_indices, distance |
|
|
| def embedding(self, x, y): |
| input_x = torch.concat((x, y), -1) |
| return input_x |
|
|
| def forward(self, x, y, t = "Middle"): |
| input_x = self.embedding(x, y) |
| if t == "Middle": |
| return self.forwardMiddle(input_x) |
| else: |
| return self.forwardBig(input_x) |
|
|
| def forwardMiddle(self, input_x): |
| feature = self.fc(input_x).reshape(1, self.ndf*8, int(self.data_shape/16), int(self.data_shape/16), int(self.data_shape/16)) |
| output = self.decoder(feature) |
| output = self.toVoxelMd(output) |
| output = output.view(1,1,self.data_shape,self.data_shape,self.data_shape) |
|
|
| return output |
| def forwardBig(self, input_x): |
| feature = self.fc(input_x).reshape(1, self.ndf*8, int(self.data_shape/16), int(self.data_shape/16), int(self.data_shape/16)) |
| output = self.decoder(feature) |
| output = self.toVoxelBig(output) |
| output = output.view(1,1,self.data_shape*2,self.data_shape*2,self.data_shape*2) |
|
|
| return output |
|
|
| def codes(self): |
| return self.latent_vectors |
|
|