""" VGNetwork — Vertex Generator Network (MLP-only, no PointTransformerV3). Inputs: sample points + normals. Outputs: 3D displacement. """ import torch import torch.nn as nn import numpy as np from .embedder import get_embedder class VGNetwork(nn.Module): def __init__(self, d_in=3, d_out=3, d_hidden=256, n_layers=8, skip_in=(4,), multires=8, scale=1.0, geometric_init=True, weight_norm=True): super(VGNetwork, self).__init__() dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] self.embed_fn_fine = None if multires > 0: embed_fn, input_ch = get_embedder(multires, input_dims=d_in) self.embed_fn_fine = embed_fn dims[0] = input_ch + 3 # positional encoding + original xyz + normals else: dims[0] += 3 # add normals self.num_layers = len(dims) self.skip_in = skip_in self.scale = scale for l in range(0, self.num_layers - 1): if l + 1 in self.skip_in: out_dim = dims[l + 1] - dims[0] else: out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if geometric_init: if multires > 0 and l == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) elif multires > 0 and l in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) if weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.activation = nn.ReLU() def forward(self, samples, normals): """ Args: samples: (B, 3) query points normals: (B, 3) estimated normals at samples Returns: moving_pcd: (B, 3) displaced points = samples + delta """ inputs = samples * self.scale if self.embed_fn_fine is not None: inputs = self.embed_fn_fine(inputs) inputs = torch.cat((inputs, normals), dim=-1) x = inputs for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) if l in self.skip_in: x = torch.cat([x, inputs], 1) / np.sqrt(2) x = lin(x) if l < self.num_layers - 2: x = self.activation(x) moving_pcd = samples + x / self.scale return moving_pcd