Pixal3D / trellis2 /datasets /sparse_voxel_pbr.py
Yang2001's picture
Upload folder using huggingface_hub
8d595ff verified
import os
import io
from typing import Union
import numpy as np
import pickle
import torch
from PIL import Image
import o_voxel
import utils3d
from .components import StandardDatasetBase
from ..modules import sparse as sp
from ..renderers import VoxelRenderer
from ..representations import Voxel
from ..representations.mesh import MeshWithPbrMaterial, TextureFilterMode, TextureWrapMode, AlphaMode, PbrMaterial, Texture
from ..utils.data_utils import load_balanced_group_indices
def is_power_of_two(n: int) -> bool:
return n > 0 and (n & (n - 1)) == 0
def nearest_power_of_two(n: int) -> int:
if n < 1:
raise ValueError("n must be >= 1")
if is_power_of_two(n):
return n
lower = 2 ** (n.bit_length() - 1)
upper = 2 ** n.bit_length()
if n - lower < upper - n:
return lower
else:
return upper
class SparseVoxelPbrVisMixin:
@torch.no_grad()
def visualize_sample(self, x: Union[sp.SparseTensor, dict]):
x = x if isinstance(x, sp.SparseTensor) else x['x']
renderer = VoxelRenderer()
renderer.rendering_options.resolution = 512
renderer.rendering_options.ssaa = 4
# Build camera
yaws = [0, np.pi / 2, np.pi, 3 * np.pi / 2]
yaws_offset = np.random.uniform(-np.pi / 4, np.pi / 4)
yaws = [y + yaws_offset for y in yaws]
pitch = [np.random.uniform(-np.pi / 4, np.pi / 4) for _ in range(4)]
exts = []
ints = []
for yaw, pitch in zip(yaws, pitch):
orig = torch.tensor([
np.sin(yaw) * np.cos(pitch),
np.cos(yaw) * np.cos(pitch),
np.sin(pitch),
]).float().cuda() * 2
fov = torch.deg2rad(torch.tensor(30)).cuda()
extrinsics = utils3d.torch.extrinsics_look_at(orig, torch.tensor([0, 0, 0]).float().cuda(), torch.tensor([0, 0, 1]).float().cuda())
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov, fov)
exts.append(extrinsics)
ints.append(intrinsics)
images = {k: [] for k in self.layout}
# Build each representation
x = x.cuda()
for i in range(x.shape[0]):
rep = Voxel(
origin=[-0.5, -0.5, -0.5],
voxel_size=1/self.resolution,
coords=x[i].coords[:, 1:].contiguous(),
attrs=None,
layout={
'color': slice(0, 3),
}
)
for k in self.layout:
image = torch.zeros(3, 1024, 1024).cuda()
tile = [2, 2]
for j, (ext, intr) in enumerate(zip(exts, ints)):
attr = x[i].feats[:, self.layout[k]].expand(-1, 3)
res = renderer.render(rep, ext, intr, colors_overwrite=attr)
image[:, 512 * (j // tile[1]):512 * (j // tile[1] + 1), 512 * (j % tile[1]):512 * (j % tile[1] + 1)] = res['color']
images[k].append(image)
for k in self.layout:
images[k] = torch.stack(images[k])
return images
class SparseVoxelPbrDataset(SparseVoxelPbrVisMixin, StandardDatasetBase):
"""
Sparse Voxel PBR dataset.
Args:
roots (str): path to the dataset
resolution (int): resolution of the voxel grid
min_aesthetic_score (float): minimum aesthetic score of the instances to be included in the dataset
"""
def __init__(
self,
roots,
resolution: int = 1024,
max_active_voxels: int = 1000000,
max_num_faces: int = None,
min_aesthetic_score: float = 5.0,
attrs: list[str] = ['base_color', 'metallic', 'roughness', 'emissive', 'alpha'],
with_mesh: bool = True,
):
self.resolution = resolution
self.min_aesthetic_score = min_aesthetic_score
self.max_active_voxels = max_active_voxels
self.max_num_faces = max_num_faces
self.with_mesh = with_mesh
self.value_range = (-1, 1)
self.channels = {
'base_color': 3,
'metallic': 1,
'roughness': 1,
'emissive': 3,
'alpha': 1,
}
self.layout = {}
start = 0
for attr in attrs:
self.layout[attr] = slice(start, start + self.channels[attr])
start += self.channels[attr]
super().__init__(roots)
self.loads = [self.metadata.loc[sha256, f'num_pbr_voxels'] for _, sha256, _ in self.instances]
def __str__(self):
lines = [
super().__str__(),
f' - Resolution: {self.resolution}',
f' - Attributes: {list(self.layout.keys())}',
]
return '\n'.join(lines)
def filter_metadata(self, metadata, dataset_name=None):
stats = {}
metadata = metadata[metadata['pbr_voxelized'] == True]
stats['PBR Voxelized'] = len(metadata)
if self.min_aesthetic_score is not None:
metadata = metadata[metadata['aesthetic_score'] >= self.min_aesthetic_score]
stats[f'Aesthetic score >= {self.min_aesthetic_score}'] = len(metadata)
metadata = metadata[metadata['num_pbr_voxels'] <= self.max_active_voxels]
stats[f'Active voxels <= {self.max_active_voxels}'] = len(metadata)
if self.max_num_faces is not None:
metadata = metadata[metadata['num_faces'] <= self.max_num_faces]
stats[f'Faces <= {self.max_num_faces}'] = len(metadata)
return metadata, stats
@staticmethod
def _texture_from_dump(pack) -> Texture:
png_bytes = pack['image']
image = Image.open(io.BytesIO(png_bytes))
if image.width != image.height or not is_power_of_two(image.width):
size = nearest_power_of_two(max(image.width, image.height))
image = image.resize((size, size), Image.LANCZOS)
texture = torch.tensor(np.array(image) / 255.0, dtype=torch.float32).reshape(image.height, image.width, -1)
filter_mode = {
'Linear': TextureFilterMode.LINEAR,
'Closest': TextureFilterMode.CLOSEST,
'Cubic': TextureFilterMode.LINEAR,
'Smart': TextureFilterMode.LINEAR,
}[pack['interpolation']]
wrap_mode = {
'REPEAT': TextureWrapMode.REPEAT,
'EXTEND': TextureWrapMode.CLAMP_TO_EDGE,
'CLIP': TextureWrapMode.CLAMP_TO_EDGE,
'MIRROR': TextureWrapMode.MIRRORED_REPEAT,
}[pack['extension']]
return Texture(texture, filter_mode=filter_mode, wrap_mode=wrap_mode)
def read_mesh_with_texture(self, root, instance):
with open(os.path.join(root, f'{instance}.pickle'), 'rb') as f:
dump = pickle.load(f)
# Fix dump alpha map
for mat in dump['materials']:
if mat['alphaTexture'] is not None and mat['alphaMode'] == 'OPAQUE':
mat['alphaMode'] = 'BLEND'
# process material
materials = []
for mat in dump['materials']:
materials.append(PbrMaterial(
base_color_texture=self._texture_from_dump(mat['baseColorTexture']) if mat['baseColorTexture'] is not None else None,
base_color_factor=mat['baseColorFactor'],
metallic_texture=self._texture_from_dump(mat['metallicTexture']) if mat['metallicTexture'] is not None else None,
metallic_factor=mat['metallicFactor'],
roughness_texture=self._texture_from_dump(mat['roughnessTexture']) if mat['roughnessTexture'] is not None else None,
roughness_factor=mat['roughnessFactor'],
alpha_texture=self._texture_from_dump(mat['alphaTexture']) if mat['alphaTexture'] is not None else None,
alpha_factor=mat['alphaFactor'],
alpha_mode={
'OPAQUE': AlphaMode.OPAQUE,
'MASK': AlphaMode.MASK,
'BLEND': AlphaMode.BLEND,
}[mat['alphaMode']],
alpha_cutoff=mat['alphaCutoff'],
))
materials.append(PbrMaterial(
base_color_factor=[0.8, 0.8, 0.8],
alpha_factor=1.0,
metallic_factor=0.0,
roughness_factor=0.5,
alpha_mode=AlphaMode.OPAQUE,
alpha_cutoff=0.5,
)) # append default material
# process mesh
start = 0
vertices = []
faces = []
material_ids = []
uv_coords = []
for obj in dump['objects']:
if obj['vertices'].size == 0 or obj['faces'].size == 0:
continue
vertices.append(obj['vertices'])
faces.append(obj['faces'] + start)
obj['mat_ids'][obj['mat_ids'] == -1] = len(materials) - 1
material_ids.append(obj['mat_ids'])
uv_coords.append(obj['uvs'] if obj['uvs'] is not None else np.zeros((obj['faces'].shape[0], 3, 2), dtype=np.float32))
start += len(obj['vertices'])
vertices = torch.from_numpy(np.concatenate(vertices, axis=0)).float()
faces = torch.from_numpy(np.concatenate(faces, axis=0)).long()
material_ids = torch.from_numpy(np.concatenate(material_ids, axis=0)).long()
uv_coords = torch.from_numpy(np.concatenate(uv_coords, axis=0)).float()
# Normalize vertices
vertices_min = vertices.min(dim=0)[0]
vertices_max = vertices.max(dim=0)[0]
center = (vertices_min + vertices_max) / 2
scale = 0.99999 / (vertices_max - vertices_min).max()
vertices = (vertices - center) * scale
assert torch.all(vertices >= -0.5) and torch.all(vertices <= 0.5), 'vertices out of range'
return {'mesh': [MeshWithPbrMaterial(
vertices=vertices,
faces=faces,
material_ids=material_ids,
uv_coords=uv_coords,
materials=materials,
)]}
def read_pbr_voxel(self, root, instance):
coords, attr = o_voxel.io.read_vxz(os.path.join(root, f'{instance}.vxz'), num_threads=4)
feats = torch.concat([attr[k] for k in self.layout], dim=-1) / 255.0 * 2 - 1
x = sp.SparseTensor(
feats.float(),
torch.cat([torch.zeros_like(coords[:, 0:1]), coords], dim=-1),
)
return {'x': x}
def get_instance(self, root, instance):
if self.with_mesh:
mesh = self.read_mesh_with_texture(root['pbr_dump'], instance)
pbr_voxel = self.read_pbr_voxel(root['pbr_voxel'], instance)
return {**mesh, **pbr_voxel}
else:
return self.read_pbr_voxel(root['pbr_voxel'], instance)
@staticmethod
def collate_fn(batch, split_size=None):
if split_size is None:
group_idx = [list(range(len(batch)))]
else:
group_idx = load_balanced_group_indices([b['x'].feats.shape[0] for b in batch], split_size)
packs = []
for group in group_idx:
sub_batch = [batch[i] for i in group]
pack = {}
keys = [k for k in sub_batch[0].keys()]
for k in keys:
if isinstance(sub_batch[0][k], torch.Tensor):
pack[k] = torch.stack([b[k] for b in sub_batch])
elif isinstance(sub_batch[0][k], sp.SparseTensor):
pack[k] = sp.sparse_cat([b[k] for b in sub_batch], dim=0)
elif isinstance(sub_batch[0][k], list):
pack[k] = sum([b[k] for b in sub_batch], [])
else:
pack[k] = [b[k] for b in sub_batch]
packs.append(pack)
if split_size is None:
return packs[0]
return packs