bdck commited on
Commit
74ead5e
·
verified ·
1 Parent(s): 8e3ab85

Upload embedder.py

Browse files
Files changed (1) hide show
  1. embedder.py +56 -0
embedder.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Positional encoding embedding.
3
+ Based on NeRF / NeuS implementation (same as original).
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ class Embedder:
9
+ def __init__(self, **kwargs):
10
+ self.kwargs = kwargs
11
+ self.create_embedding_fn()
12
+
13
+ def create_embedding_fn(self):
14
+ embed_fns = []
15
+ d = self.kwargs['input_dims']
16
+ out_dim = 0
17
+ if self.kwargs['include_input']:
18
+ embed_fns.append(lambda x: x)
19
+ out_dim += d
20
+
21
+ max_freq = self.kwargs['max_freq_log2']
22
+ N_freqs = self.kwargs['num_freqs']
23
+
24
+ if self.kwargs['log_sampling']:
25
+ freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
26
+ else:
27
+ freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
28
+
29
+ for freq in freq_bands:
30
+ for p_fn in self.kwargs['periodic_fns']:
31
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
32
+ out_dim += d
33
+
34
+ self.embed_fns = embed_fns
35
+ self.out_dim = out_dim
36
+
37
+ def embed(self, inputs):
38
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
39
+
40
+
41
+ def get_embedder(multires, input_dims=3):
42
+ embed_kwargs = {
43
+ 'include_input': True,
44
+ 'input_dims': input_dims,
45
+ 'max_freq_log2': multires - 1,
46
+ 'num_freqs': multires,
47
+ 'log_sampling': True,
48
+ 'periodic_fns': [torch.sin, torch.cos],
49
+ }
50
+
51
+ embedder_obj = Embedder(**embed_kwargs)
52
+
53
+ def embed(x, eo=embedder_obj):
54
+ return eo.embed(x)
55
+
56
+ return embed, embedder_obj.out_dim