bdck commited on
Commit
9ef936d
·
verified ·
1 Parent(s): 5b1c4a3

Upload sdfnet.py

Browse files
Files changed (1) hide show
  1. sdfnet.py +102 -0
sdfnet.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SDFNetwork — 8-layer MLP with positional encoding and geometric initialization.
3
+ Lightweight reimplementation without CUDA hash encoders.
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ from .embedder import get_embedder
9
+
10
+
11
+ class SDFNetwork(nn.Module):
12
+ def __init__(self,
13
+ d_in=3,
14
+ d_out=1,
15
+ d_hidden=256,
16
+ n_layers=8,
17
+ skip_in=(4,),
18
+ multires=8,
19
+ bias=0.5,
20
+ scale=1.0,
21
+ geometric_init=True,
22
+ weight_norm=True,
23
+ inside_outside=False):
24
+ super(SDFNetwork, self).__init__()
25
+
26
+ dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
27
+
28
+ self.embed_fn_fine = None
29
+ if multires > 0:
30
+ embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
31
+ self.embed_fn_fine = embed_fn
32
+ dims[0] = input_ch
33
+
34
+ self.num_layers = len(dims)
35
+ self.skip_in = skip_in
36
+ self.scale = scale
37
+
38
+ for l in range(0, self.num_layers - 1):
39
+ if l + 1 in self.skip_in:
40
+ out_dim = dims[l + 1] - dims[0]
41
+ else:
42
+ out_dim = dims[l + 1]
43
+
44
+ lin = nn.Linear(dims[l], out_dim)
45
+
46
+ if geometric_init:
47
+ if l == self.num_layers - 2:
48
+ if not inside_outside:
49
+ torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
50
+ torch.nn.init.constant_(lin.bias, -bias)
51
+ else:
52
+ torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
53
+ torch.nn.init.constant_(lin.bias, bias)
54
+ elif multires > 0 and l == 0:
55
+ torch.nn.init.constant_(lin.bias, 0.0)
56
+ torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
57
+ torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
58
+ elif multires > 0 and l in self.skip_in:
59
+ torch.nn.init.constant_(lin.bias, 0.0)
60
+ torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
61
+ torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
62
+ else:
63
+ torch.nn.init.constant_(lin.bias, 0.0)
64
+ torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
65
+
66
+ if weight_norm:
67
+ lin = nn.utils.weight_norm(lin)
68
+ setattr(self, "lin" + str(l), lin)
69
+
70
+ self.activation = nn.ReLU()
71
+
72
+ def forward(self, inputs):
73
+ inputs = inputs * self.scale
74
+ if self.embed_fn_fine is not None:
75
+ inputs = self.embed_fn_fine(inputs)
76
+
77
+ x = inputs
78
+ for l in range(0, self.num_layers - 1):
79
+ lin = getattr(self, "lin" + str(l))
80
+ if l in self.skip_in:
81
+ x = torch.cat([x, inputs], 1) / np.sqrt(2)
82
+ x = lin(x)
83
+ if l < self.num_layers - 2:
84
+ x = self.activation(x)
85
+
86
+ return x / self.scale
87
+
88
+ def sdf(self, x):
89
+ return self.forward(x)
90
+
91
+ def gradient(self, x):
92
+ x.requires_grad_(True)
93
+ y = self.sdf(x)
94
+ d_output = torch.ones_like(y, requires_grad=False, device=y.device)
95
+ gradients = torch.autograd.grad(
96
+ outputs=y,
97
+ inputs=x,
98
+ grad_outputs=d_output,
99
+ create_graph=True,
100
+ retain_graph=True,
101
+ only_inputs=True)[0]
102
+ return gradients, y