luh1124 commited on
Commit
c024a94
·
1 Parent(s): 9aa801c

Docs + runtime warning for LFS example assets

Browse files
extensions/vox2seq/benchmark.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import vox2seq
4
+
5
+
6
+ if __name__ == "__main__":
7
+ stats = {
8
+ 'z_order_cuda': [],
9
+ 'z_order_pytorch': [],
10
+ 'hilbert_cuda': [],
11
+ 'hilbert_pytorch': [],
12
+ }
13
+ RES = [16, 32, 64, 128, 256]
14
+ for res in RES:
15
+ coords = torch.meshgrid(torch.arange(res), torch.arange(res), torch.arange(res))
16
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda()
17
+
18
+ start = time.time()
19
+ for _ in range(100):
20
+ code_z_cuda = vox2seq.encode(coords, mode='z_order').cuda()
21
+ torch.cuda.synchronize()
22
+ stats['z_order_cuda'].append((time.time() - start) / 100)
23
+
24
+ start = time.time()
25
+ for _ in range(100):
26
+ code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order').cuda()
27
+ torch.cuda.synchronize()
28
+ stats['z_order_pytorch'].append((time.time() - start) / 100)
29
+
30
+ start = time.time()
31
+ for _ in range(100):
32
+ code_h_cuda = vox2seq.encode(coords, mode='hilbert').cuda()
33
+ torch.cuda.synchronize()
34
+ stats['hilbert_cuda'].append((time.time() - start) / 100)
35
+
36
+ start = time.time()
37
+ for _ in range(100):
38
+ code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert').cuda()
39
+ torch.cuda.synchronize()
40
+ stats['hilbert_pytorch'].append((time.time() - start) / 100)
41
+
42
+ print(f"{'Resolution':<12}{'Z-Order (CUDA)':<24}{'Z-Order (PyTorch)':<24}{'Hilbert (CUDA)':<24}{'Hilbert (PyTorch)':<24}")
43
+ for res, z_order_cuda, z_order_pytorch, hilbert_cuda, hilbert_pytorch in zip(RES, stats['z_order_cuda'], stats['z_order_pytorch'], stats['hilbert_cuda'], stats['hilbert_pytorch']):
44
+ print(f"{res:<12}{z_order_cuda:<24.6f}{z_order_pytorch:<24.6f}{hilbert_cuda:<24.6f}{hilbert_pytorch:<24.6f}")
45
+
extensions/vox2seq/setup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ from setuptools import setup
13
+ from torch.utils.cpp_extension import CUDAExtension, BuildExtension
14
+ import os
15
+ os.path.dirname(os.path.abspath(__file__))
16
+
17
+ setup(
18
+ name="vox2seq",
19
+ packages=['vox2seq', 'vox2seq.pytorch'],
20
+ ext_modules=[
21
+ CUDAExtension(
22
+ name="vox2seq._C",
23
+ sources=[
24
+ "src/api.cu",
25
+ "src/z_order.cu",
26
+ "src/hilbert.cu",
27
+ "src/ext.cpp",
28
+ ],
29
+ )
30
+ ],
31
+ cmdclass={
32
+ 'build_ext': BuildExtension
33
+ }
34
+ )
extensions/vox2seq/src/api.cu ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "api.h"
3
+ #include "z_order.h"
4
+ #include "hilbert.h"
5
+
6
+
7
+ torch::Tensor
8
+ z_order_encode(
9
+ const torch::Tensor& x,
10
+ const torch::Tensor& y,
11
+ const torch::Tensor& z
12
+ ) {
13
+ // Allocate output tensor
14
+ torch::Tensor codes = torch::empty_like(x);
15
+
16
+ // Call CUDA kernel
17
+ z_order_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
18
+ x.size(0),
19
+ reinterpret_cast<uint32_t*>(x.contiguous().data_ptr<int>()),
20
+ reinterpret_cast<uint32_t*>(y.contiguous().data_ptr<int>()),
21
+ reinterpret_cast<uint32_t*>(z.contiguous().data_ptr<int>()),
22
+ reinterpret_cast<uint32_t*>(codes.data_ptr<int>())
23
+ );
24
+
25
+ return codes;
26
+ }
27
+
28
+
29
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
30
+ z_order_decode(
31
+ const torch::Tensor& codes
32
+ ) {
33
+ // Allocate output tensors
34
+ torch::Tensor x = torch::empty_like(codes);
35
+ torch::Tensor y = torch::empty_like(codes);
36
+ torch::Tensor z = torch::empty_like(codes);
37
+
38
+ // Call CUDA kernel
39
+ z_order_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
40
+ codes.size(0),
41
+ reinterpret_cast<uint32_t*>(codes.contiguous().data_ptr<int>()),
42
+ reinterpret_cast<uint32_t*>(x.data_ptr<int>()),
43
+ reinterpret_cast<uint32_t*>(y.data_ptr<int>()),
44
+ reinterpret_cast<uint32_t*>(z.data_ptr<int>())
45
+ );
46
+
47
+ return std::make_tuple(x, y, z);
48
+ }
49
+
50
+
51
+ torch::Tensor
52
+ hilbert_encode(
53
+ const torch::Tensor& x,
54
+ const torch::Tensor& y,
55
+ const torch::Tensor& z
56
+ ) {
57
+ // Allocate output tensor
58
+ torch::Tensor codes = torch::empty_like(x);
59
+
60
+ // Call CUDA kernel
61
+ hilbert_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
62
+ x.size(0),
63
+ reinterpret_cast<uint32_t*>(x.contiguous().data_ptr<int>()),
64
+ reinterpret_cast<uint32_t*>(y.contiguous().data_ptr<int>()),
65
+ reinterpret_cast<uint32_t*>(z.contiguous().data_ptr<int>()),
66
+ reinterpret_cast<uint32_t*>(codes.data_ptr<int>())
67
+ );
68
+
69
+ return codes;
70
+ }
71
+
72
+
73
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
74
+ hilbert_decode(
75
+ const torch::Tensor& codes
76
+ ) {
77
+ // Allocate output tensors
78
+ torch::Tensor x = torch::empty_like(codes);
79
+ torch::Tensor y = torch::empty_like(codes);
80
+ torch::Tensor z = torch::empty_like(codes);
81
+
82
+ // Call CUDA kernel
83
+ hilbert_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
84
+ codes.size(0),
85
+ reinterpret_cast<uint32_t*>(codes.contiguous().data_ptr<int>()),
86
+ reinterpret_cast<uint32_t*>(x.data_ptr<int>()),
87
+ reinterpret_cast<uint32_t*>(y.data_ptr<int>()),
88
+ reinterpret_cast<uint32_t*>(z.data_ptr<int>())
89
+ );
90
+
91
+ return std::make_tuple(x, y, z);
92
+ }
extensions/vox2seq/src/api.h ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Serialize a voxel grid
3
+ *
4
+ * Copyright (C) 2024, Jianfeng XIANG <belljig@outlook.com>
5
+ * All rights reserved.
6
+ *
7
+ * Licensed under The MIT License [see LICENSE for details]
8
+ *
9
+ * Written by Jianfeng XIANG
10
+ */
11
+
12
+ #pragma once
13
+ #include <torch/extension.h>
14
+
15
+
16
+ #define BLOCK_SIZE 256
17
+
18
+
19
+ /**
20
+ * Z-order encode 3D points
21
+ *
22
+ * @param x [N] tensor containing the x coordinates
23
+ * @param y [N] tensor containing the y coordinates
24
+ * @param z [N] tensor containing the z coordinates
25
+ *
26
+ * @return [N] tensor containing the z-order encoded values
27
+ */
28
+ torch::Tensor
29
+ z_order_encode(
30
+ const torch::Tensor& x,
31
+ const torch::Tensor& y,
32
+ const torch::Tensor& z
33
+ );
34
+
35
+
36
+ /**
37
+ * Z-order decode 3D points
38
+ *
39
+ * @param codes [N] tensor containing the z-order encoded values
40
+ *
41
+ * @return 3 tensors [N] containing the x, y, z coordinates
42
+ */
43
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
44
+ z_order_decode(
45
+ const torch::Tensor& codes
46
+ );
47
+
48
+
49
+ /**
50
+ * Hilbert encode 3D points
51
+ *
52
+ * @param x [N] tensor containing the x coordinates
53
+ * @param y [N] tensor containing the y coordinates
54
+ * @param z [N] tensor containing the z coordinates
55
+ *
56
+ * @return [N] tensor containing the Hilbert encoded values
57
+ */
58
+ torch::Tensor
59
+ hilbert_encode(
60
+ const torch::Tensor& x,
61
+ const torch::Tensor& y,
62
+ const torch::Tensor& z
63
+ );
64
+
65
+
66
+ /**
67
+ * Hilbert decode 3D points
68
+ *
69
+ * @param codes [N] tensor containing the Hilbert encoded values
70
+ *
71
+ * @return 3 tensors [N] containing the x, y, z coordinates
72
+ */
73
+ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
74
+ hilbert_decode(
75
+ const torch::Tensor& codes
76
+ );
extensions/vox2seq/src/ext.cpp ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "api.h"
3
+
4
+
5
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6
+ m.def("z_order_encode", &z_order_encode);
7
+ m.def("z_order_decode", &z_order_decode);
8
+ m.def("hilbert_encode", &hilbert_encode);
9
+ m.def("hilbert_decode", &hilbert_decode);
10
+ }
extensions/vox2seq/src/hilbert.cu ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda.h>
2
+ #include <cuda_runtime.h>
3
+ #include <device_launch_parameters.h>
4
+
5
+ #include <cooperative_groups.h>
6
+ #include <cooperative_groups/memcpy_async.h>
7
+ namespace cg = cooperative_groups;
8
+
9
+ #include "hilbert.h"
10
+
11
+
12
+ // Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit.
13
+ static __device__ uint32_t expandBits(uint32_t v)
14
+ {
15
+ v = (v * 0x00010001u) & 0xFF0000FFu;
16
+ v = (v * 0x00000101u) & 0x0F00F00Fu;
17
+ v = (v * 0x00000011u) & 0xC30C30C3u;
18
+ v = (v * 0x00000005u) & 0x49249249u;
19
+ return v;
20
+ }
21
+
22
+
23
+ // Removes 2 zeros after each bit in a 30-bit integer.
24
+ static __device__ uint32_t extractBits(uint32_t v)
25
+ {
26
+ v = v & 0x49249249;
27
+ v = (v ^ (v >> 2)) & 0x030C30C3u;
28
+ v = (v ^ (v >> 4)) & 0x0300F00Fu;
29
+ v = (v ^ (v >> 8)) & 0x030000FFu;
30
+ v = (v ^ (v >> 16)) & 0x000003FFu;
31
+ return v;
32
+ }
33
+
34
+
35
+ __global__ void hilbert_encode_cuda(
36
+ size_t N,
37
+ const uint32_t* x,
38
+ const uint32_t* y,
39
+ const uint32_t* z,
40
+ uint32_t* codes
41
+ ) {
42
+ size_t thread_id = cg::this_grid().thread_rank();
43
+ if (thread_id >= N) return;
44
+
45
+ uint32_t point[3] = {x[thread_id], y[thread_id], z[thread_id]};
46
+
47
+ uint32_t m = 1 << 9, q, p, t;
48
+
49
+ // Inverse undo excess work
50
+ q = m;
51
+ while (q > 1) {
52
+ p = q - 1;
53
+ for (int i = 0; i < 3; i++) {
54
+ if (point[i] & q) {
55
+ point[0] ^= p; // invert
56
+ } else {
57
+ t = (point[0] ^ point[i]) & p;
58
+ point[0] ^= t;
59
+ point[i] ^= t;
60
+ }
61
+ }
62
+ q >>= 1;
63
+ }
64
+
65
+ // Gray encode
66
+ for (int i = 1; i < 3; i++) {
67
+ point[i] ^= point[i - 1];
68
+ }
69
+ t = 0;
70
+ q = m;
71
+ while (q > 1) {
72
+ if (point[2] & q) {
73
+ t ^= q - 1;
74
+ }
75
+ q >>= 1;
76
+ }
77
+ for (int i = 0; i < 3; i++) {
78
+ point[i] ^= t;
79
+ }
80
+
81
+ // Convert to 3D Hilbert code
82
+ uint32_t xx = expandBits(point[0]);
83
+ uint32_t yy = expandBits(point[1]);
84
+ uint32_t zz = expandBits(point[2]);
85
+
86
+ codes[thread_id] = xx * 4 + yy * 2 + zz;
87
+ }
88
+
89
+
90
+ __global__ void hilbert_decode_cuda(
91
+ size_t N,
92
+ const uint32_t* codes,
93
+ uint32_t* x,
94
+ uint32_t* y,
95
+ uint32_t* z
96
+ ) {
97
+ size_t thread_id = cg::this_grid().thread_rank();
98
+ if (thread_id >= N) return;
99
+
100
+ uint32_t point[3];
101
+ point[0] = extractBits(codes[thread_id] >> 2);
102
+ point[1] = extractBits(codes[thread_id] >> 1);
103
+ point[2] = extractBits(codes[thread_id]);
104
+
105
+ uint32_t m = 2 << 9, q, p, t;
106
+
107
+ // Gray decode by H ^ (H/2)
108
+ t = point[2] >> 1;
109
+ for (int i = 2; i > 0; i--) {
110
+ point[i] ^= point[i - 1];
111
+ }
112
+ point[0] ^= t;
113
+
114
+ // Undo excess work
115
+ q = 2;
116
+ while (q != m) {
117
+ p = q - 1;
118
+ for (int i = 2; i >= 0; i--) {
119
+ if (point[i] & q) {
120
+ point[0] ^= p;
121
+ } else {
122
+ t = (point[0] ^ point[i]) & p;
123
+ point[0] ^= t;
124
+ point[i] ^= t;
125
+ }
126
+ }
127
+ q <<= 1;
128
+ }
129
+
130
+ x[thread_id] = point[0];
131
+ y[thread_id] = point[1];
132
+ z[thread_id] = point[2];
133
+ }
extensions/vox2seq/src/hilbert.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ /**
4
+ * Hilbert encode 3D points
5
+ *
6
+ * @param x [N] tensor containing the x coordinates
7
+ * @param y [N] tensor containing the y coordinates
8
+ * @param z [N] tensor containing the z coordinates
9
+ *
10
+ * @return [N] tensor containing the z-order encoded values
11
+ */
12
+ __global__ void hilbert_encode_cuda(
13
+ size_t N,
14
+ const uint32_t* x,
15
+ const uint32_t* y,
16
+ const uint32_t* z,
17
+ uint32_t* codes
18
+ );
19
+
20
+
21
+ /**
22
+ * Hilbert decode 3D points
23
+ *
24
+ * @param codes [N] tensor containing the z-order encoded values
25
+ * @param x [N] tensor containing the x coordinates
26
+ * @param y [N] tensor containing the y coordinates
27
+ * @param z [N] tensor containing the z coordinates
28
+ */
29
+ __global__ void hilbert_decode_cuda(
30
+ size_t N,
31
+ const uint32_t* codes,
32
+ uint32_t* x,
33
+ uint32_t* y,
34
+ uint32_t* z
35
+ );
extensions/vox2seq/src/z_order.cu ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda.h>
2
+ #include <cuda_runtime.h>
3
+ #include <device_launch_parameters.h>
4
+
5
+ #include <cooperative_groups.h>
6
+ #include <cooperative_groups/memcpy_async.h>
7
+ namespace cg = cooperative_groups;
8
+
9
+ #include "z_order.h"
10
+
11
+
12
+ // Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit.
13
+ static __device__ uint32_t expandBits(uint32_t v)
14
+ {
15
+ v = (v * 0x00010001u) & 0xFF0000FFu;
16
+ v = (v * 0x00000101u) & 0x0F00F00Fu;
17
+ v = (v * 0x00000011u) & 0xC30C30C3u;
18
+ v = (v * 0x00000005u) & 0x49249249u;
19
+ return v;
20
+ }
21
+
22
+
23
+ // Removes 2 zeros after each bit in a 30-bit integer.
24
+ static __device__ uint32_t extractBits(uint32_t v)
25
+ {
26
+ v = v & 0x49249249;
27
+ v = (v ^ (v >> 2)) & 0x030C30C3u;
28
+ v = (v ^ (v >> 4)) & 0x0300F00Fu;
29
+ v = (v ^ (v >> 8)) & 0x030000FFu;
30
+ v = (v ^ (v >> 16)) & 0x000003FFu;
31
+ return v;
32
+ }
33
+
34
+
35
+ __global__ void z_order_encode_cuda(
36
+ size_t N,
37
+ const uint32_t* x,
38
+ const uint32_t* y,
39
+ const uint32_t* z,
40
+ uint32_t* codes
41
+ ) {
42
+ size_t thread_id = cg::this_grid().thread_rank();
43
+ if (thread_id >= N) return;
44
+
45
+ uint32_t xx = expandBits(x[thread_id]);
46
+ uint32_t yy = expandBits(y[thread_id]);
47
+ uint32_t zz = expandBits(z[thread_id]);
48
+
49
+ codes[thread_id] = xx * 4 + yy * 2 + zz;
50
+ }
51
+
52
+
53
+ __global__ void z_order_decode_cuda(
54
+ size_t N,
55
+ const uint32_t* codes,
56
+ uint32_t* x,
57
+ uint32_t* y,
58
+ uint32_t* z
59
+ ) {
60
+ size_t thread_id = cg::this_grid().thread_rank();
61
+ if (thread_id >= N) return;
62
+
63
+ x[thread_id] = extractBits(codes[thread_id] >> 2);
64
+ y[thread_id] = extractBits(codes[thread_id] >> 1);
65
+ z[thread_id] = extractBits(codes[thread_id]);
66
+ }
extensions/vox2seq/src/z_order.h ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ /**
4
+ * Z-order encode 3D points
5
+ *
6
+ * @param x [N] tensor containing the x coordinates
7
+ * @param y [N] tensor containing the y coordinates
8
+ * @param z [N] tensor containing the z coordinates
9
+ *
10
+ * @return [N] tensor containing the z-order encoded values
11
+ */
12
+ __global__ void z_order_encode_cuda(
13
+ size_t N,
14
+ const uint32_t* x,
15
+ const uint32_t* y,
16
+ const uint32_t* z,
17
+ uint32_t* codes
18
+ );
19
+
20
+
21
+ /**
22
+ * Z-order decode 3D points
23
+ *
24
+ * @param codes [N] tensor containing the z-order encoded values
25
+ * @param x [N] tensor containing the x coordinates
26
+ * @param y [N] tensor containing the y coordinates
27
+ * @param z [N] tensor containing the z coordinates
28
+ */
29
+ __global__ void z_order_decode_cuda(
30
+ size_t N,
31
+ const uint32_t* codes,
32
+ uint32_t* x,
33
+ uint32_t* y,
34
+ uint32_t* z
35
+ );
extensions/vox2seq/test.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import vox2seq
3
+
4
+
5
+ if __name__ == "__main__":
6
+ RES = 256
7
+ coords = torch.meshgrid(torch.arange(RES), torch.arange(RES), torch.arange(RES))
8
+ coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda()
9
+ code_z_cuda = vox2seq.encode(coords, mode='z_order')
10
+ code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order')
11
+ code_h_cuda = vox2seq.encode(coords, mode='hilbert')
12
+ code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert')
13
+ assert torch.equal(code_z_cuda, code_z_pytorch)
14
+ assert torch.equal(code_h_cuda, code_h_pytorch)
15
+
16
+ code = torch.arange(RES**3).int().cuda()
17
+ coords_z_cuda = vox2seq.decode(code, mode='z_order')
18
+ coords_z_pytorch = vox2seq.pytorch.decode(code, mode='z_order')
19
+ coords_h_cuda = vox2seq.decode(code, mode='hilbert')
20
+ coords_h_pytorch = vox2seq.pytorch.decode(code, mode='hilbert')
21
+ assert torch.equal(coords_z_cuda, coords_z_pytorch)
22
+ assert torch.equal(coords_h_cuda, coords_h_pytorch)
23
+
24
+ print("All tests passed.")
25
+
extensions/vox2seq/vox2seq/__init__.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import *
3
+ import torch
4
+ from . import _C
5
+ from . import pytorch
6
+
7
+
8
+ @torch.no_grad()
9
+ def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
10
+ """
11
+ Encodes 3D coordinates into a 30-bit code.
12
+
13
+ Args:
14
+ coords: a tensor of shape [N, 3] containing the 3D coordinates.
15
+ permute: the permutation of the coordinates.
16
+ mode: the encoding mode to use.
17
+ """
18
+ assert coords.shape[-1] == 3 and coords.ndim == 2, "Input coordinates must be of shape [N, 3]"
19
+ x = coords[:, permute[0]].int()
20
+ y = coords[:, permute[1]].int()
21
+ z = coords[:, permute[2]].int()
22
+ if mode == 'z_order':
23
+ return _C.z_order_encode(x, y, z)
24
+ elif mode == 'hilbert':
25
+ return _C.hilbert_encode(x, y, z)
26
+ else:
27
+ raise ValueError(f"Unknown encoding mode: {mode}")
28
+
29
+
30
+ @torch.no_grad()
31
+ def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
32
+ """
33
+ Decodes a 30-bit code into 3D coordinates.
34
+
35
+ Args:
36
+ code: a tensor of shape [N] containing the 30-bit code.
37
+ permute: the permutation of the coordinates.
38
+ mode: the decoding mode to use.
39
+ """
40
+ assert code.ndim == 1, "Input code must be of shape [N]"
41
+ if mode == 'z_order':
42
+ coords = _C.z_order_decode(code)
43
+ elif mode == 'hilbert':
44
+ coords = _C.hilbert_decode(code)
45
+ else:
46
+ raise ValueError(f"Unknown decoding mode: {mode}")
47
+ x = coords[permute.index(0)]
48
+ y = coords[permute.index(1)]
49
+ z = coords[permute.index(2)]
50
+ return torch.stack([x, y, z], dim=-1)
extensions/vox2seq/vox2seq/pytorch/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import *
3
+
4
+ from .default import (
5
+ encode,
6
+ decode,
7
+ z_order_encode,
8
+ z_order_decode,
9
+ hilbert_encode,
10
+ hilbert_decode,
11
+ )
12
+
13
+
14
+ @torch.no_grad()
15
+ def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
16
+ """
17
+ Encodes 3D coordinates into a 30-bit code.
18
+
19
+ Args:
20
+ coords: a tensor of shape [N, 3] containing the 3D coordinates.
21
+ permute: the permutation of the coordinates.
22
+ mode: the encoding mode to use.
23
+ """
24
+ if mode == 'z_order':
25
+ return z_order_encode(coords[:, permute], depth=10).int()
26
+ elif mode == 'hilbert':
27
+ return hilbert_encode(coords[:, permute], depth=10).int()
28
+ else:
29
+ raise ValueError(f"Unknown encoding mode: {mode}")
30
+
31
+
32
+ @torch.no_grad()
33
+ def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
34
+ """
35
+ Decodes a 30-bit code into 3D coordinates.
36
+
37
+ Args:
38
+ code: a tensor of shape [N] containing the 30-bit code.
39
+ permute: the permutation of the coordinates.
40
+ mode: the decoding mode to use.
41
+ """
42
+ if mode == 'z_order':
43
+ return z_order_decode(code, depth=10)[:, permute].float()
44
+ elif mode == 'hilbert':
45
+ return hilbert_decode(code, depth=10)[:, permute].float()
46
+ else:
47
+ raise ValueError(f"Unknown decoding mode: {mode}")
48
+
extensions/vox2seq/vox2seq/pytorch/default.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .z_order import xyz2key as z_order_encode_
3
+ from .z_order import key2xyz as z_order_decode_
4
+ from .hilbert import encode as hilbert_encode_
5
+ from .hilbert import decode as hilbert_decode_
6
+
7
+
8
+ @torch.inference_mode()
9
+ def encode(grid_coord, batch=None, depth=16, order="z"):
10
+ assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
11
+ if order == "z":
12
+ code = z_order_encode(grid_coord, depth=depth)
13
+ elif order == "z-trans":
14
+ code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
15
+ elif order == "hilbert":
16
+ code = hilbert_encode(grid_coord, depth=depth)
17
+ elif order == "hilbert-trans":
18
+ code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
19
+ else:
20
+ raise NotImplementedError
21
+ if batch is not None:
22
+ batch = batch.long()
23
+ code = batch << depth * 3 | code
24
+ return code
25
+
26
+
27
+ @torch.inference_mode()
28
+ def decode(code, depth=16, order="z"):
29
+ assert order in {"z", "hilbert"}
30
+ batch = code >> depth * 3
31
+ code = code & ((1 << depth * 3) - 1)
32
+ if order == "z":
33
+ grid_coord = z_order_decode(code, depth=depth)
34
+ elif order == "hilbert":
35
+ grid_coord = hilbert_decode(code, depth=depth)
36
+ else:
37
+ raise NotImplementedError
38
+ return grid_coord, batch
39
+
40
+
41
+ def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
42
+ x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
43
+ # we block the support to batch, maintain batched code in Point class
44
+ code = z_order_encode_(x, y, z, b=None, depth=depth)
45
+ return code
46
+
47
+
48
+ def z_order_decode(code: torch.Tensor, depth):
49
+ x, y, z, _ = z_order_decode_(code, depth=depth)
50
+ grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3)
51
+ return grid_coord
52
+
53
+
54
+ def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
55
+ return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
56
+
57
+
58
+ def hilbert_decode(code: torch.Tensor, depth: int = 16):
59
+ return hilbert_decode_(code, num_dims=3, num_bits=depth)
extensions/vox2seq/vox2seq/pytorch/hilbert.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hilbert Order
3
+ Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve
4
+
5
+ Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu
6
+ Please cite our work if the code is helpful to you.
7
+ """
8
+
9
+ import torch
10
+
11
+
12
+ def right_shift(binary, k=1, axis=-1):
13
+ """Right shift an array of binary values.
14
+
15
+ Parameters:
16
+ -----------
17
+ binary: An ndarray of binary values.
18
+
19
+ k: The number of bits to shift. Default 1.
20
+
21
+ axis: The axis along which to shift. Default -1.
22
+
23
+ Returns:
24
+ --------
25
+ Returns an ndarray with zero prepended and the ends truncated, along
26
+ whatever axis was specified."""
27
+
28
+ # If we're shifting the whole thing, just return zeros.
29
+ if binary.shape[axis] <= k:
30
+ return torch.zeros_like(binary)
31
+
32
+ # Determine the padding pattern.
33
+ # padding = [(0,0)] * len(binary.shape)
34
+ # padding[axis] = (k,0)
35
+
36
+ # Determine the slicing pattern to eliminate just the last one.
37
+ slicing = [slice(None)] * len(binary.shape)
38
+ slicing[axis] = slice(None, -k)
39
+ shifted = torch.nn.functional.pad(
40
+ binary[tuple(slicing)], (k, 0), mode="constant", value=0
41
+ )
42
+
43
+ return shifted
44
+
45
+
46
+ def binary2gray(binary, axis=-1):
47
+ """Convert an array of binary values into Gray codes.
48
+
49
+ This uses the classic X ^ (X >> 1) trick to compute the Gray code.
50
+
51
+ Parameters:
52
+ -----------
53
+ binary: An ndarray of binary values.
54
+
55
+ axis: The axis along which to compute the gray code. Default=-1.
56
+
57
+ Returns:
58
+ --------
59
+ Returns an ndarray of Gray codes.
60
+ """
61
+ shifted = right_shift(binary, axis=axis)
62
+
63
+ # Do the X ^ (X >> 1) trick.
64
+ gray = torch.logical_xor(binary, shifted)
65
+
66
+ return gray
67
+
68
+
69
+ def gray2binary(gray, axis=-1):
70
+ """Convert an array of Gray codes back into binary values.
71
+
72
+ Parameters:
73
+ -----------
74
+ gray: An ndarray of gray codes.
75
+
76
+ axis: The axis along which to perform Gray decoding. Default=-1.
77
+
78
+ Returns:
79
+ --------
80
+ Returns an ndarray of binary values.
81
+ """
82
+
83
+ # Loop the log2(bits) number of times necessary, with shift and xor.
84
+ shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
85
+ while shift > 0:
86
+ gray = torch.logical_xor(gray, right_shift(gray, shift))
87
+ shift = torch.div(shift, 2, rounding_mode="floor")
88
+ return gray
89
+
90
+
91
+ def encode(locs, num_dims, num_bits):
92
+ """Decode an array of locations in a hypercube into a Hilbert integer.
93
+
94
+ This is a vectorized-ish version of the Hilbert curve implementation by John
95
+ Skilling as described in:
96
+
97
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
98
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
99
+
100
+ Params:
101
+ -------
102
+ locs - An ndarray of locations in a hypercube of num_dims dimensions, in
103
+ which each dimension runs from 0 to 2**num_bits-1. The shape can
104
+ be arbitrary, as long as the last dimension of the same has size
105
+ num_dims.
106
+
107
+ num_dims - The dimensionality of the hypercube. Integer.
108
+
109
+ num_bits - The number of bits for each dimension. Integer.
110
+
111
+ Returns:
112
+ --------
113
+ The output is an ndarray of uint64 integers with the same shape as the
114
+ input, excluding the last dimension, which needs to be num_dims.
115
+ """
116
+
117
+ # Keep around the original shape for later.
118
+ orig_shape = locs.shape
119
+ bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
120
+ bitpack_mask_rev = bitpack_mask.flip(-1)
121
+
122
+ if orig_shape[-1] != num_dims:
123
+ raise ValueError(
124
+ """
125
+ The shape of locs was surprising in that the last dimension was of size
126
+ %d, but num_dims=%d. These need to be equal.
127
+ """
128
+ % (orig_shape[-1], num_dims)
129
+ )
130
+
131
+ if num_dims * num_bits > 63:
132
+ raise ValueError(
133
+ """
134
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
135
+ into a int64. Are you sure you need that many points on your Hilbert
136
+ curve?
137
+ """
138
+ % (num_dims, num_bits, num_dims * num_bits)
139
+ )
140
+
141
+ # Treat the location integers as 64-bit unsigned and then split them up into
142
+ # a sequence of uint8s. Preserve the association by dimension.
143
+ locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
144
+
145
+ # Now turn these into bits and truncate to num_bits.
146
+ gray = (
147
+ locs_uint8.unsqueeze(-1)
148
+ .bitwise_and(bitpack_mask_rev)
149
+ .ne(0)
150
+ .byte()
151
+ .flatten(-2, -1)[..., -num_bits:]
152
+ )
153
+
154
+ # Run the decoding process the other way.
155
+ # Iterate forwards through the bits.
156
+ for bit in range(0, num_bits):
157
+ # Iterate forwards through the dimensions.
158
+ for dim in range(0, num_dims):
159
+ # Identify which ones have this bit active.
160
+ mask = gray[:, dim, bit]
161
+
162
+ # Where this bit is on, invert the 0 dimension for lower bits.
163
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
164
+ gray[:, 0, bit + 1 :], mask[:, None]
165
+ )
166
+
167
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
168
+ to_flip = torch.logical_and(
169
+ torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
170
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
171
+ )
172
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
173
+ gray[:, dim, bit + 1 :], to_flip
174
+ )
175
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
176
+
177
+ # Now flatten out.
178
+ gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
179
+
180
+ # Convert Gray back to binary.
181
+ hh_bin = gray2binary(gray)
182
+
183
+ # Pad back out to 64 bits.
184
+ extra_dims = 64 - num_bits * num_dims
185
+ padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
186
+
187
+ # Convert binary values into uint8s.
188
+ hh_uint8 = (
189
+ (padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
190
+ .sum(2)
191
+ .squeeze()
192
+ .type(torch.uint8)
193
+ )
194
+
195
+ # Convert uint8s into uint64s.
196
+ hh_uint64 = hh_uint8.view(torch.int64).squeeze()
197
+
198
+ return hh_uint64
199
+
200
+
201
+ def decode(hilberts, num_dims, num_bits):
202
+ """Decode an array of Hilbert integers into locations in a hypercube.
203
+
204
+ This is a vectorized-ish version of the Hilbert curve implementation by John
205
+ Skilling as described in:
206
+
207
+ Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
208
+ Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
209
+
210
+ Params:
211
+ -------
212
+ hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
213
+ cannot have fewer bits than num_dims * num_bits.
214
+
215
+ num_dims - The dimensionality of the hypercube. Integer.
216
+
217
+ num_bits - The number of bits for each dimension. Integer.
218
+
219
+ Returns:
220
+ --------
221
+ The output is an ndarray of unsigned integers with the same shape as hilberts
222
+ but with an additional dimension of size num_dims.
223
+ """
224
+
225
+ if num_dims * num_bits > 64:
226
+ raise ValueError(
227
+ """
228
+ num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
229
+ into a uint64. Are you sure you need that many points on your Hilbert
230
+ curve?
231
+ """
232
+ % (num_dims, num_bits)
233
+ )
234
+
235
+ # Handle the case where we got handed a naked integer.
236
+ hilberts = torch.atleast_1d(hilberts)
237
+
238
+ # Keep around the shape for later.
239
+ orig_shape = hilberts.shape
240
+ bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
241
+ bitpack_mask_rev = bitpack_mask.flip(-1)
242
+
243
+ # Treat each of the hilberts as a s equence of eight uint8.
244
+ # This treats all of the inputs as uint64 and makes things uniform.
245
+ hh_uint8 = (
246
+ hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
247
+ )
248
+
249
+ # Turn these lists of uints into lists of bits and then truncate to the size
250
+ # we actually need for using Skilling's procedure.
251
+ hh_bits = (
252
+ hh_uint8.unsqueeze(-1)
253
+ .bitwise_and(bitpack_mask_rev)
254
+ .ne(0)
255
+ .byte()
256
+ .flatten(-2, -1)[:, -num_dims * num_bits :]
257
+ )
258
+
259
+ # Take the sequence of bits and Gray-code it.
260
+ gray = binary2gray(hh_bits)
261
+
262
+ # There has got to be a better way to do this.
263
+ # I could index them differently, but the eventual packbits likes it this way.
264
+ gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
265
+
266
+ # Iterate backwards through the bits.
267
+ for bit in range(num_bits - 1, -1, -1):
268
+ # Iterate backwards through the dimensions.
269
+ for dim in range(num_dims - 1, -1, -1):
270
+ # Identify which ones have this bit active.
271
+ mask = gray[:, dim, bit]
272
+
273
+ # Where this bit is on, invert the 0 dimension for lower bits.
274
+ gray[:, 0, bit + 1 :] = torch.logical_xor(
275
+ gray[:, 0, bit + 1 :], mask[:, None]
276
+ )
277
+
278
+ # Where the bit is off, exchange the lower bits with the 0 dimension.
279
+ to_flip = torch.logical_and(
280
+ torch.logical_not(mask[:, None]),
281
+ torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
282
+ )
283
+ gray[:, dim, bit + 1 :] = torch.logical_xor(
284
+ gray[:, dim, bit + 1 :], to_flip
285
+ )
286
+ gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
287
+
288
+ # Pad back out to 64 bits.
289
+ extra_dims = 64 - num_bits
290
+ padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
291
+
292
+ # Now chop these up into blocks of 8.
293
+ locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
294
+
295
+ # Take those blocks and turn them unto uint8s.
296
+ # from IPython import embed; embed()
297
+ locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
298
+
299
+ # Finally, treat these as uint64s.
300
+ flat_locs = locs_uint8.view(torch.int64)
301
+
302
+ # Return them in the expected shape.
303
+ return flat_locs.reshape((*orig_shape, num_dims))
extensions/vox2seq/vox2seq/pytorch/z_order.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Octree-based Sparse Convolutional Neural Networks
3
+ # Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Peng-Shuai Wang
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from typing import Optional, Union
10
+
11
+
12
+ class KeyLUT:
13
+ def __init__(self):
14
+ r256 = torch.arange(256, dtype=torch.int64)
15
+ r512 = torch.arange(512, dtype=torch.int64)
16
+ zero = torch.zeros(256, dtype=torch.int64)
17
+ device = torch.device("cpu")
18
+
19
+ self._encode = {
20
+ device: (
21
+ self.xyz2key(r256, zero, zero, 8),
22
+ self.xyz2key(zero, r256, zero, 8),
23
+ self.xyz2key(zero, zero, r256, 8),
24
+ )
25
+ }
26
+ self._decode = {device: self.key2xyz(r512, 9)}
27
+
28
+ def encode_lut(self, device=torch.device("cpu")):
29
+ if device not in self._encode:
30
+ cpu = torch.device("cpu")
31
+ self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
32
+ return self._encode[device]
33
+
34
+ def decode_lut(self, device=torch.device("cpu")):
35
+ if device not in self._decode:
36
+ cpu = torch.device("cpu")
37
+ self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
38
+ return self._decode[device]
39
+
40
+ def xyz2key(self, x, y, z, depth):
41
+ key = torch.zeros_like(x)
42
+ for i in range(depth):
43
+ mask = 1 << i
44
+ key = (
45
+ key
46
+ | ((x & mask) << (2 * i + 2))
47
+ | ((y & mask) << (2 * i + 1))
48
+ | ((z & mask) << (2 * i + 0))
49
+ )
50
+ return key
51
+
52
+ def key2xyz(self, key, depth):
53
+ x = torch.zeros_like(key)
54
+ y = torch.zeros_like(key)
55
+ z = torch.zeros_like(key)
56
+ for i in range(depth):
57
+ x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
58
+ y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
59
+ z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
60
+ return x, y, z
61
+
62
+
63
+ _key_lut = KeyLUT()
64
+
65
+
66
+ def xyz2key(
67
+ x: torch.Tensor,
68
+ y: torch.Tensor,
69
+ z: torch.Tensor,
70
+ b: Optional[Union[torch.Tensor, int]] = None,
71
+ depth: int = 16,
72
+ ):
73
+ r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
74
+ based on pre-computed look up tables. The speed of this function is much
75
+ faster than the method based on for-loop.
76
+
77
+ Args:
78
+ x (torch.Tensor): The x coordinate.
79
+ y (torch.Tensor): The y coordinate.
80
+ z (torch.Tensor): The z coordinate.
81
+ b (torch.Tensor or int): The batch index of the coordinates, and should be
82
+ smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
83
+ :attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
84
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
85
+ """
86
+
87
+ EX, EY, EZ = _key_lut.encode_lut(x.device)
88
+ x, y, z = x.long(), y.long(), z.long()
89
+
90
+ mask = 255 if depth > 8 else (1 << depth) - 1
91
+ key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
92
+ if depth > 8:
93
+ mask = (1 << (depth - 8)) - 1
94
+ key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
95
+ key = key16 << 24 | key
96
+
97
+ if b is not None:
98
+ b = b.long()
99
+ key = b << 48 | key
100
+
101
+ return key
102
+
103
+
104
+ def key2xyz(key: torch.Tensor, depth: int = 16):
105
+ r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
106
+ and the batch index based on pre-computed look up tables.
107
+
108
+ Args:
109
+ key (torch.Tensor): The shuffled key.
110
+ depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
111
+ """
112
+
113
+ DX, DY, DZ = _key_lut.decode_lut(key.device)
114
+ x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
115
+
116
+ b = key >> 48
117
+ key = key & ((1 << 48) - 1)
118
+
119
+ n = (depth + 2) // 3
120
+ for i in range(n):
121
+ k = key >> (i * 9) & 511
122
+ x = x | (DX[k] << (i * 3))
123
+ y = y | (DY[k] << (i * 3))
124
+ z = z | (DZ[k] << (i * 3))
125
+
126
+ return x, y, z, b