| import glob |
| import logging |
| import os |
| import sys |
| import time |
|
|
| import cv2 |
| import numpy as np |
| from absl import app |
| import gin |
| from internal import configs |
| from internal import datasets |
| from internal import models |
| from internal import utils |
| from internal import coord |
| from internal import checkpoints |
| from internal import configs |
| import torch |
| import accelerate |
| from tqdm import tqdm |
| from torch.utils._pytree import tree_map |
| import torch.nn.functional as F |
| from skimage import measure |
| import trimesh |
| import pymeshlab as pml |
| from torch import Tensor |
|
|
| configs.define_common_flags() |
|
|
|
|
| class TSDF: |
| def __init__(self, config: configs.Config, accelerator: accelerate.Accelerator): |
| self.config = config |
| self.device = accelerator.device |
| self.accelerator = accelerator |
| self.origin = torch.tensor([-config.tsdf_radius] * 3, dtype=torch.float32, device=self.device) |
| self.voxel_size = 2 * config.tsdf_radius / (config.tsdf_resolution - 1) |
| self.resolution = config.tsdf_resolution |
| |
| dim = torch.arange(self.resolution) |
| grid = torch.stack(torch.meshgrid(dim, dim, dim, indexing="ij"), dim=0).reshape(3, -1) |
| period = int(grid.shape[1] / accelerator.num_processes + 0.5) |
| grid = grid[:, period * accelerator.process_index: period * (accelerator.process_index + 1)] |
| self.voxel_coords = self.origin.view(3, 1) + grid.to(self.device) * self.voxel_size |
|
|
| N = self.voxel_coords.shape[1] |
| |
| voxel_world_coords = coord.inv_contract(self.voxel_coords.permute(1, 0)).permute(1, 0).view(3, -1) |
| |
| voxel_world_coords = torch.cat( |
| [voxel_world_coords, torch.ones(1, voxel_world_coords.shape[1], device=self.device)], dim=0 |
| ) |
| voxel_world_coords = voxel_world_coords.unsqueeze(0) |
| self.voxel_world_coords = voxel_world_coords.expand(-1, *voxel_world_coords.shape[1:]) |
|
|
| |
| self.values = torch.ones(N, dtype=torch.float32, |
| device=self.device) |
| self.weights = torch.zeros(N, dtype=torch.float32, |
| device=self.device) |
| self.colors = torch.zeros(N, 3, dtype=torch.float32, |
| device=self.device) |
|
|
| @property |
| def truncation(self): |
| """Returns the truncation distance.""" |
| |
| truncation = self.voxel_size * self.config.truncation_margin |
| return truncation |
|
|
| def export_mesh(self, path): |
| """Extracts a mesh using marching cubes.""" |
| |
| tsdf_values = self.values.clamp(-1, 1) |
| mask = self.voxel_world_coords[:, :3].permute(0, 2, 1).norm(p=2, dim=-1) > self.config.tsdf_max_radius |
| tsdf_values[mask.reshape(self.values.shape)] = 1. |
|
|
| tsdf_values_np = self.accelerator.gather(tsdf_values).cpu().reshape((self.resolution, self.resolution, self.resolution)).numpy() |
| color_values_np = self.accelerator.gather(self.colors).cpu().reshape((self.resolution, self.resolution, self.resolution, 3)).numpy() |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| if self.accelerator.is_main_process: |
| |
| |
| |
| |
| |
| |
| |
| |
| vertices, faces, normals, _ = measure.marching_cubes( |
| tsdf_values_np, |
| level=0, |
| allow_degenerate=False, |
| ) |
|
|
| vertices_indices = np.round(vertices).astype(int) |
| colors = color_values_np[vertices_indices[:, 0], vertices_indices[:, 1], vertices_indices[:, 2]] |
|
|
| |
| vertices = self.origin.cpu().numpy() + vertices * self.voxel_size |
| vertices = coord.inv_contract_np(vertices) |
| trimesh.Trimesh(vertices=vertices, |
| faces=faces, |
| normals=normals, |
| vertex_colors=colors, |
| ).export(path) |
|
|
| @torch.no_grad() |
| def integrate_tsdf( |
| self, |
| c2w, |
| K, |
| depth_images, |
| color_images=None, |
| ): |
| """Integrates a batch of depth images into the TSDF. |
| |
| Args: |
| c2w: The camera extrinsics. |
| K: The camera intrinsics. |
| depth_images: The depth images to integrate. |
| color_images: The color images to integrate. |
| """ |
| batch_size = c2w.shape[0] |
| shape = self.voxel_coords.shape[1:] |
|
|
| |
| image_size = torch.tensor( |
| [depth_images.shape[-1], depth_images.shape[-2]], device=self.device |
| ) |
|
|
| |
| voxel_world_coords = self.voxel_world_coords.expand(batch_size, |
| *self.voxel_world_coords.shape[1:]) |
|
|
| voxel_cam_coords = torch.bmm(torch.inverse(c2w), voxel_world_coords) |
|
|
| |
| voxel_cam_coords[:, 2, :] = -voxel_cam_coords[:, 2, :] |
| |
| voxel_cam_coords[:, 1, :] = -voxel_cam_coords[:, 1, :] |
|
|
| |
| |
| |
|
|
| voxel_cam_coords_z = voxel_cam_coords[:, 2:3, :] |
| voxel_depth = voxel_cam_coords_z |
|
|
| voxel_cam_points = torch.bmm(K[None].expand(batch_size, -1, -1), |
| voxel_cam_coords[:, 0:3, :] / voxel_cam_coords_z) |
| voxel_pixel_coords = voxel_cam_points[:, :2, :] |
|
|
| |
|
|
| grid = voxel_pixel_coords.permute(0, 2, 1) |
| |
| grid = 2.0 * grid / image_size.view(1, 1, 2) - 1.0 |
| grid = grid[:, None] |
| |
| sampled_depth = F.grid_sample( |
| input=depth_images, grid=grid, mode="nearest", padding_mode="zeros", align_corners=False |
| ) |
| sampled_depth = sampled_depth.squeeze(2) |
| |
| sampled_colors = None |
| if color_images is not None: |
| sampled_colors = F.grid_sample( |
| input=color_images, grid=grid, mode="nearest", padding_mode="zeros", align_corners=False |
| ) |
| sampled_colors = sampled_colors.squeeze(2) |
|
|
| dist = sampled_depth - voxel_depth |
|
|
| |
| |
| |
| |
| |
| |
|
|
| truncation = self.truncation |
|
|
| tsdf_values = torch.clamp(dist / truncation, min=-1.0, max=1.0) |
| valid_points = (voxel_depth > 0) & (sampled_depth > 0) & (dist > -self.truncation) |
|
|
| |
| for i in range(batch_size): |
| valid_points_i = valid_points[i] |
| valid_points_i_shape = valid_points_i.view(*shape) |
|
|
| |
| old_tsdf_values_i = self.values[valid_points_i_shape] |
| old_weights_i = self.weights[valid_points_i_shape] |
|
|
| |
| |
| new_tsdf_values_i = tsdf_values[i][valid_points_i] |
| new_weights_i = 1.0 |
|
|
| total_weights = old_weights_i + new_weights_i |
|
|
| self.values[valid_points_i_shape] = (old_tsdf_values_i * old_weights_i + |
| new_tsdf_values_i * new_weights_i) / total_weights |
| |
| self.weights[valid_points_i_shape] = total_weights |
|
|
| if sampled_colors is not None: |
| old_colors_i = self.colors[valid_points_i_shape] |
| new_colors_i = sampled_colors[i][:, valid_points_i.squeeze(0)].permute(1, 0) |
| self.colors[valid_points_i_shape] = (old_colors_i * old_weights_i[:, None] + |
| new_colors_i * new_weights_i) / total_weights[:, None] |
|
|
|
|
| def main(unused_argv): |
| config = configs.load_config() |
| config.compute_visibility = True |
|
|
| config.exp_path = os.path.join("exp", config.exp_name) |
| config.mesh_path = os.path.join("exp", config.exp_name, "mesh") |
| config.checkpoint_dir = os.path.join(config.exp_path, 'checkpoints') |
| os.makedirs(config.mesh_path, exist_ok=True) |
|
|
| |
| accelerator = accelerate.Accelerator() |
| device = accelerator.device |
|
|
| |
| logging.basicConfig( |
| format="%(asctime)s: %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| force=True, |
| handlers=[logging.StreamHandler(sys.stdout), |
| logging.FileHandler(os.path.join(config.exp_path, 'log_extract.txt'))], |
| level=logging.INFO, |
| ) |
| sys.excepthook = utils.handle_exception |
| logger = accelerate.logging.get_logger(__name__) |
| logger.info(config) |
| logger.info(accelerator.state, main_process_only=False) |
|
|
| config.world_size = accelerator.num_processes |
| config.global_rank = accelerator.process_index |
| accelerate.utils.set_seed(config.seed, device_specific=True) |
|
|
| |
| model = models.Model(config=config) |
| model = accelerator.prepare(model) |
| step = checkpoints.restore_checkpoint(config.checkpoint_dir, accelerator, logger) |
| model.eval() |
| module = accelerator.unwrap_model(model) |
|
|
| dataset = datasets.load_dataset('train', config.data_dir, config) |
| dataloader = torch.utils.data.DataLoader(np.arange(len(dataset)), |
| shuffle=False, |
| batch_size=1, |
| collate_fn=dataset.collate_fn, |
| ) |
| dataiter = iter(dataloader) |
| if config.rawnerf_mode: |
| postprocess_fn = dataset.metadata['postprocess_fn'] |
| else: |
| postprocess_fn = lambda z: z |
|
|
| out_name = f'train_preds_step_{step}' |
| out_dir = os.path.join(config.mesh_path, out_name) |
| utils.makedirs(out_dir) |
| logger.info("Render trainset in {}".format(out_dir)) |
|
|
| path_fn = lambda x: os.path.join(out_dir, x) |
|
|
| |
| zpad = max(3, len(str(dataset.size - 1))) |
| idx_to_str = lambda idx: str(idx).zfill(zpad) |
|
|
| for idx in range(dataset.size): |
| |
| idx_str = idx_to_str(idx) |
| curr_file = path_fn(f'color_{idx_str}.png') |
| if utils.file_exists(curr_file): |
| logger.info(f'Image {idx + 1}/{dataset.size} already exists, skipping') |
| continue |
| batch = next(dataiter) |
| batch = tree_map(lambda x: x.to(accelerator.device) if x is not None else None, batch) |
| logger.info(f'Evaluating image {idx + 1}/{dataset.size}') |
| eval_start_time = time.time() |
| rendering = models.render_image(model, accelerator, |
| batch, False, 1, config) |
|
|
| logger.info(f'Rendered in {(time.time() - eval_start_time):0.3f}s') |
|
|
| if accelerator.is_main_process: |
| rendering['rgb'] = postprocess_fn(rendering['rgb']) |
| rendering = tree_map(lambda x: x.detach().cpu().numpy() if x is not None else None, rendering) |
| utils.save_img_u8(rendering['rgb'], path_fn(f'color_{idx_str}.png')) |
| utils.save_img_f32(rendering['distance_mean'], |
| path_fn(f'distance_mean_{idx_str}.tiff')) |
| utils.save_img_f32(rendering['distance_median'], |
| path_fn(f'distance_median_{idx_str}.tiff')) |
|
|
| |
| tsdf = TSDF(config, accelerator) |
|
|
| c2w = torch.from_numpy(dataset.camtoworlds[:, :3, :4]).float().to(device) |
|
|
| |
| c2w = torch.cat([c2w, torch.zeros(c2w.shape[0], 1, 4, device=device)], dim=1) |
| c2w[:, 3, 3] = 1 |
| K = torch.from_numpy(dataset.pixtocams).float().to(device).inverse() |
|
|
| logger.info('Reading images') |
| rgb_files = sorted(glob.glob(path_fn('color_*.png'))) |
| depth_files = sorted(glob.glob(path_fn('distance_median_*.tiff'))) |
| assert len(rgb_files) == len(depth_files) |
| color_images = [] |
| depth_images = [] |
| for rgb_file, depth_file in zip(tqdm(rgb_files, disable=not accelerator.is_main_process), depth_files): |
| color_images.append(utils.load_img(rgb_file) / 255) |
| depth_images.append(utils.load_img(depth_file)[..., None]) |
|
|
| color_images = torch.tensor(np.array(color_images), device=device).permute(0, 3, 1, 2) |
| depth_images = torch.tensor(np.array(depth_images), device=device).permute(0, 3, 1, 2) |
|
|
| batch_size = 1 |
| logger.info("Integrating the TSDF") |
| for i in tqdm(range(0, len(c2w), batch_size), disable=not accelerator.is_main_process): |
| tsdf.integrate_tsdf( |
| c2w[i: i + batch_size], |
| K, |
| depth_images[i: i + batch_size], |
| color_images=color_images[i: i + batch_size], |
| ) |
|
|
| logger.info("Saving TSDF Mesh") |
| tsdf.export_mesh(os.path.join(config.mesh_path, "tsdf_mesh.ply")) |
| accelerator.wait_for_everyone() |
| logger.info('Finish extracting mesh using TSDF.') |
|
|
|
|
| if __name__ == '__main__': |
| with gin.config_scope('bake'): |
| app.run(main) |
|
|