| import os, sys |
| import math |
| import json |
| import glm |
| from pathlib import Path |
|
|
| import random |
| import numpy as np |
| from PIL import Image |
| import webdataset as wds |
| import pytorch_lightning as pl |
| import sys |
| from src.utils import obj, render_utils |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset |
| from torch.utils.data.distributed import DistributedSampler |
| import random |
| import itertools |
| from src.utils.train_util import instantiate_from_config |
| from src.utils.camera_util import ( |
| FOV_to_intrinsics, |
| center_looking_at_camera_pose, |
| get_circular_camera_poses, |
| ) |
| os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1" |
| import re |
|
|
| def spherical_camera_pose(azimuths: np.ndarray, elevations: np.ndarray, radius=2.5): |
| azimuths = np.deg2rad(azimuths) |
| elevations = np.deg2rad(elevations) |
|
|
| xs = radius * np.cos(elevations) * np.cos(azimuths) |
| ys = radius * np.cos(elevations) * np.sin(azimuths) |
| zs = radius * np.sin(elevations) |
|
|
| cam_locations = np.stack([xs, ys, zs], axis=-1) |
| cam_locations = torch.from_numpy(cam_locations).float() |
|
|
| c2ws = center_looking_at_camera_pose(cam_locations) |
| return c2ws |
|
|
| def find_matching_files(base_path, idx): |
| formatted_idx = '%03d' % idx |
| pattern = re.compile(r'^%s_\d+\.png$' % formatted_idx) |
| matching_files = [] |
| |
| if os.path.exists(base_path): |
| for filename in os.listdir(base_path): |
| if pattern.match(filename): |
| matching_files.append(filename) |
| |
| return os.path.join(base_path, matching_files[0]) |
|
|
| def load_mipmap(env_path): |
| diffuse_path = os.path.join(env_path, "diffuse.pth") |
| diffuse = torch.load(diffuse_path, map_location=torch.device('cpu')) |
|
|
| specular = [] |
| for i in range(6): |
| specular_path = os.path.join(env_path, f"specular_{i}.pth") |
| specular_tensor = torch.load(specular_path, map_location=torch.device('cpu')) |
| specular.append(specular_tensor) |
| return [specular, diffuse] |
|
|
| def convert_to_white_bg(image, write_bg=True): |
| alpha = image[:, :, 3:] |
| if write_bg: |
| return image[:, :, :3] * alpha + 1. * (1 - alpha) |
| else: |
| return image[:, :, :3] * alpha |
| |
| def load_obj(path, return_attributes=False, scale_factor=1.0): |
| return obj.load_obj(path, clear_ks=True, mtl_override=None, return_attributes=return_attributes, scale_factor=scale_factor) |
|
|
| def custom_collate_fn(batch): |
| return batch |
|
|
|
|
| def collate_fn_wrapper(batch): |
| return custom_collate_fn(batch) |
|
|
| class DataModuleFromConfig(pl.LightningDataModule): |
| def __init__( |
| self, |
| batch_size=8, |
| num_workers=4, |
| train=None, |
| validation=None, |
| test=None, |
| **kwargs, |
| ): |
| super().__init__() |
|
|
| self.batch_size = batch_size |
| self.num_workers = num_workers |
|
|
| self.dataset_configs = dict() |
| if train is not None: |
| self.dataset_configs['train'] = train |
| if validation is not None: |
| self.dataset_configs['validation'] = validation |
| if test is not None: |
| self.dataset_configs['test'] = test |
|
|
| def setup(self, stage): |
|
|
| if stage in ['fit']: |
| self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) |
| else: |
| raise NotImplementedError |
| |
| def custom_collate_fn(self, batch): |
| collated_batch = {} |
| for key in batch[0].keys(): |
| if key == 'input_env' or key == 'target_env': |
| collated_batch[key] = [d[key] for d in batch] |
| else: |
| collated_batch[key] = torch.stack([d[key] for d in batch], dim=0) |
| return collated_batch |
| |
| def convert_to_white_bg(self, image): |
| alpha = image[:, :, 3:] |
| return image[:, :, :3] * alpha + 1. * (1 - alpha) |
| |
| def load_obj(self, path): |
| return obj.load_obj(path, clear_ks=True, mtl_override=None) |
|
|
| def train_dataloader(self): |
|
|
| sampler = DistributedSampler(self.datasets['train']) |
| return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler, collate_fn=collate_fn_wrapper) |
|
|
| def val_dataloader(self): |
|
|
| sampler = DistributedSampler(self.datasets['validation']) |
| return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler, collate_fn=collate_fn_wrapper) |
|
|
| def test_dataloader(self): |
|
|
| return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) |
|
|
|
|
| class ObjaverseData(Dataset): |
| def __init__(self, |
| root_dir='Objaverse_highQuality', |
| light_dir= 'env_mipmap', |
| input_view_num=6, |
| target_view_num=4, |
| total_view_n=18, |
| distance=3.5, |
| fov=50, |
| camera_random=False, |
| validation=False, |
| ): |
| self.root_dir = Path(root_dir) |
| self.light_dir = light_dir |
| self.all_env_name = [] |
| for temp_dir in os.listdir(light_dir): |
| if os.listdir(os.path.join(self.light_dir, temp_dir)): |
| self.all_env_name.append(temp_dir) |
|
|
| self.input_view_num = input_view_num |
| self.target_view_num = target_view_num |
| self.total_view_n = total_view_n |
| self.fov = fov |
| self.camera_random = camera_random |
| |
| self.train_res = [512, 512] |
| self.cam_near_far = [0.1, 1000.0] |
| self.fov_rad = np.deg2rad(fov) |
| self.fov_deg = fov |
| self.spp = 1 |
| self.cam_radius = distance |
| self.layers = 1 |
| |
| numbers = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] |
| self.combinations = list(itertools.product(numbers, repeat=2)) |
| |
| self.paths = os.listdir(self.root_dir) |
| |
| |
| |
|
|
| print('total training object num:', len(self.paths)) |
|
|
| self.depth_scale = 6.0 |
| |
| total_objects = len(self.paths) |
| print('============= length of dataset %d =============' % total_objects) |
|
|
| def __len__(self): |
| return len(self.paths) |
| |
| def load_obj(self, path): |
| return obj.load_obj(path, clear_ks=True, mtl_override=None) |
| |
| def sample_spherical(self, phi, theta, cam_radius): |
| theta = np.deg2rad(theta) |
| phi = np.deg2rad(phi) |
|
|
| z = cam_radius * np.cos(phi) * np.sin(theta) |
| x = cam_radius * np.sin(phi) * np.sin(theta) |
| y = cam_radius * np.cos(theta) |
| |
| return x, y, z |
| |
| def _random_scene(self, cam_radius, fov_rad): |
| iter_res = self.train_res |
| proj_mtx = render_utils.perspective(fov_rad, iter_res[1] / iter_res[0], self.cam_near_far[0], self.cam_near_far[1]) |
|
|
| azimuths = random.uniform(0, 360) |
| elevations = random.uniform(30, 150) |
| mv_embedding = spherical_camera_pose(azimuths, 90-elevations, cam_radius) |
| x, y, z = self.sample_spherical(azimuths, elevations, cam_radius) |
| eye = glm.vec3(x, y, z) |
| at = glm.vec3(0.0, 0.0, 0.0) |
| up = glm.vec3(0.0, 1.0, 0.0) |
| view_matrix = glm.lookAt(eye, at, up) |
| mv = torch.from_numpy(np.array(view_matrix)) |
| mvp = proj_mtx @ (mv) |
| campos = torch.linalg.inv(mv)[:3, 3] |
| return mv[None, ...], mvp[None, ...], campos[None, ...], mv_embedding[None, ...], iter_res, self.spp |
| |
| def load_im(self, path, color): |
| ''' |
| replace background pixel with random color in rendering |
| ''' |
| pil_img = Image.open(path) |
|
|
| image = np.asarray(pil_img, dtype=np.float32) / 255. |
| alpha = image[:, :, 3:] |
| image = image[:, :, :3] * alpha + color * (1 - alpha) |
|
|
| image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() |
| alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() |
| return image, alpha |
| |
| def load_albedo(self, path, color, mask): |
| ''' |
| replace background pixel with random color in rendering |
| ''' |
| pil_img = Image.open(path) |
|
|
| image = np.asarray(pil_img, dtype=np.float32) / 255. |
| image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() |
|
|
| color = torch.ones_like(image) |
| image = image * mask + color * (1 - mask) |
| return image |
| |
| def convert_to_white_bg(self, image): |
| alpha = image[:, :, 3:] |
| return image[:, :, :3] * alpha + 1. * (1 - alpha) |
|
|
| def calculate_fov(self, initial_distance, initial_fov, new_distance): |
| initial_fov_rad = math.radians(initial_fov) |
| |
| height = 2 * initial_distance * math.tan(initial_fov_rad / 2) |
| |
| new_fov_rad = 2 * math.atan(height / (2 * new_distance)) |
| |
| new_fov = math.degrees(new_fov_rad) |
| |
| return new_fov |
| |
| def __getitem__(self, index): |
| obj_path = os.path.join(self.root_dir, self.paths[index]) |
| mesh_attributes = torch.load(obj_path, map_location=torch.device('cpu')) |
| pose_list = [] |
| env_list = [] |
| material_list = [] |
| camera_pos = [] |
| c2w_list = [] |
| camera_embedding_list = [] |
| random_env = False |
| random_mr = False |
| if random.random() > 0.5: |
| random_env = True |
| if random.random() > 0.5: |
| random_mr = True |
| selected_env = random.randint(0, len(self.all_env_name)-1) |
| materials = random.choice(self.combinations) |
| if self.camera_random: |
| random_perturbation = random.uniform(-1.5, 1.5) |
| cam_radius = self.cam_radius + random_perturbation |
| fov_deg = self.calculate_fov(initial_distance=self.cam_radius, initial_fov=self.fov_deg, new_distance=cam_radius) |
| fov_rad = np.deg2rad(fov_deg) |
| else: |
| cam_radius = self.cam_radius |
| fov_rad = self.fov_rad |
| fov_deg = self.fov_deg |
|
|
| if len(self.input_view_num) >= 1: |
| input_view_num = random.choice(self.input_view_num) |
| else: |
| input_view_num = self.input_view_num |
| for _ in range(input_view_num + self.target_view_num): |
| mv, mvp, campos, mv_mebedding, iter_res, iter_spp = self._random_scene(cam_radius, fov_rad) |
| if random_env: |
| selected_env = random.randint(0, len(self.all_env_name)-1) |
| env_path = os.path.join(self.light_dir, self.all_env_name[selected_env]) |
| env = load_mipmap(env_path) |
| if random_mr: |
| materials = random.choice(self.combinations) |
| pose_list.append(mvp) |
| camera_pos.append(campos) |
| c2w_list.append(mv) |
| env_list.append(env) |
| material_list.append(materials) |
| camera_embedding_list.append(mv_mebedding) |
| data = { |
| 'mesh_attributes': mesh_attributes, |
| 'input_view_num': input_view_num, |
| 'target_view_num': self.target_view_num, |
| 'obj_path': obj_path, |
| 'pose_list': pose_list, |
| 'camera_pos': camera_pos, |
| 'c2w_list': c2w_list, |
| 'env_list': env_list, |
| 'material_list': material_list, |
| 'camera_embedding_list': camera_embedding_list, |
| 'fov_deg':fov_deg, |
| 'raduis': cam_radius |
| } |
| |
| return data |
|
|
| class ValidationData(Dataset): |
| def __init__(self, |
| root_dir='objaverse/', |
| input_view_num=6, |
| input_image_size=320, |
| fov=30, |
| ): |
| self.root_dir = Path(root_dir) |
| self.input_view_num = input_view_num |
| self.input_image_size = input_image_size |
| self.fov = fov |
| self.light_dir = 'env_mipmap' |
|
|
| |
| |
| |
| self.paths = os.listdir(self.root_dir) |
| |
| |
| print('============= length of dataset %d =============' % len(self.paths)) |
|
|
| cam_distance = 4.0 |
| azimuths = np.array([30, 90, 150, 210, 270, 330]) |
| elevations = np.array([20, -10, 20, -10, 20, -10]) |
| azimuths = np.deg2rad(azimuths) |
| elevations = np.deg2rad(elevations) |
|
|
| x = cam_distance * np.cos(elevations) * np.cos(azimuths) |
| y = cam_distance * np.cos(elevations) * np.sin(azimuths) |
| z = cam_distance * np.sin(elevations) |
|
|
| cam_locations = np.stack([x, y, z], axis=-1) |
| cam_locations = torch.from_numpy(cam_locations).float() |
| c2ws = center_looking_at_camera_pose(cam_locations) |
| self.c2ws = c2ws.float() |
| self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float() |
|
|
| render_c2ws = get_circular_camera_poses(M=8, radius=cam_distance, elevation=20.0) |
| render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1) |
| self.render_c2ws = render_c2ws.float() |
| self.render_Ks = render_Ks.float() |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| def load_im(self, path, color): |
| ''' |
| replace background pixel with random color in rendering |
| ''' |
| pil_img = Image.open(path) |
| pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC) |
|
|
| image = np.asarray(pil_img, dtype=np.float32) / 255. |
| if image.shape[-1] == 4: |
| alpha = image[:, :, 3:] |
| image = image[:, :, :3] * alpha + color * (1 - alpha) |
| else: |
| alpha = np.ones_like(image[:, :, :1]) |
|
|
| image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() |
| alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() |
| return image, alpha |
| |
| def load_mat(self, path, color): |
| ''' |
| replace background pixel with random color in rendering |
| ''' |
| pil_img = Image.open(path) |
| pil_img = pil_img.resize((384,384), resample=Image.BICUBIC) |
|
|
| image = np.asarray(pil_img, dtype=np.float32) / 255. |
| if image.shape[-1] == 4: |
| alpha = image[:, :, 3:] |
| image = image[:, :, :3] * alpha + color * (1 - alpha) |
| else: |
| alpha = np.ones_like(image[:, :, :1]) |
|
|
| image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() |
| alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() |
| return image, alpha |
| |
| def load_albedo(self, path, color, mask): |
| ''' |
| replace background pixel with random color in rendering |
| ''' |
| pil_img = Image.open(path) |
| pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC) |
|
|
| image = np.asarray(pil_img, dtype=np.float32) / 255. |
| image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() |
|
|
| color = torch.ones_like(image) |
| image = image * mask + color * (1 - mask) |
| return image |
| |
| def __getitem__(self, index): |
| |
| |
| input_image_path = os.path.join(self.root_dir, self.paths[index]) |
|
|
| '''background color, default: white''' |
| bkg_color = [1.0, 1.0, 1.0] |
|
|
| image_list = [] |
| albedo_list = [] |
| alpha_list = [] |
| specular_list = [] |
| diffuse_list = [] |
| metallic_list = [] |
| roughness_list = [] |
| |
| exist_comb_list = [] |
| for subfolder in os.listdir(input_image_path): |
| found_numeric_subfolder=False |
| subfolder_path = os.path.join(input_image_path, subfolder) |
| if os.path.isdir(subfolder_path) and '_' in subfolder and 'specular' not in subfolder and 'diffuse' not in subfolder: |
| try: |
| parts = subfolder.split('_') |
| float(parts[0]) |
| float(parts[1]) |
| found_numeric_subfolder = True |
| except ValueError: |
| continue |
| if found_numeric_subfolder: |
| exist_comb_list.append(subfolder) |
| |
| selected_one_comb = random.choice(exist_comb_list) |
|
|
|
|
| for idx in range(self.input_view_num): |
| img_path = find_matching_files(os.path.join(input_image_path, selected_one_comb, 'rgb'), idx) |
| albedo_path = img_path.replace('rgb', 'albedo') |
| metallic_path = img_path.replace('rgb', 'metallic') |
| roughness_path = img_path.replace('rgb', 'roughness') |
| |
| image, alpha = self.load_im(img_path, bkg_color) |
| albedo = self.load_albedo(albedo_path, bkg_color, alpha) |
| metallic,_ = self.load_mat(metallic_path, bkg_color) |
| roughness,_ = self.load_mat(roughness_path, bkg_color) |
| |
| light_num = os.path.basename(img_path).split('_')[1].split('.')[0] |
| light_path = os.path.join(self.light_dir, str(int(light_num)+1)) |
| |
| specular, diffuse = load_mipmap(light_path) |
| |
| image_list.append(image) |
| alpha_list.append(alpha) |
| albedo_list.append(albedo) |
| metallic_list.append(metallic) |
| roughness_list.append(roughness) |
| specular_list.append(specular) |
| diffuse_list.append(diffuse) |
| |
| images = torch.stack(image_list, dim=0).float() |
| alphas = torch.stack(alpha_list, dim=0).float() |
| albedo = torch.stack(albedo_list, dim=0).float() |
| metallic = torch.stack(metallic_list, dim=0).float() |
| roughness = torch.stack(roughness_list, dim=0).float() |
|
|
| data = { |
| 'input_images': images, |
| 'input_alphas': alphas, |
| 'input_c2ws': self.c2ws, |
| 'input_Ks': self.Ks, |
| |
| 'input_albedos': albedo[:self.input_view_num], |
| 'input_metallics': metallic[:self.input_view_num], |
| 'input_roughness': roughness[:self.input_view_num], |
| |
| 'specular': specular_list[:self.input_view_num], |
| 'diffuse': diffuse_list[:self.input_view_num], |
|
|
| 'render_c2ws': self.render_c2ws, |
| 'render_Ks': self.render_Ks, |
| } |
| return data |
|
|
|
|
| if __name__ == '__main__': |
| dataset = ObjaverseData() |
| dataset.new(1) |
|
|