File size: 3,626 Bytes
9ef936d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | """
SDFNetwork — 8-layer MLP with positional encoding and geometric initialization.
Lightweight reimplementation without CUDA hash encoders.
"""
import torch
import torch.nn as nn
import numpy as np
from .embedder import get_embedder
class SDFNetwork(nn.Module):
def __init__(self,
d_in=3,
d_out=1,
d_hidden=256,
n_layers=8,
skip_in=(4,),
multires=8,
bias=0.5,
scale=1.0,
geometric_init=True,
weight_norm=True,
inside_outside=False):
super(SDFNetwork, self).__init__()
dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
self.embed_fn_fine = None
if multires > 0:
embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
self.embed_fn_fine = embed_fn
dims[0] = input_ch
self.num_layers = len(dims)
self.skip_in = skip_in
self.scale = scale
for l in range(0, self.num_layers - 1):
if l + 1 in self.skip_in:
out_dim = dims[l + 1] - dims[0]
else:
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if geometric_init:
if l == self.num_layers - 2:
if not inside_outside:
torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
torch.nn.init.constant_(lin.bias, -bias)
else:
torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
torch.nn.init.constant_(lin.bias, bias)
elif multires > 0 and l == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
elif multires > 0 and l in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
if weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.activation = nn.ReLU()
def forward(self, inputs):
inputs = inputs * self.scale
if self.embed_fn_fine is not None:
inputs = self.embed_fn_fine(inputs)
x = inputs
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
if l in self.skip_in:
x = torch.cat([x, inputs], 1) / np.sqrt(2)
x = lin(x)
if l < self.num_layers - 2:
x = self.activation(x)
return x / self.scale
def sdf(self, x):
return self.forward(x)
def gradient(self, x):
x.requires_grad_(True)
y = self.sdf(x)
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients, y
|