bdck commited on
Commit
70d6e03
·
verified ·
1 Parent(s): 9051cce

Upload vgnet.py

Browse files
Files changed (1) hide show
  1. vgnet.py +88 -0
vgnet.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VGNetwork — Vertex Generator Network (MLP-only, no PointTransformerV3).
3
+ Inputs: sample points + normals. Outputs: 3D displacement.
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ from .embedder import get_embedder
9
+
10
+
11
+ class VGNetwork(nn.Module):
12
+ def __init__(self,
13
+ d_in=3,
14
+ d_out=3,
15
+ d_hidden=256,
16
+ n_layers=8,
17
+ skip_in=(4,),
18
+ multires=8,
19
+ scale=1.0,
20
+ geometric_init=True,
21
+ weight_norm=True):
22
+ super(VGNetwork, self).__init__()
23
+
24
+ dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
25
+
26
+ self.embed_fn_fine = None
27
+ if multires > 0:
28
+ embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
29
+ self.embed_fn_fine = embed_fn
30
+ dims[0] = input_ch + 3 # positional encoding + original xyz + normals
31
+ else:
32
+ dims[0] += 3 # add normals
33
+
34
+ self.num_layers = len(dims)
35
+ self.skip_in = skip_in
36
+ self.scale = scale
37
+
38
+ for l in range(0, self.num_layers - 1):
39
+ if l + 1 in self.skip_in:
40
+ out_dim = dims[l + 1] - dims[0]
41
+ else:
42
+ out_dim = dims[l + 1]
43
+
44
+ lin = nn.Linear(dims[l], out_dim)
45
+
46
+ if geometric_init:
47
+ if multires > 0 and l == 0:
48
+ torch.nn.init.constant_(lin.bias, 0.0)
49
+ torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
50
+ torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
51
+ elif multires > 0 and l in self.skip_in:
52
+ torch.nn.init.constant_(lin.bias, 0.0)
53
+ torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
54
+ torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
55
+ else:
56
+ torch.nn.init.constant_(lin.bias, 0.0)
57
+ torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
58
+
59
+ if weight_norm:
60
+ lin = nn.utils.weight_norm(lin)
61
+ setattr(self, "lin" + str(l), lin)
62
+
63
+ self.activation = nn.ReLU()
64
+
65
+ def forward(self, samples, normals):
66
+ """
67
+ Args:
68
+ samples: (B, 3) query points
69
+ normals: (B, 3) estimated normals at samples
70
+ Returns:
71
+ moving_pcd: (B, 3) displaced points = samples + delta
72
+ """
73
+ inputs = samples * self.scale
74
+ if self.embed_fn_fine is not None:
75
+ inputs = self.embed_fn_fine(inputs)
76
+ inputs = torch.cat((inputs, normals), dim=-1)
77
+
78
+ x = inputs
79
+ for l in range(0, self.num_layers - 1):
80
+ lin = getattr(self, "lin" + str(l))
81
+ if l in self.skip_in:
82
+ x = torch.cat([x, inputs], 1) / np.sqrt(2)
83
+ x = lin(x)
84
+ if l < self.num_layers - 2:
85
+ x = self.activation(x)
86
+
87
+ moving_pcd = samples + x / self.scale
88
+ return moving_pcd