Docs + runtime warning for LFS example assets
Browse files- extensions/vox2seq/benchmark.py +45 -0
- extensions/vox2seq/setup.py +34 -0
- extensions/vox2seq/src/api.cu +92 -0
- extensions/vox2seq/src/api.h +76 -0
- extensions/vox2seq/src/ext.cpp +10 -0
- extensions/vox2seq/src/hilbert.cu +133 -0
- extensions/vox2seq/src/hilbert.h +35 -0
- extensions/vox2seq/src/z_order.cu +66 -0
- extensions/vox2seq/src/z_order.h +35 -0
- extensions/vox2seq/test.py +25 -0
- extensions/vox2seq/vox2seq/__init__.py +50 -0
- extensions/vox2seq/vox2seq/pytorch/__init__.py +48 -0
- extensions/vox2seq/vox2seq/pytorch/default.py +59 -0
- extensions/vox2seq/vox2seq/pytorch/hilbert.py +303 -0
- extensions/vox2seq/vox2seq/pytorch/z_order.py +126 -0
extensions/vox2seq/benchmark.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import vox2seq
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
if __name__ == "__main__":
|
| 7 |
+
stats = {
|
| 8 |
+
'z_order_cuda': [],
|
| 9 |
+
'z_order_pytorch': [],
|
| 10 |
+
'hilbert_cuda': [],
|
| 11 |
+
'hilbert_pytorch': [],
|
| 12 |
+
}
|
| 13 |
+
RES = [16, 32, 64, 128, 256]
|
| 14 |
+
for res in RES:
|
| 15 |
+
coords = torch.meshgrid(torch.arange(res), torch.arange(res), torch.arange(res))
|
| 16 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda()
|
| 17 |
+
|
| 18 |
+
start = time.time()
|
| 19 |
+
for _ in range(100):
|
| 20 |
+
code_z_cuda = vox2seq.encode(coords, mode='z_order').cuda()
|
| 21 |
+
torch.cuda.synchronize()
|
| 22 |
+
stats['z_order_cuda'].append((time.time() - start) / 100)
|
| 23 |
+
|
| 24 |
+
start = time.time()
|
| 25 |
+
for _ in range(100):
|
| 26 |
+
code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order').cuda()
|
| 27 |
+
torch.cuda.synchronize()
|
| 28 |
+
stats['z_order_pytorch'].append((time.time() - start) / 100)
|
| 29 |
+
|
| 30 |
+
start = time.time()
|
| 31 |
+
for _ in range(100):
|
| 32 |
+
code_h_cuda = vox2seq.encode(coords, mode='hilbert').cuda()
|
| 33 |
+
torch.cuda.synchronize()
|
| 34 |
+
stats['hilbert_cuda'].append((time.time() - start) / 100)
|
| 35 |
+
|
| 36 |
+
start = time.time()
|
| 37 |
+
for _ in range(100):
|
| 38 |
+
code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert').cuda()
|
| 39 |
+
torch.cuda.synchronize()
|
| 40 |
+
stats['hilbert_pytorch'].append((time.time() - start) / 100)
|
| 41 |
+
|
| 42 |
+
print(f"{'Resolution':<12}{'Z-Order (CUDA)':<24}{'Z-Order (PyTorch)':<24}{'Hilbert (CUDA)':<24}{'Hilbert (PyTorch)':<24}")
|
| 43 |
+
for res, z_order_cuda, z_order_pytorch, hilbert_cuda, hilbert_pytorch in zip(RES, stats['z_order_cuda'], stats['z_order_pytorch'], stats['hilbert_cuda'], stats['hilbert_pytorch']):
|
| 44 |
+
print(f"{res:<12}{z_order_cuda:<24.6f}{z_order_pytorch:<24.6f}{hilbert_cuda:<24.6f}{hilbert_pytorch:<24.6f}")
|
| 45 |
+
|
extensions/vox2seq/setup.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#
|
| 2 |
+
# Copyright (C) 2023, Inria
|
| 3 |
+
# GRAPHDECO research group, https://team.inria.fr/graphdeco
|
| 4 |
+
# All rights reserved.
|
| 5 |
+
#
|
| 6 |
+
# This software is free for non-commercial, research and evaluation use
|
| 7 |
+
# under the terms of the LICENSE.md file.
|
| 8 |
+
#
|
| 9 |
+
# For inquiries contact george.drettakis@inria.fr
|
| 10 |
+
#
|
| 11 |
+
|
| 12 |
+
from setuptools import setup
|
| 13 |
+
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
| 14 |
+
import os
|
| 15 |
+
os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
|
| 17 |
+
setup(
|
| 18 |
+
name="vox2seq",
|
| 19 |
+
packages=['vox2seq', 'vox2seq.pytorch'],
|
| 20 |
+
ext_modules=[
|
| 21 |
+
CUDAExtension(
|
| 22 |
+
name="vox2seq._C",
|
| 23 |
+
sources=[
|
| 24 |
+
"src/api.cu",
|
| 25 |
+
"src/z_order.cu",
|
| 26 |
+
"src/hilbert.cu",
|
| 27 |
+
"src/ext.cpp",
|
| 28 |
+
],
|
| 29 |
+
)
|
| 30 |
+
],
|
| 31 |
+
cmdclass={
|
| 32 |
+
'build_ext': BuildExtension
|
| 33 |
+
}
|
| 34 |
+
)
|
extensions/vox2seq/src/api.cu
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "api.h"
|
| 3 |
+
#include "z_order.h"
|
| 4 |
+
#include "hilbert.h"
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
torch::Tensor
|
| 8 |
+
z_order_encode(
|
| 9 |
+
const torch::Tensor& x,
|
| 10 |
+
const torch::Tensor& y,
|
| 11 |
+
const torch::Tensor& z
|
| 12 |
+
) {
|
| 13 |
+
// Allocate output tensor
|
| 14 |
+
torch::Tensor codes = torch::empty_like(x);
|
| 15 |
+
|
| 16 |
+
// Call CUDA kernel
|
| 17 |
+
z_order_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
| 18 |
+
x.size(0),
|
| 19 |
+
reinterpret_cast<uint32_t*>(x.contiguous().data_ptr<int>()),
|
| 20 |
+
reinterpret_cast<uint32_t*>(y.contiguous().data_ptr<int>()),
|
| 21 |
+
reinterpret_cast<uint32_t*>(z.contiguous().data_ptr<int>()),
|
| 22 |
+
reinterpret_cast<uint32_t*>(codes.data_ptr<int>())
|
| 23 |
+
);
|
| 24 |
+
|
| 25 |
+
return codes;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
| 30 |
+
z_order_decode(
|
| 31 |
+
const torch::Tensor& codes
|
| 32 |
+
) {
|
| 33 |
+
// Allocate output tensors
|
| 34 |
+
torch::Tensor x = torch::empty_like(codes);
|
| 35 |
+
torch::Tensor y = torch::empty_like(codes);
|
| 36 |
+
torch::Tensor z = torch::empty_like(codes);
|
| 37 |
+
|
| 38 |
+
// Call CUDA kernel
|
| 39 |
+
z_order_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
| 40 |
+
codes.size(0),
|
| 41 |
+
reinterpret_cast<uint32_t*>(codes.contiguous().data_ptr<int>()),
|
| 42 |
+
reinterpret_cast<uint32_t*>(x.data_ptr<int>()),
|
| 43 |
+
reinterpret_cast<uint32_t*>(y.data_ptr<int>()),
|
| 44 |
+
reinterpret_cast<uint32_t*>(z.data_ptr<int>())
|
| 45 |
+
);
|
| 46 |
+
|
| 47 |
+
return std::make_tuple(x, y, z);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
torch::Tensor
|
| 52 |
+
hilbert_encode(
|
| 53 |
+
const torch::Tensor& x,
|
| 54 |
+
const torch::Tensor& y,
|
| 55 |
+
const torch::Tensor& z
|
| 56 |
+
) {
|
| 57 |
+
// Allocate output tensor
|
| 58 |
+
torch::Tensor codes = torch::empty_like(x);
|
| 59 |
+
|
| 60 |
+
// Call CUDA kernel
|
| 61 |
+
hilbert_encode_cuda<<<(x.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
| 62 |
+
x.size(0),
|
| 63 |
+
reinterpret_cast<uint32_t*>(x.contiguous().data_ptr<int>()),
|
| 64 |
+
reinterpret_cast<uint32_t*>(y.contiguous().data_ptr<int>()),
|
| 65 |
+
reinterpret_cast<uint32_t*>(z.contiguous().data_ptr<int>()),
|
| 66 |
+
reinterpret_cast<uint32_t*>(codes.data_ptr<int>())
|
| 67 |
+
);
|
| 68 |
+
|
| 69 |
+
return codes;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
| 74 |
+
hilbert_decode(
|
| 75 |
+
const torch::Tensor& codes
|
| 76 |
+
) {
|
| 77 |
+
// Allocate output tensors
|
| 78 |
+
torch::Tensor x = torch::empty_like(codes);
|
| 79 |
+
torch::Tensor y = torch::empty_like(codes);
|
| 80 |
+
torch::Tensor z = torch::empty_like(codes);
|
| 81 |
+
|
| 82 |
+
// Call CUDA kernel
|
| 83 |
+
hilbert_decode_cuda<<<(codes.size(0) + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE>>>(
|
| 84 |
+
codes.size(0),
|
| 85 |
+
reinterpret_cast<uint32_t*>(codes.contiguous().data_ptr<int>()),
|
| 86 |
+
reinterpret_cast<uint32_t*>(x.data_ptr<int>()),
|
| 87 |
+
reinterpret_cast<uint32_t*>(y.data_ptr<int>()),
|
| 88 |
+
reinterpret_cast<uint32_t*>(z.data_ptr<int>())
|
| 89 |
+
);
|
| 90 |
+
|
| 91 |
+
return std::make_tuple(x, y, z);
|
| 92 |
+
}
|
extensions/vox2seq/src/api.h
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Serialize a voxel grid
|
| 3 |
+
*
|
| 4 |
+
* Copyright (C) 2024, Jianfeng XIANG <belljig@outlook.com>
|
| 5 |
+
* All rights reserved.
|
| 6 |
+
*
|
| 7 |
+
* Licensed under The MIT License [see LICENSE for details]
|
| 8 |
+
*
|
| 9 |
+
* Written by Jianfeng XIANG
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
#pragma once
|
| 13 |
+
#include <torch/extension.h>
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
#define BLOCK_SIZE 256
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
/**
|
| 20 |
+
* Z-order encode 3D points
|
| 21 |
+
*
|
| 22 |
+
* @param x [N] tensor containing the x coordinates
|
| 23 |
+
* @param y [N] tensor containing the y coordinates
|
| 24 |
+
* @param z [N] tensor containing the z coordinates
|
| 25 |
+
*
|
| 26 |
+
* @return [N] tensor containing the z-order encoded values
|
| 27 |
+
*/
|
| 28 |
+
torch::Tensor
|
| 29 |
+
z_order_encode(
|
| 30 |
+
const torch::Tensor& x,
|
| 31 |
+
const torch::Tensor& y,
|
| 32 |
+
const torch::Tensor& z
|
| 33 |
+
);
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
/**
|
| 37 |
+
* Z-order decode 3D points
|
| 38 |
+
*
|
| 39 |
+
* @param codes [N] tensor containing the z-order encoded values
|
| 40 |
+
*
|
| 41 |
+
* @return 3 tensors [N] containing the x, y, z coordinates
|
| 42 |
+
*/
|
| 43 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
| 44 |
+
z_order_decode(
|
| 45 |
+
const torch::Tensor& codes
|
| 46 |
+
);
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
/**
|
| 50 |
+
* Hilbert encode 3D points
|
| 51 |
+
*
|
| 52 |
+
* @param x [N] tensor containing the x coordinates
|
| 53 |
+
* @param y [N] tensor containing the y coordinates
|
| 54 |
+
* @param z [N] tensor containing the z coordinates
|
| 55 |
+
*
|
| 56 |
+
* @return [N] tensor containing the Hilbert encoded values
|
| 57 |
+
*/
|
| 58 |
+
torch::Tensor
|
| 59 |
+
hilbert_encode(
|
| 60 |
+
const torch::Tensor& x,
|
| 61 |
+
const torch::Tensor& y,
|
| 62 |
+
const torch::Tensor& z
|
| 63 |
+
);
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
/**
|
| 67 |
+
* Hilbert decode 3D points
|
| 68 |
+
*
|
| 69 |
+
* @param codes [N] tensor containing the Hilbert encoded values
|
| 70 |
+
*
|
| 71 |
+
* @return 3 tensors [N] containing the x, y, z coordinates
|
| 72 |
+
*/
|
| 73 |
+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
|
| 74 |
+
hilbert_decode(
|
| 75 |
+
const torch::Tensor& codes
|
| 76 |
+
);
|
extensions/vox2seq/src/ext.cpp
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/extension.h>
|
| 2 |
+
#include "api.h"
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 6 |
+
m.def("z_order_encode", &z_order_encode);
|
| 7 |
+
m.def("z_order_decode", &z_order_decode);
|
| 8 |
+
m.def("hilbert_encode", &hilbert_encode);
|
| 9 |
+
m.def("hilbert_decode", &hilbert_decode);
|
| 10 |
+
}
|
extensions/vox2seq/src/hilbert.cu
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <cuda.h>
|
| 2 |
+
#include <cuda_runtime.h>
|
| 3 |
+
#include <device_launch_parameters.h>
|
| 4 |
+
|
| 5 |
+
#include <cooperative_groups.h>
|
| 6 |
+
#include <cooperative_groups/memcpy_async.h>
|
| 7 |
+
namespace cg = cooperative_groups;
|
| 8 |
+
|
| 9 |
+
#include "hilbert.h"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit.
|
| 13 |
+
static __device__ uint32_t expandBits(uint32_t v)
|
| 14 |
+
{
|
| 15 |
+
v = (v * 0x00010001u) & 0xFF0000FFu;
|
| 16 |
+
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
| 17 |
+
v = (v * 0x00000011u) & 0xC30C30C3u;
|
| 18 |
+
v = (v * 0x00000005u) & 0x49249249u;
|
| 19 |
+
return v;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
// Removes 2 zeros after each bit in a 30-bit integer.
|
| 24 |
+
static __device__ uint32_t extractBits(uint32_t v)
|
| 25 |
+
{
|
| 26 |
+
v = v & 0x49249249;
|
| 27 |
+
v = (v ^ (v >> 2)) & 0x030C30C3u;
|
| 28 |
+
v = (v ^ (v >> 4)) & 0x0300F00Fu;
|
| 29 |
+
v = (v ^ (v >> 8)) & 0x030000FFu;
|
| 30 |
+
v = (v ^ (v >> 16)) & 0x000003FFu;
|
| 31 |
+
return v;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
__global__ void hilbert_encode_cuda(
|
| 36 |
+
size_t N,
|
| 37 |
+
const uint32_t* x,
|
| 38 |
+
const uint32_t* y,
|
| 39 |
+
const uint32_t* z,
|
| 40 |
+
uint32_t* codes
|
| 41 |
+
) {
|
| 42 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
| 43 |
+
if (thread_id >= N) return;
|
| 44 |
+
|
| 45 |
+
uint32_t point[3] = {x[thread_id], y[thread_id], z[thread_id]};
|
| 46 |
+
|
| 47 |
+
uint32_t m = 1 << 9, q, p, t;
|
| 48 |
+
|
| 49 |
+
// Inverse undo excess work
|
| 50 |
+
q = m;
|
| 51 |
+
while (q > 1) {
|
| 52 |
+
p = q - 1;
|
| 53 |
+
for (int i = 0; i < 3; i++) {
|
| 54 |
+
if (point[i] & q) {
|
| 55 |
+
point[0] ^= p; // invert
|
| 56 |
+
} else {
|
| 57 |
+
t = (point[0] ^ point[i]) & p;
|
| 58 |
+
point[0] ^= t;
|
| 59 |
+
point[i] ^= t;
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
q >>= 1;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
// Gray encode
|
| 66 |
+
for (int i = 1; i < 3; i++) {
|
| 67 |
+
point[i] ^= point[i - 1];
|
| 68 |
+
}
|
| 69 |
+
t = 0;
|
| 70 |
+
q = m;
|
| 71 |
+
while (q > 1) {
|
| 72 |
+
if (point[2] & q) {
|
| 73 |
+
t ^= q - 1;
|
| 74 |
+
}
|
| 75 |
+
q >>= 1;
|
| 76 |
+
}
|
| 77 |
+
for (int i = 0; i < 3; i++) {
|
| 78 |
+
point[i] ^= t;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// Convert to 3D Hilbert code
|
| 82 |
+
uint32_t xx = expandBits(point[0]);
|
| 83 |
+
uint32_t yy = expandBits(point[1]);
|
| 84 |
+
uint32_t zz = expandBits(point[2]);
|
| 85 |
+
|
| 86 |
+
codes[thread_id] = xx * 4 + yy * 2 + zz;
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
__global__ void hilbert_decode_cuda(
|
| 91 |
+
size_t N,
|
| 92 |
+
const uint32_t* codes,
|
| 93 |
+
uint32_t* x,
|
| 94 |
+
uint32_t* y,
|
| 95 |
+
uint32_t* z
|
| 96 |
+
) {
|
| 97 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
| 98 |
+
if (thread_id >= N) return;
|
| 99 |
+
|
| 100 |
+
uint32_t point[3];
|
| 101 |
+
point[0] = extractBits(codes[thread_id] >> 2);
|
| 102 |
+
point[1] = extractBits(codes[thread_id] >> 1);
|
| 103 |
+
point[2] = extractBits(codes[thread_id]);
|
| 104 |
+
|
| 105 |
+
uint32_t m = 2 << 9, q, p, t;
|
| 106 |
+
|
| 107 |
+
// Gray decode by H ^ (H/2)
|
| 108 |
+
t = point[2] >> 1;
|
| 109 |
+
for (int i = 2; i > 0; i--) {
|
| 110 |
+
point[i] ^= point[i - 1];
|
| 111 |
+
}
|
| 112 |
+
point[0] ^= t;
|
| 113 |
+
|
| 114 |
+
// Undo excess work
|
| 115 |
+
q = 2;
|
| 116 |
+
while (q != m) {
|
| 117 |
+
p = q - 1;
|
| 118 |
+
for (int i = 2; i >= 0; i--) {
|
| 119 |
+
if (point[i] & q) {
|
| 120 |
+
point[0] ^= p;
|
| 121 |
+
} else {
|
| 122 |
+
t = (point[0] ^ point[i]) & p;
|
| 123 |
+
point[0] ^= t;
|
| 124 |
+
point[i] ^= t;
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
q <<= 1;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
x[thread_id] = point[0];
|
| 131 |
+
y[thread_id] = point[1];
|
| 132 |
+
z[thread_id] = point[2];
|
| 133 |
+
}
|
extensions/vox2seq/src/hilbert.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
/**
|
| 4 |
+
* Hilbert encode 3D points
|
| 5 |
+
*
|
| 6 |
+
* @param x [N] tensor containing the x coordinates
|
| 7 |
+
* @param y [N] tensor containing the y coordinates
|
| 8 |
+
* @param z [N] tensor containing the z coordinates
|
| 9 |
+
*
|
| 10 |
+
* @return [N] tensor containing the z-order encoded values
|
| 11 |
+
*/
|
| 12 |
+
__global__ void hilbert_encode_cuda(
|
| 13 |
+
size_t N,
|
| 14 |
+
const uint32_t* x,
|
| 15 |
+
const uint32_t* y,
|
| 16 |
+
const uint32_t* z,
|
| 17 |
+
uint32_t* codes
|
| 18 |
+
);
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
/**
|
| 22 |
+
* Hilbert decode 3D points
|
| 23 |
+
*
|
| 24 |
+
* @param codes [N] tensor containing the z-order encoded values
|
| 25 |
+
* @param x [N] tensor containing the x coordinates
|
| 26 |
+
* @param y [N] tensor containing the y coordinates
|
| 27 |
+
* @param z [N] tensor containing the z coordinates
|
| 28 |
+
*/
|
| 29 |
+
__global__ void hilbert_decode_cuda(
|
| 30 |
+
size_t N,
|
| 31 |
+
const uint32_t* codes,
|
| 32 |
+
uint32_t* x,
|
| 33 |
+
uint32_t* y,
|
| 34 |
+
uint32_t* z
|
| 35 |
+
);
|
extensions/vox2seq/src/z_order.cu
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <cuda.h>
|
| 2 |
+
#include <cuda_runtime.h>
|
| 3 |
+
#include <device_launch_parameters.h>
|
| 4 |
+
|
| 5 |
+
#include <cooperative_groups.h>
|
| 6 |
+
#include <cooperative_groups/memcpy_async.h>
|
| 7 |
+
namespace cg = cooperative_groups;
|
| 8 |
+
|
| 9 |
+
#include "z_order.h"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
// Expands a 10-bit integer into 30 bits by inserting 2 zeros after each bit.
|
| 13 |
+
static __device__ uint32_t expandBits(uint32_t v)
|
| 14 |
+
{
|
| 15 |
+
v = (v * 0x00010001u) & 0xFF0000FFu;
|
| 16 |
+
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
| 17 |
+
v = (v * 0x00000011u) & 0xC30C30C3u;
|
| 18 |
+
v = (v * 0x00000005u) & 0x49249249u;
|
| 19 |
+
return v;
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
// Removes 2 zeros after each bit in a 30-bit integer.
|
| 24 |
+
static __device__ uint32_t extractBits(uint32_t v)
|
| 25 |
+
{
|
| 26 |
+
v = v & 0x49249249;
|
| 27 |
+
v = (v ^ (v >> 2)) & 0x030C30C3u;
|
| 28 |
+
v = (v ^ (v >> 4)) & 0x0300F00Fu;
|
| 29 |
+
v = (v ^ (v >> 8)) & 0x030000FFu;
|
| 30 |
+
v = (v ^ (v >> 16)) & 0x000003FFu;
|
| 31 |
+
return v;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
__global__ void z_order_encode_cuda(
|
| 36 |
+
size_t N,
|
| 37 |
+
const uint32_t* x,
|
| 38 |
+
const uint32_t* y,
|
| 39 |
+
const uint32_t* z,
|
| 40 |
+
uint32_t* codes
|
| 41 |
+
) {
|
| 42 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
| 43 |
+
if (thread_id >= N) return;
|
| 44 |
+
|
| 45 |
+
uint32_t xx = expandBits(x[thread_id]);
|
| 46 |
+
uint32_t yy = expandBits(y[thread_id]);
|
| 47 |
+
uint32_t zz = expandBits(z[thread_id]);
|
| 48 |
+
|
| 49 |
+
codes[thread_id] = xx * 4 + yy * 2 + zz;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
__global__ void z_order_decode_cuda(
|
| 54 |
+
size_t N,
|
| 55 |
+
const uint32_t* codes,
|
| 56 |
+
uint32_t* x,
|
| 57 |
+
uint32_t* y,
|
| 58 |
+
uint32_t* z
|
| 59 |
+
) {
|
| 60 |
+
size_t thread_id = cg::this_grid().thread_rank();
|
| 61 |
+
if (thread_id >= N) return;
|
| 62 |
+
|
| 63 |
+
x[thread_id] = extractBits(codes[thread_id] >> 2);
|
| 64 |
+
y[thread_id] = extractBits(codes[thread_id] >> 1);
|
| 65 |
+
z[thread_id] = extractBits(codes[thread_id]);
|
| 66 |
+
}
|
extensions/vox2seq/src/z_order.h
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
/**
|
| 4 |
+
* Z-order encode 3D points
|
| 5 |
+
*
|
| 6 |
+
* @param x [N] tensor containing the x coordinates
|
| 7 |
+
* @param y [N] tensor containing the y coordinates
|
| 8 |
+
* @param z [N] tensor containing the z coordinates
|
| 9 |
+
*
|
| 10 |
+
* @return [N] tensor containing the z-order encoded values
|
| 11 |
+
*/
|
| 12 |
+
__global__ void z_order_encode_cuda(
|
| 13 |
+
size_t N,
|
| 14 |
+
const uint32_t* x,
|
| 15 |
+
const uint32_t* y,
|
| 16 |
+
const uint32_t* z,
|
| 17 |
+
uint32_t* codes
|
| 18 |
+
);
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
/**
|
| 22 |
+
* Z-order decode 3D points
|
| 23 |
+
*
|
| 24 |
+
* @param codes [N] tensor containing the z-order encoded values
|
| 25 |
+
* @param x [N] tensor containing the x coordinates
|
| 26 |
+
* @param y [N] tensor containing the y coordinates
|
| 27 |
+
* @param z [N] tensor containing the z coordinates
|
| 28 |
+
*/
|
| 29 |
+
__global__ void z_order_decode_cuda(
|
| 30 |
+
size_t N,
|
| 31 |
+
const uint32_t* codes,
|
| 32 |
+
uint32_t* x,
|
| 33 |
+
uint32_t* y,
|
| 34 |
+
uint32_t* z
|
| 35 |
+
);
|
extensions/vox2seq/test.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import vox2seq
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
RES = 256
|
| 7 |
+
coords = torch.meshgrid(torch.arange(RES), torch.arange(RES), torch.arange(RES))
|
| 8 |
+
coords = torch.stack(coords, dim=-1).reshape(-1, 3).int().cuda()
|
| 9 |
+
code_z_cuda = vox2seq.encode(coords, mode='z_order')
|
| 10 |
+
code_z_pytorch = vox2seq.pytorch.encode(coords, mode='z_order')
|
| 11 |
+
code_h_cuda = vox2seq.encode(coords, mode='hilbert')
|
| 12 |
+
code_h_pytorch = vox2seq.pytorch.encode(coords, mode='hilbert')
|
| 13 |
+
assert torch.equal(code_z_cuda, code_z_pytorch)
|
| 14 |
+
assert torch.equal(code_h_cuda, code_h_pytorch)
|
| 15 |
+
|
| 16 |
+
code = torch.arange(RES**3).int().cuda()
|
| 17 |
+
coords_z_cuda = vox2seq.decode(code, mode='z_order')
|
| 18 |
+
coords_z_pytorch = vox2seq.pytorch.decode(code, mode='z_order')
|
| 19 |
+
coords_h_cuda = vox2seq.decode(code, mode='hilbert')
|
| 20 |
+
coords_h_pytorch = vox2seq.pytorch.decode(code, mode='hilbert')
|
| 21 |
+
assert torch.equal(coords_z_cuda, coords_z_pytorch)
|
| 22 |
+
assert torch.equal(coords_h_cuda, coords_h_pytorch)
|
| 23 |
+
|
| 24 |
+
print("All tests passed.")
|
| 25 |
+
|
extensions/vox2seq/vox2seq/__init__.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from typing import *
|
| 3 |
+
import torch
|
| 4 |
+
from . import _C
|
| 5 |
+
from . import pytorch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@torch.no_grad()
|
| 9 |
+
def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
| 10 |
+
"""
|
| 11 |
+
Encodes 3D coordinates into a 30-bit code.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
coords: a tensor of shape [N, 3] containing the 3D coordinates.
|
| 15 |
+
permute: the permutation of the coordinates.
|
| 16 |
+
mode: the encoding mode to use.
|
| 17 |
+
"""
|
| 18 |
+
assert coords.shape[-1] == 3 and coords.ndim == 2, "Input coordinates must be of shape [N, 3]"
|
| 19 |
+
x = coords[:, permute[0]].int()
|
| 20 |
+
y = coords[:, permute[1]].int()
|
| 21 |
+
z = coords[:, permute[2]].int()
|
| 22 |
+
if mode == 'z_order':
|
| 23 |
+
return _C.z_order_encode(x, y, z)
|
| 24 |
+
elif mode == 'hilbert':
|
| 25 |
+
return _C.hilbert_encode(x, y, z)
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError(f"Unknown encoding mode: {mode}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@torch.no_grad()
|
| 31 |
+
def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
| 32 |
+
"""
|
| 33 |
+
Decodes a 30-bit code into 3D coordinates.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
code: a tensor of shape [N] containing the 30-bit code.
|
| 37 |
+
permute: the permutation of the coordinates.
|
| 38 |
+
mode: the decoding mode to use.
|
| 39 |
+
"""
|
| 40 |
+
assert code.ndim == 1, "Input code must be of shape [N]"
|
| 41 |
+
if mode == 'z_order':
|
| 42 |
+
coords = _C.z_order_decode(code)
|
| 43 |
+
elif mode == 'hilbert':
|
| 44 |
+
coords = _C.hilbert_decode(code)
|
| 45 |
+
else:
|
| 46 |
+
raise ValueError(f"Unknown decoding mode: {mode}")
|
| 47 |
+
x = coords[permute.index(0)]
|
| 48 |
+
y = coords[permute.index(1)]
|
| 49 |
+
z = coords[permute.index(2)]
|
| 50 |
+
return torch.stack([x, y, z], dim=-1)
|
extensions/vox2seq/vox2seq/pytorch/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import *
|
| 3 |
+
|
| 4 |
+
from .default import (
|
| 5 |
+
encode,
|
| 6 |
+
decode,
|
| 7 |
+
z_order_encode,
|
| 8 |
+
z_order_decode,
|
| 9 |
+
hilbert_encode,
|
| 10 |
+
hilbert_decode,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@torch.no_grad()
|
| 15 |
+
def encode(coords: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
| 16 |
+
"""
|
| 17 |
+
Encodes 3D coordinates into a 30-bit code.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
coords: a tensor of shape [N, 3] containing the 3D coordinates.
|
| 21 |
+
permute: the permutation of the coordinates.
|
| 22 |
+
mode: the encoding mode to use.
|
| 23 |
+
"""
|
| 24 |
+
if mode == 'z_order':
|
| 25 |
+
return z_order_encode(coords[:, permute], depth=10).int()
|
| 26 |
+
elif mode == 'hilbert':
|
| 27 |
+
return hilbert_encode(coords[:, permute], depth=10).int()
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f"Unknown encoding mode: {mode}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@torch.no_grad()
|
| 33 |
+
def decode(code: torch.Tensor, permute: List[int] = [0, 1, 2], mode: Literal['z_order', 'hilbert'] = 'z_order') -> torch.Tensor:
|
| 34 |
+
"""
|
| 35 |
+
Decodes a 30-bit code into 3D coordinates.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
code: a tensor of shape [N] containing the 30-bit code.
|
| 39 |
+
permute: the permutation of the coordinates.
|
| 40 |
+
mode: the decoding mode to use.
|
| 41 |
+
"""
|
| 42 |
+
if mode == 'z_order':
|
| 43 |
+
return z_order_decode(code, depth=10)[:, permute].float()
|
| 44 |
+
elif mode == 'hilbert':
|
| 45 |
+
return hilbert_decode(code, depth=10)[:, permute].float()
|
| 46 |
+
else:
|
| 47 |
+
raise ValueError(f"Unknown decoding mode: {mode}")
|
| 48 |
+
|
extensions/vox2seq/vox2seq/pytorch/default.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from .z_order import xyz2key as z_order_encode_
|
| 3 |
+
from .z_order import key2xyz as z_order_decode_
|
| 4 |
+
from .hilbert import encode as hilbert_encode_
|
| 5 |
+
from .hilbert import decode as hilbert_decode_
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@torch.inference_mode()
|
| 9 |
+
def encode(grid_coord, batch=None, depth=16, order="z"):
|
| 10 |
+
assert order in {"z", "z-trans", "hilbert", "hilbert-trans"}
|
| 11 |
+
if order == "z":
|
| 12 |
+
code = z_order_encode(grid_coord, depth=depth)
|
| 13 |
+
elif order == "z-trans":
|
| 14 |
+
code = z_order_encode(grid_coord[:, [1, 0, 2]], depth=depth)
|
| 15 |
+
elif order == "hilbert":
|
| 16 |
+
code = hilbert_encode(grid_coord, depth=depth)
|
| 17 |
+
elif order == "hilbert-trans":
|
| 18 |
+
code = hilbert_encode(grid_coord[:, [1, 0, 2]], depth=depth)
|
| 19 |
+
else:
|
| 20 |
+
raise NotImplementedError
|
| 21 |
+
if batch is not None:
|
| 22 |
+
batch = batch.long()
|
| 23 |
+
code = batch << depth * 3 | code
|
| 24 |
+
return code
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@torch.inference_mode()
|
| 28 |
+
def decode(code, depth=16, order="z"):
|
| 29 |
+
assert order in {"z", "hilbert"}
|
| 30 |
+
batch = code >> depth * 3
|
| 31 |
+
code = code & ((1 << depth * 3) - 1)
|
| 32 |
+
if order == "z":
|
| 33 |
+
grid_coord = z_order_decode(code, depth=depth)
|
| 34 |
+
elif order == "hilbert":
|
| 35 |
+
grid_coord = hilbert_decode(code, depth=depth)
|
| 36 |
+
else:
|
| 37 |
+
raise NotImplementedError
|
| 38 |
+
return grid_coord, batch
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def z_order_encode(grid_coord: torch.Tensor, depth: int = 16):
|
| 42 |
+
x, y, z = grid_coord[:, 0].long(), grid_coord[:, 1].long(), grid_coord[:, 2].long()
|
| 43 |
+
# we block the support to batch, maintain batched code in Point class
|
| 44 |
+
code = z_order_encode_(x, y, z, b=None, depth=depth)
|
| 45 |
+
return code
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def z_order_decode(code: torch.Tensor, depth):
|
| 49 |
+
x, y, z, _ = z_order_decode_(code, depth=depth)
|
| 50 |
+
grid_coord = torch.stack([x, y, z], dim=-1) # (N, 3)
|
| 51 |
+
return grid_coord
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def hilbert_encode(grid_coord: torch.Tensor, depth: int = 16):
|
| 55 |
+
return hilbert_encode_(grid_coord, num_dims=3, num_bits=depth)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def hilbert_decode(code: torch.Tensor, depth: int = 16):
|
| 59 |
+
return hilbert_decode_(code, num_dims=3, num_bits=depth)
|
extensions/vox2seq/vox2seq/pytorch/hilbert.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hilbert Order
|
| 3 |
+
Modified from https://github.com/PrincetonLIPS/numpy-hilbert-curve
|
| 4 |
+
|
| 5 |
+
Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com), Kaixin Xu
|
| 6 |
+
Please cite our work if the code is helpful to you.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def right_shift(binary, k=1, axis=-1):
|
| 13 |
+
"""Right shift an array of binary values.
|
| 14 |
+
|
| 15 |
+
Parameters:
|
| 16 |
+
-----------
|
| 17 |
+
binary: An ndarray of binary values.
|
| 18 |
+
|
| 19 |
+
k: The number of bits to shift. Default 1.
|
| 20 |
+
|
| 21 |
+
axis: The axis along which to shift. Default -1.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
--------
|
| 25 |
+
Returns an ndarray with zero prepended and the ends truncated, along
|
| 26 |
+
whatever axis was specified."""
|
| 27 |
+
|
| 28 |
+
# If we're shifting the whole thing, just return zeros.
|
| 29 |
+
if binary.shape[axis] <= k:
|
| 30 |
+
return torch.zeros_like(binary)
|
| 31 |
+
|
| 32 |
+
# Determine the padding pattern.
|
| 33 |
+
# padding = [(0,0)] * len(binary.shape)
|
| 34 |
+
# padding[axis] = (k,0)
|
| 35 |
+
|
| 36 |
+
# Determine the slicing pattern to eliminate just the last one.
|
| 37 |
+
slicing = [slice(None)] * len(binary.shape)
|
| 38 |
+
slicing[axis] = slice(None, -k)
|
| 39 |
+
shifted = torch.nn.functional.pad(
|
| 40 |
+
binary[tuple(slicing)], (k, 0), mode="constant", value=0
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
return shifted
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def binary2gray(binary, axis=-1):
|
| 47 |
+
"""Convert an array of binary values into Gray codes.
|
| 48 |
+
|
| 49 |
+
This uses the classic X ^ (X >> 1) trick to compute the Gray code.
|
| 50 |
+
|
| 51 |
+
Parameters:
|
| 52 |
+
-----------
|
| 53 |
+
binary: An ndarray of binary values.
|
| 54 |
+
|
| 55 |
+
axis: The axis along which to compute the gray code. Default=-1.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
--------
|
| 59 |
+
Returns an ndarray of Gray codes.
|
| 60 |
+
"""
|
| 61 |
+
shifted = right_shift(binary, axis=axis)
|
| 62 |
+
|
| 63 |
+
# Do the X ^ (X >> 1) trick.
|
| 64 |
+
gray = torch.logical_xor(binary, shifted)
|
| 65 |
+
|
| 66 |
+
return gray
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def gray2binary(gray, axis=-1):
|
| 70 |
+
"""Convert an array of Gray codes back into binary values.
|
| 71 |
+
|
| 72 |
+
Parameters:
|
| 73 |
+
-----------
|
| 74 |
+
gray: An ndarray of gray codes.
|
| 75 |
+
|
| 76 |
+
axis: The axis along which to perform Gray decoding. Default=-1.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
--------
|
| 80 |
+
Returns an ndarray of binary values.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
# Loop the log2(bits) number of times necessary, with shift and xor.
|
| 84 |
+
shift = 2 ** (torch.Tensor([gray.shape[axis]]).log2().ceil().int() - 1)
|
| 85 |
+
while shift > 0:
|
| 86 |
+
gray = torch.logical_xor(gray, right_shift(gray, shift))
|
| 87 |
+
shift = torch.div(shift, 2, rounding_mode="floor")
|
| 88 |
+
return gray
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def encode(locs, num_dims, num_bits):
|
| 92 |
+
"""Decode an array of locations in a hypercube into a Hilbert integer.
|
| 93 |
+
|
| 94 |
+
This is a vectorized-ish version of the Hilbert curve implementation by John
|
| 95 |
+
Skilling as described in:
|
| 96 |
+
|
| 97 |
+
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
|
| 98 |
+
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
|
| 99 |
+
|
| 100 |
+
Params:
|
| 101 |
+
-------
|
| 102 |
+
locs - An ndarray of locations in a hypercube of num_dims dimensions, in
|
| 103 |
+
which each dimension runs from 0 to 2**num_bits-1. The shape can
|
| 104 |
+
be arbitrary, as long as the last dimension of the same has size
|
| 105 |
+
num_dims.
|
| 106 |
+
|
| 107 |
+
num_dims - The dimensionality of the hypercube. Integer.
|
| 108 |
+
|
| 109 |
+
num_bits - The number of bits for each dimension. Integer.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
--------
|
| 113 |
+
The output is an ndarray of uint64 integers with the same shape as the
|
| 114 |
+
input, excluding the last dimension, which needs to be num_dims.
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
# Keep around the original shape for later.
|
| 118 |
+
orig_shape = locs.shape
|
| 119 |
+
bitpack_mask = 1 << torch.arange(0, 8).to(locs.device)
|
| 120 |
+
bitpack_mask_rev = bitpack_mask.flip(-1)
|
| 121 |
+
|
| 122 |
+
if orig_shape[-1] != num_dims:
|
| 123 |
+
raise ValueError(
|
| 124 |
+
"""
|
| 125 |
+
The shape of locs was surprising in that the last dimension was of size
|
| 126 |
+
%d, but num_dims=%d. These need to be equal.
|
| 127 |
+
"""
|
| 128 |
+
% (orig_shape[-1], num_dims)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if num_dims * num_bits > 63:
|
| 132 |
+
raise ValueError(
|
| 133 |
+
"""
|
| 134 |
+
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
|
| 135 |
+
into a int64. Are you sure you need that many points on your Hilbert
|
| 136 |
+
curve?
|
| 137 |
+
"""
|
| 138 |
+
% (num_dims, num_bits, num_dims * num_bits)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Treat the location integers as 64-bit unsigned and then split them up into
|
| 142 |
+
# a sequence of uint8s. Preserve the association by dimension.
|
| 143 |
+
locs_uint8 = locs.long().view(torch.uint8).reshape((-1, num_dims, 8)).flip(-1)
|
| 144 |
+
|
| 145 |
+
# Now turn these into bits and truncate to num_bits.
|
| 146 |
+
gray = (
|
| 147 |
+
locs_uint8.unsqueeze(-1)
|
| 148 |
+
.bitwise_and(bitpack_mask_rev)
|
| 149 |
+
.ne(0)
|
| 150 |
+
.byte()
|
| 151 |
+
.flatten(-2, -1)[..., -num_bits:]
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Run the decoding process the other way.
|
| 155 |
+
# Iterate forwards through the bits.
|
| 156 |
+
for bit in range(0, num_bits):
|
| 157 |
+
# Iterate forwards through the dimensions.
|
| 158 |
+
for dim in range(0, num_dims):
|
| 159 |
+
# Identify which ones have this bit active.
|
| 160 |
+
mask = gray[:, dim, bit]
|
| 161 |
+
|
| 162 |
+
# Where this bit is on, invert the 0 dimension for lower bits.
|
| 163 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
| 164 |
+
gray[:, 0, bit + 1 :], mask[:, None]
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Where the bit is off, exchange the lower bits with the 0 dimension.
|
| 168 |
+
to_flip = torch.logical_and(
|
| 169 |
+
torch.logical_not(mask[:, None]).repeat(1, gray.shape[2] - bit - 1),
|
| 170 |
+
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
|
| 171 |
+
)
|
| 172 |
+
gray[:, dim, bit + 1 :] = torch.logical_xor(
|
| 173 |
+
gray[:, dim, bit + 1 :], to_flip
|
| 174 |
+
)
|
| 175 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
|
| 176 |
+
|
| 177 |
+
# Now flatten out.
|
| 178 |
+
gray = gray.swapaxes(1, 2).reshape((-1, num_bits * num_dims))
|
| 179 |
+
|
| 180 |
+
# Convert Gray back to binary.
|
| 181 |
+
hh_bin = gray2binary(gray)
|
| 182 |
+
|
| 183 |
+
# Pad back out to 64 bits.
|
| 184 |
+
extra_dims = 64 - num_bits * num_dims
|
| 185 |
+
padded = torch.nn.functional.pad(hh_bin, (extra_dims, 0), "constant", 0)
|
| 186 |
+
|
| 187 |
+
# Convert binary values into uint8s.
|
| 188 |
+
hh_uint8 = (
|
| 189 |
+
(padded.flip(-1).reshape((-1, 8, 8)) * bitpack_mask)
|
| 190 |
+
.sum(2)
|
| 191 |
+
.squeeze()
|
| 192 |
+
.type(torch.uint8)
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Convert uint8s into uint64s.
|
| 196 |
+
hh_uint64 = hh_uint8.view(torch.int64).squeeze()
|
| 197 |
+
|
| 198 |
+
return hh_uint64
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def decode(hilberts, num_dims, num_bits):
|
| 202 |
+
"""Decode an array of Hilbert integers into locations in a hypercube.
|
| 203 |
+
|
| 204 |
+
This is a vectorized-ish version of the Hilbert curve implementation by John
|
| 205 |
+
Skilling as described in:
|
| 206 |
+
|
| 207 |
+
Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference
|
| 208 |
+
Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics.
|
| 209 |
+
|
| 210 |
+
Params:
|
| 211 |
+
-------
|
| 212 |
+
hilberts - An ndarray of Hilbert integers. Must be an integer dtype and
|
| 213 |
+
cannot have fewer bits than num_dims * num_bits.
|
| 214 |
+
|
| 215 |
+
num_dims - The dimensionality of the hypercube. Integer.
|
| 216 |
+
|
| 217 |
+
num_bits - The number of bits for each dimension. Integer.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
--------
|
| 221 |
+
The output is an ndarray of unsigned integers with the same shape as hilberts
|
| 222 |
+
but with an additional dimension of size num_dims.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
if num_dims * num_bits > 64:
|
| 226 |
+
raise ValueError(
|
| 227 |
+
"""
|
| 228 |
+
num_dims=%d and num_bits=%d for %d bits total, which can't be encoded
|
| 229 |
+
into a uint64. Are you sure you need that many points on your Hilbert
|
| 230 |
+
curve?
|
| 231 |
+
"""
|
| 232 |
+
% (num_dims, num_bits)
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
# Handle the case where we got handed a naked integer.
|
| 236 |
+
hilberts = torch.atleast_1d(hilberts)
|
| 237 |
+
|
| 238 |
+
# Keep around the shape for later.
|
| 239 |
+
orig_shape = hilberts.shape
|
| 240 |
+
bitpack_mask = 2 ** torch.arange(0, 8).to(hilberts.device)
|
| 241 |
+
bitpack_mask_rev = bitpack_mask.flip(-1)
|
| 242 |
+
|
| 243 |
+
# Treat each of the hilberts as a s equence of eight uint8.
|
| 244 |
+
# This treats all of the inputs as uint64 and makes things uniform.
|
| 245 |
+
hh_uint8 = (
|
| 246 |
+
hilberts.ravel().type(torch.int64).view(torch.uint8).reshape((-1, 8)).flip(-1)
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Turn these lists of uints into lists of bits and then truncate to the size
|
| 250 |
+
# we actually need for using Skilling's procedure.
|
| 251 |
+
hh_bits = (
|
| 252 |
+
hh_uint8.unsqueeze(-1)
|
| 253 |
+
.bitwise_and(bitpack_mask_rev)
|
| 254 |
+
.ne(0)
|
| 255 |
+
.byte()
|
| 256 |
+
.flatten(-2, -1)[:, -num_dims * num_bits :]
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Take the sequence of bits and Gray-code it.
|
| 260 |
+
gray = binary2gray(hh_bits)
|
| 261 |
+
|
| 262 |
+
# There has got to be a better way to do this.
|
| 263 |
+
# I could index them differently, but the eventual packbits likes it this way.
|
| 264 |
+
gray = gray.reshape((-1, num_bits, num_dims)).swapaxes(1, 2)
|
| 265 |
+
|
| 266 |
+
# Iterate backwards through the bits.
|
| 267 |
+
for bit in range(num_bits - 1, -1, -1):
|
| 268 |
+
# Iterate backwards through the dimensions.
|
| 269 |
+
for dim in range(num_dims - 1, -1, -1):
|
| 270 |
+
# Identify which ones have this bit active.
|
| 271 |
+
mask = gray[:, dim, bit]
|
| 272 |
+
|
| 273 |
+
# Where this bit is on, invert the 0 dimension for lower bits.
|
| 274 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(
|
| 275 |
+
gray[:, 0, bit + 1 :], mask[:, None]
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Where the bit is off, exchange the lower bits with the 0 dimension.
|
| 279 |
+
to_flip = torch.logical_and(
|
| 280 |
+
torch.logical_not(mask[:, None]),
|
| 281 |
+
torch.logical_xor(gray[:, 0, bit + 1 :], gray[:, dim, bit + 1 :]),
|
| 282 |
+
)
|
| 283 |
+
gray[:, dim, bit + 1 :] = torch.logical_xor(
|
| 284 |
+
gray[:, dim, bit + 1 :], to_flip
|
| 285 |
+
)
|
| 286 |
+
gray[:, 0, bit + 1 :] = torch.logical_xor(gray[:, 0, bit + 1 :], to_flip)
|
| 287 |
+
|
| 288 |
+
# Pad back out to 64 bits.
|
| 289 |
+
extra_dims = 64 - num_bits
|
| 290 |
+
padded = torch.nn.functional.pad(gray, (extra_dims, 0), "constant", 0)
|
| 291 |
+
|
| 292 |
+
# Now chop these up into blocks of 8.
|
| 293 |
+
locs_chopped = padded.flip(-1).reshape((-1, num_dims, 8, 8))
|
| 294 |
+
|
| 295 |
+
# Take those blocks and turn them unto uint8s.
|
| 296 |
+
# from IPython import embed; embed()
|
| 297 |
+
locs_uint8 = (locs_chopped * bitpack_mask).sum(3).squeeze().type(torch.uint8)
|
| 298 |
+
|
| 299 |
+
# Finally, treat these as uint64s.
|
| 300 |
+
flat_locs = locs_uint8.view(torch.int64)
|
| 301 |
+
|
| 302 |
+
# Return them in the expected shape.
|
| 303 |
+
return flat_locs.reshape((*orig_shape, num_dims))
|
extensions/vox2seq/vox2seq/pytorch/z_order.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# Octree-based Sparse Convolutional Neural Networks
|
| 3 |
+
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
|
| 4 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 5 |
+
# Written by Peng-Shuai Wang
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from typing import Optional, Union
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class KeyLUT:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
r256 = torch.arange(256, dtype=torch.int64)
|
| 15 |
+
r512 = torch.arange(512, dtype=torch.int64)
|
| 16 |
+
zero = torch.zeros(256, dtype=torch.int64)
|
| 17 |
+
device = torch.device("cpu")
|
| 18 |
+
|
| 19 |
+
self._encode = {
|
| 20 |
+
device: (
|
| 21 |
+
self.xyz2key(r256, zero, zero, 8),
|
| 22 |
+
self.xyz2key(zero, r256, zero, 8),
|
| 23 |
+
self.xyz2key(zero, zero, r256, 8),
|
| 24 |
+
)
|
| 25 |
+
}
|
| 26 |
+
self._decode = {device: self.key2xyz(r512, 9)}
|
| 27 |
+
|
| 28 |
+
def encode_lut(self, device=torch.device("cpu")):
|
| 29 |
+
if device not in self._encode:
|
| 30 |
+
cpu = torch.device("cpu")
|
| 31 |
+
self._encode[device] = tuple(e.to(device) for e in self._encode[cpu])
|
| 32 |
+
return self._encode[device]
|
| 33 |
+
|
| 34 |
+
def decode_lut(self, device=torch.device("cpu")):
|
| 35 |
+
if device not in self._decode:
|
| 36 |
+
cpu = torch.device("cpu")
|
| 37 |
+
self._decode[device] = tuple(e.to(device) for e in self._decode[cpu])
|
| 38 |
+
return self._decode[device]
|
| 39 |
+
|
| 40 |
+
def xyz2key(self, x, y, z, depth):
|
| 41 |
+
key = torch.zeros_like(x)
|
| 42 |
+
for i in range(depth):
|
| 43 |
+
mask = 1 << i
|
| 44 |
+
key = (
|
| 45 |
+
key
|
| 46 |
+
| ((x & mask) << (2 * i + 2))
|
| 47 |
+
| ((y & mask) << (2 * i + 1))
|
| 48 |
+
| ((z & mask) << (2 * i + 0))
|
| 49 |
+
)
|
| 50 |
+
return key
|
| 51 |
+
|
| 52 |
+
def key2xyz(self, key, depth):
|
| 53 |
+
x = torch.zeros_like(key)
|
| 54 |
+
y = torch.zeros_like(key)
|
| 55 |
+
z = torch.zeros_like(key)
|
| 56 |
+
for i in range(depth):
|
| 57 |
+
x = x | ((key & (1 << (3 * i + 2))) >> (2 * i + 2))
|
| 58 |
+
y = y | ((key & (1 << (3 * i + 1))) >> (2 * i + 1))
|
| 59 |
+
z = z | ((key & (1 << (3 * i + 0))) >> (2 * i + 0))
|
| 60 |
+
return x, y, z
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
_key_lut = KeyLUT()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def xyz2key(
|
| 67 |
+
x: torch.Tensor,
|
| 68 |
+
y: torch.Tensor,
|
| 69 |
+
z: torch.Tensor,
|
| 70 |
+
b: Optional[Union[torch.Tensor, int]] = None,
|
| 71 |
+
depth: int = 16,
|
| 72 |
+
):
|
| 73 |
+
r"""Encodes :attr:`x`, :attr:`y`, :attr:`z` coordinates to the shuffled keys
|
| 74 |
+
based on pre-computed look up tables. The speed of this function is much
|
| 75 |
+
faster than the method based on for-loop.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
x (torch.Tensor): The x coordinate.
|
| 79 |
+
y (torch.Tensor): The y coordinate.
|
| 80 |
+
z (torch.Tensor): The z coordinate.
|
| 81 |
+
b (torch.Tensor or int): The batch index of the coordinates, and should be
|
| 82 |
+
smaller than 32768. If :attr:`b` is :obj:`torch.Tensor`, the size of
|
| 83 |
+
:attr:`b` must be the same as :attr:`x`, :attr:`y`, and :attr:`z`.
|
| 84 |
+
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
EX, EY, EZ = _key_lut.encode_lut(x.device)
|
| 88 |
+
x, y, z = x.long(), y.long(), z.long()
|
| 89 |
+
|
| 90 |
+
mask = 255 if depth > 8 else (1 << depth) - 1
|
| 91 |
+
key = EX[x & mask] | EY[y & mask] | EZ[z & mask]
|
| 92 |
+
if depth > 8:
|
| 93 |
+
mask = (1 << (depth - 8)) - 1
|
| 94 |
+
key16 = EX[(x >> 8) & mask] | EY[(y >> 8) & mask] | EZ[(z >> 8) & mask]
|
| 95 |
+
key = key16 << 24 | key
|
| 96 |
+
|
| 97 |
+
if b is not None:
|
| 98 |
+
b = b.long()
|
| 99 |
+
key = b << 48 | key
|
| 100 |
+
|
| 101 |
+
return key
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def key2xyz(key: torch.Tensor, depth: int = 16):
|
| 105 |
+
r"""Decodes the shuffled key to :attr:`x`, :attr:`y`, :attr:`z` coordinates
|
| 106 |
+
and the batch index based on pre-computed look up tables.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
key (torch.Tensor): The shuffled key.
|
| 110 |
+
depth (int): The depth of the shuffled key, and must be smaller than 17 (< 17).
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
DX, DY, DZ = _key_lut.decode_lut(key.device)
|
| 114 |
+
x, y, z = torch.zeros_like(key), torch.zeros_like(key), torch.zeros_like(key)
|
| 115 |
+
|
| 116 |
+
b = key >> 48
|
| 117 |
+
key = key & ((1 << 48) - 1)
|
| 118 |
+
|
| 119 |
+
n = (depth + 2) // 3
|
| 120 |
+
for i in range(n):
|
| 121 |
+
k = key >> (i * 9) & 511
|
| 122 |
+
x = x | (DX[k] << (i * 3))
|
| 123 |
+
y = y | (DY[k] << (i * 3))
|
| 124 |
+
z = z | (DZ[k] << (i * 3))
|
| 125 |
+
|
| 126 |
+
return x, y, z, b
|