| import accelerate |
| import gin |
| from internal import coord |
| from internal import geopoly |
| from internal import image |
| from internal import math |
| from internal import ref_utils |
| from internal import train_utils |
| from internal import render |
| from internal import stepfun |
| from internal import utils |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils._pytree import tree_map |
| from tqdm import tqdm |
| from gridencoder import GridEncoder |
| from torch_scatter import segment_coo |
|
|
| gin.config.external_configurable(math.safe_exp, module='math') |
|
|
|
|
| def set_kwargs(self, kwargs): |
| for k, v in kwargs.items(): |
| setattr(self, k, v) |
|
|
|
|
| @gin.configurable |
| class Model(nn.Module): |
| """A mip-Nerf360 model containing all MLPs.""" |
| num_prop_samples: int = 64 |
| num_nerf_samples: int = 32 |
| num_levels: int = 3 |
| bg_intensity_range = (1., 1.) |
| anneal_slope: float = 10 |
| stop_level_grad: bool = True |
| use_viewdirs: bool = True |
| raydist_fn = None |
| single_jitter: bool = True |
| dilation_multiplier: float = 0.5 |
| dilation_bias: float = 0.0025 |
| num_glo_features: int = 0 |
| num_glo_embeddings: int = 1000 |
| learned_exposure_scaling: bool = False |
| near_anneal_rate = None |
| near_anneal_init: float = 0.95 |
| single_mlp: bool = False |
| distinct_prop: bool = True |
| resample_padding: float = 0.0 |
| opaque_background: bool = False |
| power_lambda: float = -1.5 |
| std_scale: float = 0.5 |
| prop_desired_grid_size = [512, 2048] |
|
|
| def __init__(self, config=None, **kwargs): |
| super().__init__() |
| set_kwargs(self, kwargs) |
| self.config = config |
|
|
| |
| |
| self.nerf_mlp = NerfMLP(num_glo_features=self.num_glo_features, |
| num_glo_embeddings=self.num_glo_embeddings) |
| if self.single_mlp: |
| self.prop_mlp = self.nerf_mlp |
| elif not self.distinct_prop: |
| self.prop_mlp = PropMLP() |
| else: |
| for i in range(self.num_levels - 1): |
| self.register_module(f'prop_mlp_{i}', PropMLP(grid_disired_resolution=self.prop_desired_grid_size[i])) |
| if self.num_glo_features > 0 and not config.zero_glo: |
| |
| self.glo_vecs = nn.Embedding(self.num_glo_embeddings, self.num_glo_features) |
|
|
| if self.learned_exposure_scaling: |
| |
| max_num_exposures = self.num_glo_embeddings |
| |
| self.exposure_scaling_offsets = nn.Embedding(max_num_exposures, 3) |
| torch.nn.init.zeros_(self.exposure_scaling_offsets.weight) |
|
|
| def forward( |
| self, |
| rand, |
| batch, |
| train_frac, |
| compute_extras, |
| zero_glo=True, |
| ): |
| """The mip-NeRF Model. |
| |
| Args: |
| rand: random number generator (or None for deterministic output). |
| batch: util.Rays, a pytree of ray origins, directions, and viewdirs. |
| train_frac: float in [0, 1], what fraction of training is complete. |
| compute_extras: bool, if True, compute extra quantities besides color. |
| zero_glo: bool, if True, when using GLO pass in vector of zeros. |
| |
| Returns: |
| ret: list, [*(rgb, distance, acc)] |
| """ |
| device = batch['origins'].device |
| if self.num_glo_features > 0: |
| if not zero_glo: |
| |
| cam_idx = batch['cam_idx'][..., 0] |
| glo_vec = self.glo_vecs(cam_idx.long()) |
| else: |
| glo_vec = torch.zeros(batch['origins'].shape[:-1] + (self.num_glo_features,), device=device) |
| else: |
| glo_vec = None |
|
|
| |
| _, s_to_t = coord.construct_ray_warps(self.raydist_fn, batch['near'], batch['far'], self.power_lambda) |
|
|
| |
| |
| |
| |
| |
| if self.near_anneal_rate is None: |
| init_s_near = 0. |
| else: |
| init_s_near = np.clip(1 - train_frac / self.near_anneal_rate, 0, |
| self.near_anneal_init) |
| init_s_far = 1. |
| sdist = torch.cat([ |
| torch.full_like(batch['near'], init_s_near), |
| torch.full_like(batch['far'], init_s_far) |
| ], dim=-1) |
| weights = torch.ones_like(batch['near']) |
| prod_num_samples = 1 |
|
|
| ray_history = [] |
| renderings = [] |
| for i_level in range(self.num_levels): |
| is_prop = i_level < (self.num_levels - 1) |
| num_samples = self.num_prop_samples if is_prop else self.num_nerf_samples |
|
|
| |
| |
| dilation = self.dilation_bias + self.dilation_multiplier * ( |
| init_s_far - init_s_near) / prod_num_samples |
|
|
| |
| prod_num_samples *= num_samples |
|
|
| |
| |
| |
| use_dilation = self.dilation_bias > 0 or self.dilation_multiplier > 0 |
| if i_level > 0 and use_dilation: |
| sdist, weights = stepfun.max_dilate_weights( |
| sdist, |
| weights, |
| dilation, |
| domain=(init_s_near, init_s_far), |
| renormalize=True) |
| sdist = sdist[..., 1:-1] |
| weights = weights[..., 1:-1] |
|
|
| |
| if self.anneal_slope > 0: |
| |
| bias = lambda x, s: (s * x) / ((s - 1) * x + 1) |
| anneal = bias(train_frac, self.anneal_slope) |
| else: |
| anneal = 1. |
|
|
| |
| |
| logits_resample = torch.where( |
| sdist[..., 1:] > sdist[..., :-1], |
| anneal * torch.log(weights + self.resample_padding), |
| torch.full_like(sdist[..., :-1], -torch.inf)) |
|
|
| |
| sdist = stepfun.sample_intervals( |
| rand, |
| sdist, |
| logits_resample, |
| num_samples, |
| single_jitter=self.single_jitter, |
| domain=(init_s_near, init_s_far)) |
|
|
| |
| |
| if self.stop_level_grad: |
| sdist = sdist.detach() |
|
|
| |
| tdist = s_to_t(sdist) |
|
|
| |
| means, stds, ts = render.cast_rays( |
| tdist, |
| batch['origins'], |
| batch['directions'], |
| batch['cam_dirs'], |
| batch['radii'], |
| rand, |
| std_scale=self.std_scale) |
|
|
| |
| mlp = (self.get_submodule( |
| f'prop_mlp_{i_level}') if self.distinct_prop else self.prop_mlp) if is_prop else self.nerf_mlp |
| ray_results = mlp( |
| rand, |
| means, stds, |
| viewdirs=batch['viewdirs'] if self.use_viewdirs else None, |
| imageplane=batch.get('imageplane'), |
| glo_vec=None if is_prop else glo_vec, |
| exposure=batch.get('exposure_values'), |
| ) |
| if self.config.gradient_scaling: |
| ray_results['rgb'], ray_results['density'] = train_utils.GradientScaler.apply( |
| ray_results['rgb'], ray_results['density'], ts.mean(dim=-1)) |
|
|
| |
| weights = render.compute_alpha_weights( |
| ray_results['density'], |
| tdist, |
| batch['directions'], |
| opaque_background=self.opaque_background, |
| )[0] |
|
|
| |
| if self.bg_intensity_range[0] == self.bg_intensity_range[1]: |
| |
| bg_rgbs = self.bg_intensity_range[0] |
| elif rand is None: |
| |
| bg_rgbs = (self.bg_intensity_range[0] + self.bg_intensity_range[1]) / 2 |
| else: |
| |
| minval = self.bg_intensity_range[0] |
| maxval = self.bg_intensity_range[1] |
| bg_rgbs = torch.rand(weights.shape[:-1] + (3,), device=device) * (maxval - minval) + minval |
|
|
| |
| if batch.get('exposure_idx') is not None: |
| |
| ray_results['rgb'] *= batch['exposure_values'][..., None, :] |
| if self.learned_exposure_scaling: |
| exposure_idx = batch['exposure_idx'][..., 0] |
| |
| |
| mask = exposure_idx > 0 |
| |
| scaling = 1 + mask[..., None] * self.exposure_scaling_offsets(exposure_idx.long()) |
| ray_results['rgb'] *= scaling[..., None, :] |
|
|
| |
| rendering = render.volumetric_rendering( |
| ray_results['rgb'], |
| weights, |
| tdist, |
| bg_rgbs, |
| batch['far'], |
| compute_extras, |
| extras={ |
| k: v |
| for k, v in ray_results.items() |
| if k.startswith('normals') or k in ['roughness'] |
| }) |
|
|
| if compute_extras: |
| |
| |
| |
| n = self.config.vis_num_rays |
| rendering['ray_sdist'] = sdist.reshape([-1, sdist.shape[-1]])[:n, :] |
| rendering['ray_weights'] = ( |
| weights.reshape([-1, weights.shape[-1]])[:n, :]) |
| rgb = ray_results['rgb'] |
| rendering['ray_rgbs'] = (rgb.reshape((-1,) + rgb.shape[-2:]))[:n, :, :] |
|
|
| if self.training: |
| |
| idx = mlp.encoder.idx |
| param = mlp.encoder.embeddings |
| loss_hash_decay = segment_coo(param ** 2, |
| idx, |
| torch.zeros(idx.max() + 1, param.shape[-1], device=param.device), |
| reduce='mean' |
| ).mean() |
| ray_results['loss_hash_decay'] = loss_hash_decay |
|
|
| renderings.append(rendering) |
| ray_results['sdist'] = sdist.clone() |
| ray_results['weights'] = weights.clone() |
| ray_history.append(ray_results) |
|
|
| if compute_extras: |
| |
| |
| |
| weights = [r['ray_weights'] for r in renderings] |
| rgbs = [r['ray_rgbs'] for r in renderings] |
| final_rgb = torch.sum(rgbs[-1] * weights[-1][..., None], dim=-2) |
| avg_rgbs = [ |
| torch.broadcast_to(final_rgb[:, None, :], r.shape) for r in rgbs[:-1] |
| ] |
| for i in range(len(avg_rgbs)): |
| renderings[i]['ray_rgbs'] = avg_rgbs[i] |
|
|
| return renderings, ray_history |
|
|
|
|
| class MLP(nn.Module): |
| """A PosEnc MLP.""" |
| bottleneck_width: int = 256 |
| net_depth_viewdirs: int = 2 |
| net_width_viewdirs: int = 256 |
| skip_layer_dir: int = 0 |
| num_rgb_channels: int = 3 |
| deg_view: int = 4 |
| use_reflections: bool = False |
| use_directional_enc: bool = False |
| |
| enable_pred_roughness: bool = False |
| roughness_bias: float = -1. |
| use_diffuse_color: bool = False |
| use_specular_tint: bool = False |
| use_n_dot_v: bool = False |
| bottleneck_noise: float = 0.0 |
| density_bias: float = -1. |
| density_noise: float = 0. |
| rgb_premultiplier: float = 1. |
| rgb_bias: float = 0. |
| rgb_padding: float = 0.001 |
| enable_pred_normals: bool = False |
| disable_density_normals: bool = False |
| disable_rgb: bool = False |
| warp_fn = 'contract' |
| num_glo_features: int = 0 |
| num_glo_embeddings: int = 1000 |
| scale_featurization: bool = False |
| grid_num_levels: int = 10 |
| grid_level_interval: int = 2 |
| grid_level_dim: int = 4 |
| grid_base_resolution: int = 16 |
| grid_disired_resolution: int = 8192 |
| grid_log2_hashmap_size: int = 21 |
| net_width_glo: int = 128 |
| net_depth_glo: int = 2 |
|
|
| def __init__(self, **kwargs): |
| super().__init__() |
| set_kwargs(self, kwargs) |
| |
| if self.use_reflections and not (self.enable_pred_normals or |
| not self.disable_density_normals): |
| raise ValueError('Normals must be computed for reflection directions.') |
|
|
| |
| if self.use_directional_enc: |
| self.dir_enc_fn = ref_utils.generate_ide_fn(self.deg_view) |
| dim_dir_enc = self.dir_enc_fn(torch.zeros(1, 3), torch.zeros(1, 1)).shape[-1] |
| else: |
|
|
| def dir_enc_fn(direction, _): |
| return coord.pos_enc( |
| direction, min_deg=0, max_deg=self.deg_view, append_identity=True) |
|
|
| self.dir_enc_fn = dir_enc_fn |
| dim_dir_enc = self.dir_enc_fn(torch.zeros(1, 3), None).shape[-1] |
| self.grid_num_levels = int( |
| np.log(self.grid_disired_resolution / self.grid_base_resolution) / np.log(self.grid_level_interval)) + 1 |
| self.encoder = GridEncoder(input_dim=3, |
| num_levels=self.grid_num_levels, |
| level_dim=self.grid_level_dim, |
| base_resolution=self.grid_base_resolution, |
| desired_resolution=self.grid_disired_resolution, |
| log2_hashmap_size=self.grid_log2_hashmap_size, |
| gridtype='hash', |
| align_corners=False) |
| last_dim = self.encoder.output_dim |
| if self.scale_featurization: |
| last_dim += self.encoder.num_levels |
| self.density_layer = nn.Sequential(nn.Linear(last_dim, 64), |
| nn.ReLU(), |
| nn.Linear(64, |
| 1 if self.disable_rgb else self.bottleneck_width)) |
| last_dim = 1 if self.disable_rgb and not self.enable_pred_normals else self.bottleneck_width |
| if self.enable_pred_normals: |
| self.normal_layer = nn.Linear(last_dim, 3) |
|
|
| if not self.disable_rgb: |
| if self.use_diffuse_color: |
| self.diffuse_layer = nn.Linear(last_dim, self.num_rgb_channels) |
|
|
| if self.use_specular_tint: |
| self.specular_layer = nn.Linear(last_dim, 3) |
|
|
| if self.enable_pred_roughness: |
| self.roughness_layer = nn.Linear(last_dim, 1) |
|
|
| |
| if self.bottleneck_width > 0: |
| last_dim_rgb = self.bottleneck_width |
| else: |
| last_dim_rgb = 0 |
|
|
| last_dim_rgb += dim_dir_enc |
|
|
| if self.use_n_dot_v: |
| last_dim_rgb += 1 |
|
|
| if self.num_glo_features > 0: |
| last_dim_glo = self.num_glo_features |
| for i in range(self.net_depth_glo - 1): |
| self.register_module(f"lin_glo_{i}", nn.Linear(last_dim_glo, self.net_width_glo)) |
| last_dim_glo = self.net_width_glo |
| self.register_module(f"lin_glo_{self.net_depth_glo - 1}", |
| nn.Linear(last_dim_glo, self.bottleneck_width * 2)) |
|
|
| input_dim_rgb = last_dim_rgb |
| for i in range(self.net_depth_viewdirs): |
| lin = nn.Linear(last_dim_rgb, self.net_width_viewdirs) |
| torch.nn.init.kaiming_uniform_(lin.weight) |
| self.register_module(f"lin_second_stage_{i}", lin) |
| last_dim_rgb = self.net_width_viewdirs |
| if i == self.skip_layer_dir: |
| last_dim_rgb += input_dim_rgb |
| self.rgb_layer = nn.Linear(last_dim_rgb, self.num_rgb_channels) |
|
|
| def predict_density(self, means, stds, rand=False, no_warp=False): |
| """Helper function to output density.""" |
| |
| if self.warp_fn is not None and not no_warp: |
| means, stds = coord.track_linearize(self.warp_fn, means, stds) |
| |
| bound = 2 |
| means = means / bound |
| stds = stds / bound |
| features = self.encoder(means, bound=1).unflatten(-1, (self.encoder.num_levels, -1)) |
| weights = torch.erf(1 / torch.sqrt(8 * stds[..., None] ** 2 * self.encoder.grid_sizes ** 2)) |
| features = (features * weights[..., None]).mean(dim=-3).flatten(-2, -1) |
| if self.scale_featurization: |
| with torch.no_grad(): |
| vl2mean = segment_coo((self.encoder.embeddings ** 2).sum(-1), |
| self.encoder.idx, |
| torch.zeros(self.grid_num_levels, device=weights.device), |
| self.grid_num_levels, |
| reduce='mean' |
| ) |
| featurized_w = (2 * weights.mean(dim=-2) - 1) * (self.encoder.init_std ** 2 + vl2mean).sqrt() |
| features = torch.cat([features, featurized_w], dim=-1) |
| x = self.density_layer(features) |
| raw_density = x[..., 0] |
| |
| if rand and (self.density_noise > 0): |
| raw_density += self.density_noise * torch.randn_like(raw_density) |
| return raw_density, x, means.mean(dim=-2) |
|
|
| def forward(self, |
| rand, |
| means, stds, |
| viewdirs=None, |
| imageplane=None, |
| glo_vec=None, |
| exposure=None, |
| no_warp=False): |
| """Evaluate the MLP. |
| |
| Args: |
| rand: if random . |
| means: [..., n, 3], coordinate means. |
| stds: [..., n], coordinate stds. |
| viewdirs: [..., 3], if not None, this variable will |
| be part of the input to the second part of the MLP concatenated with the |
| output vector of the first part of the MLP. If None, only the first part |
| of the MLP will be used with input x. In the original paper, this |
| variable is the view direction. |
| imageplane:[batch, 2], xy image plane coordinates |
| for each ray in the batch. Useful for image plane operations such as a |
| learned vignette mapping. |
| glo_vec: [..., num_glo_features], The GLO vector for each ray. |
| exposure: [..., 1], exposure value (shutter_speed * ISO) for each ray. |
| |
| Returns: |
| rgb: [..., num_rgb_channels]. |
| density: [...]. |
| normals: [..., 3], or None. |
| normals_pred: [..., 3], or None. |
| roughness: [..., 1], or None. |
| """ |
| if self.disable_density_normals: |
| raw_density, x, means_contract = self.predict_density(means, stds, rand=rand, no_warp=no_warp) |
| raw_grad_density = None |
| normals = None |
| else: |
| with torch.enable_grad(): |
| means.requires_grad_(True) |
| raw_density, x, means_contract = self.predict_density(means, stds, rand=rand, no_warp=no_warp) |
| d_output = torch.ones_like(raw_density, requires_grad=False, device=raw_density.device) |
| raw_grad_density = torch.autograd.grad( |
| outputs=raw_density, |
| inputs=means, |
| grad_outputs=d_output, |
| create_graph=True, |
| retain_graph=True, |
| only_inputs=True)[0] |
| raw_grad_density = raw_grad_density.mean(-2) |
| |
| |
| |
| |
| normals = -ref_utils.l2_normalize(raw_grad_density) |
|
|
| if self.enable_pred_normals: |
| grad_pred = self.normal_layer(x) |
|
|
| |
| normals_pred = -ref_utils.l2_normalize(grad_pred) |
| normals_to_use = normals_pred |
| else: |
| grad_pred = None |
| normals_pred = None |
| normals_to_use = normals |
|
|
| |
| density = F.softplus(raw_density + self.density_bias) |
|
|
| roughness = None |
| if self.disable_rgb: |
| rgb = torch.zeros(density.shape + (3,), device=density.device) |
| else: |
| if viewdirs is not None: |
| |
| if self.use_diffuse_color: |
| raw_rgb_diffuse = self.diffuse_layer(x) |
|
|
| if self.use_specular_tint: |
| tint = torch.sigmoid(self.specular_layer(x)) |
|
|
| if self.enable_pred_roughness: |
| raw_roughness = self.roughness_layer(x) |
| roughness = (F.softplus(raw_roughness + self.roughness_bias)) |
|
|
| |
| if self.bottleneck_width > 0: |
| bottleneck = x |
| |
| if rand and (self.bottleneck_noise > 0): |
| bottleneck += self.bottleneck_noise * torch.randn_like(bottleneck) |
|
|
| |
| if glo_vec is not None: |
| for i in range(self.net_depth_glo): |
| glo_vec = self.get_submodule(f"lin_glo_{i}")(glo_vec) |
| if i != self.net_depth_glo - 1: |
| glo_vec = F.relu(glo_vec) |
| glo_vec = torch.broadcast_to(glo_vec[..., None, :], |
| bottleneck.shape[:-1] + glo_vec.shape[-1:]) |
| scale, shift = glo_vec.chunk(2, dim=-1) |
| bottleneck = bottleneck * torch.exp(scale) + shift |
|
|
| x = [bottleneck] |
| else: |
| x = [] |
|
|
| |
| if self.use_reflections: |
| |
| |
| |
| |
| refdirs = ref_utils.reflect(-viewdirs[..., None, :], normals_to_use) |
| |
| dir_enc = self.dir_enc_fn(refdirs, roughness) |
| else: |
| |
| dir_enc = self.dir_enc_fn(viewdirs, roughness) |
| dir_enc = torch.broadcast_to( |
| dir_enc[..., None, :], |
| bottleneck.shape[:-1] + (dir_enc.shape[-1],)) |
|
|
| |
| x.append(dir_enc) |
|
|
| |
| if self.use_n_dot_v: |
| dotprod = torch.sum( |
| normals_to_use * viewdirs[..., None, :], dim=-1, keepdim=True) |
| x.append(dotprod) |
|
|
| |
| x = torch.cat(x, dim=-1) |
| |
| inputs = x |
| for i in range(self.net_depth_viewdirs): |
| x = self.get_submodule(f"lin_second_stage_{i}")(x) |
| x = F.relu(x) |
| if i == self.skip_layer_dir: |
| x = torch.cat([x, inputs], dim=-1) |
| |
| |
| rgb = torch.sigmoid(self.rgb_premultiplier * |
| self.rgb_layer(x) + |
| self.rgb_bias) |
|
|
| if self.use_diffuse_color: |
| |
| |
| diffuse_linear = torch.sigmoid(raw_rgb_diffuse - np.log(3.0)) |
| if self.use_specular_tint: |
| specular_linear = tint * rgb |
| else: |
| specular_linear = 0.5 * rgb |
|
|
| |
| rgb = torch.clip(image.linear_to_srgb(specular_linear + diffuse_linear), 0.0, 1.0) |
|
|
| |
| rgb = rgb * (1 + 2 * self.rgb_padding) - self.rgb_padding |
|
|
| return dict( |
| coord=means_contract, |
| density=density, |
| rgb=rgb, |
| raw_grad_density=raw_grad_density, |
| grad_pred=grad_pred, |
| normals=normals, |
| normals_pred=normals_pred, |
| roughness=roughness, |
| ) |
|
|
|
|
| @gin.configurable |
| class NerfMLP(MLP): |
| pass |
|
|
|
|
| @gin.configurable |
| class PropMLP(MLP): |
| pass |
|
|
|
|
| @torch.no_grad() |
| def render_image(model, |
| accelerator: accelerate.Accelerator, |
| batch, |
| rand, |
| train_frac, |
| config, |
| verbose=True, |
| return_weights=False): |
| """Render all the pixels of an image (in test mode). |
| |
| Args: |
| render_fn: function, jit-ed render function mapping (rand, batch) -> pytree. |
| accelerator: used for DDP. |
| batch: a `Rays` pytree, the rays to be rendered. |
| rand: if random |
| config: A Config class. |
| |
| Returns: |
| rgb: rendered color image. |
| disp: rendered disparity image. |
| acc: rendered accumulated weights per pixel. |
| """ |
| model.eval() |
|
|
| height, width = batch['origins'].shape[:2] |
| num_rays = height * width |
| batch = {k: v.reshape((num_rays, -1)) for k, v in batch.items() if v is not None} |
|
|
| global_rank = accelerator.process_index |
| chunks = [] |
| idx0s = tqdm(range(0, num_rays, config.render_chunk_size), |
| desc="Rendering chunk", leave=False, |
| disable=not (accelerator.is_main_process and verbose)) |
|
|
| for i_chunk, idx0 in enumerate(idx0s): |
| chunk_batch = tree_map(lambda r: r[idx0:idx0 + config.render_chunk_size], batch) |
| actual_chunk_size = chunk_batch['origins'].shape[0] |
| rays_remaining = actual_chunk_size % accelerator.num_processes |
| if rays_remaining != 0: |
| padding = accelerator.num_processes - rays_remaining |
| chunk_batch = tree_map(lambda v: torch.cat([v, torch.zeros_like(v[-padding:])], dim=0), chunk_batch) |
| else: |
| padding = 0 |
| |
| rays_per_host = chunk_batch['origins'].shape[0] // accelerator.num_processes |
| start, stop = global_rank * rays_per_host, (global_rank + 1) * rays_per_host |
| chunk_batch = tree_map(lambda r: r[start:stop], chunk_batch) |
|
|
| with accelerator.autocast(): |
| chunk_renderings, ray_history = model(rand, |
| chunk_batch, |
| train_frac=train_frac, |
| compute_extras=True, |
| zero_glo=True) |
|
|
| gather = lambda v: accelerator.gather(v.contiguous())[:-padding] \ |
| if padding > 0 else accelerator.gather(v.contiguous()) |
| |
| chunk_renderings = tree_map(gather, chunk_renderings) |
|
|
| |
| chunk_rendering = chunk_renderings[-1] |
| for k in chunk_renderings[0]: |
| if k.startswith('ray_'): |
| chunk_rendering[k] = [r[k] for r in chunk_renderings] |
|
|
| if return_weights: |
| chunk_rendering['weights'] = gather(ray_history[-1]['weights']) |
| chunk_rendering['coord'] = gather(ray_history[-1]['coord']) |
| chunks.append(chunk_rendering) |
|
|
| |
| rendering = {} |
| for k in chunks[0].keys(): |
| if isinstance(chunks[0][k], list): |
| rendering[k] = [] |
| for i in range(len(chunks[0][k])): |
| rendering[k].append(torch.cat([item[k][i] for item in chunks])) |
| else: |
| rendering[k] = torch.cat([item[k] for item in chunks]) |
|
|
| for k, z in rendering.items(): |
| if not k.startswith('ray_'): |
| |
| rendering[k] = z.reshape((height, width) + z.shape[1:]) |
|
|
| |
| |
| |
| keys = [k for k in rendering if k.startswith('ray_')] |
| if keys: |
| num_rays = rendering[keys[0]][0].shape[0] |
| ray_idx = torch.randperm(num_rays) |
| ray_idx = ray_idx[:config.vis_num_rays] |
| for k in keys: |
| rendering[k] = [r[ray_idx] for r in rendering[k]] |
| model.train() |
| return rendering |
|
|