| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| """ |
| The renderer is a module that takes in rays, decides where to sample along each |
| ray, and computes pixel colors using the volume rendering equation. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .ray_marcher import MipRayMarcher2 |
|
|
| |
| def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: |
| """ |
| Left-multiplies MxM @ NxM. Returns NxM. |
| """ |
| res = torch.matmul(vectors4, matrix.T) |
| return res |
|
|
| |
| def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: |
| """ |
| Normalize vector lengths. |
| """ |
| return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) |
|
|
| |
| def torch_dot(x: torch.Tensor, y: torch.Tensor): |
| """ |
| Dot product of two tensors. |
| """ |
| return (x * y).sum(-1) |
|
|
| |
| def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): |
| """ |
| Author: Petr Kellnhofer |
| Intersects rays with the [-1, 1] NDC volume. |
| Returns min and max distance of entry. |
| Returns -1 for no intersection. |
| https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection |
| """ |
| o_shape = rays_o.shape |
| rays_o = rays_o.detach().reshape(-1, 3) |
| rays_d = rays_d.detach().reshape(-1, 3) |
|
|
|
|
| bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] |
| bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] |
| bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) |
| is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) |
|
|
| |
| invdir = 1 / rays_d |
| sign = (invdir < 0).long() |
|
|
| |
| tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] |
| tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] |
|
|
| |
| tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] |
| tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] |
|
|
| |
| is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False |
|
|
| |
| tmin = torch.max(tmin, tymin) |
| tmax = torch.min(tmax, tymax) |
|
|
| |
| tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] |
| tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] |
|
|
| |
| is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False |
|
|
| |
| tmin = torch.max(tmin, tzmin) |
| tmax = torch.min(tmax, tzmax) |
|
|
| |
| tmin[torch.logical_not(is_valid)] = -1 |
| tmax[torch.logical_not(is_valid)] = -2 |
|
|
| return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) |
|
|
| |
| def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): |
| """ |
| Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. |
| Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. |
| """ |
| |
| steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) |
|
|
| |
| |
| |
| for i in range(start.ndim): |
| steps = steps.unsqueeze(-1) |
|
|
| |
| out = start[None] + steps * (stop - start)[None] |
|
|
| return out |
|
|
| |
| def generate_planes(): |
| """ |
| Defines planes by the three vectors that form the "axes" of the |
| plane. Should work with arbitrary number of planes and planes of |
| arbitrary orientation. |
| |
| Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 |
| """ |
| return torch.tensor([[[1, 0, 0], |
| [0, 1, 0], |
| [0, 0, 1]], |
| [[1, 0, 0], |
| [0, 0, 1], |
| [0, 1, 0]], |
| [[0, 0, 1], |
| [0, 1, 0], |
| [1, 0, 0]]], dtype=torch.float32) |
|
|
| |
| def project_onto_planes(planes, coordinates): |
| """ |
| Does a projection of a 3D point onto a batch of 2D planes, |
| returning 2D plane coordinates. |
| |
| Takes plane axes of shape n_planes, 3, 3 |
| # Takes coordinates of shape N, M, 3 |
| # returns projections of shape N*n_planes, M, 2 |
| """ |
| N, M, C = coordinates.shape |
| n_planes, _, _ = planes.shape |
| coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) |
| inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) |
| coordinates = coordinates.to(inv_planes.device) |
| projections = torch.bmm(coordinates, inv_planes) |
| return projections[..., :2] |
|
|
| |
| def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): |
| assert padding_mode == 'zeros' |
| N, n_planes, C, H, W = plane_features.shape |
| _, M, _ = coordinates.shape |
| plane_features = plane_features.view(N*n_planes, C, H, W) |
|
|
| coordinates = (2/box_warp) * coordinates |
| |
| projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) |
| |
| output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) |
| return output_features |
|
|
| |
| def sample_from_3dgrid(grid, coordinates): |
| """ |
| Expects coordinates in shape (batch_size, num_points_per_batch, 3) |
| Expects grid in shape (1, channels, H, W, D) |
| (Also works if grid has batch size) |
| Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) |
| """ |
| batch_size, n_coords, n_dims = coordinates.shape |
| sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1), |
| coordinates.reshape(batch_size, 1, 1, -1, n_dims), |
| mode='bilinear', padding_mode='zeros', align_corners=False) |
| N, C, H, W, D = sampled_features.shape |
| sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) |
| return sampled_features |
|
|
| class ImportanceRenderer(torch.nn.Module): |
| """ |
| Modified original version to filter out-of-box samples as TensoRF does. |
| |
| Reference: |
| TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277 |
| """ |
| def __init__(self): |
| super().__init__() |
| self.activation_factory = self._build_activation_factory() |
| self.ray_marcher = MipRayMarcher2(self.activation_factory) |
| self.plane_axes = generate_planes() |
|
|
| def _build_activation_factory(self): |
| def activation_factory(options: dict): |
| if options['clamp_mode'] == 'softplus': |
| return lambda x: F.softplus(x - 1) |
| else: |
| assert False, "Renderer only supports `clamp_mode`=`softplus`!" |
| return activation_factory |
|
|
| def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor, |
| planes: torch.Tensor, decoder: nn.Module, rendering_options: dict): |
| """ |
| Additional filtering is applied to filter out-of-box samples. |
| Modifications made by Zexin He. |
| """ |
|
|
| |
| batch_size, num_rays, samples_per_ray, _ = depths.shape |
| device = planes.device |
| depths = depths.to(device) |
| ray_directions = ray_directions.to(device) |
| ray_origins = ray_origins.to(device) |
| |
| sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) |
| sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) |
|
|
| |
| mask_inbox = \ |
| (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \ |
| (sample_coordinates <= rendering_options['sampler_bbox_max']) |
| mask_inbox = mask_inbox.all(-1) |
|
|
| |
| _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) |
|
|
| |
| SAFE_GUARD = 3 |
| DATA_TYPE = _out['sigma'].dtype |
| colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE) |
| densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD |
| colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox] |
|
|
| |
| colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1]) |
| densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1]) |
|
|
| return colors_pass, densities_pass |
|
|
| def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options): |
| |
|
|
| if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': |
| ray_start, ray_end = get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) |
| is_ray_valid = ray_end > ray_start |
| if torch.any(is_ray_valid).item(): |
| ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() |
| ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() |
| depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) |
| else: |
| |
| depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) |
| |
| depths_coarse = depths_coarse.to(planes.device) |
|
|
| |
| colors_coarse, densities_coarse = self._forward_pass( |
| depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins, |
| planes=planes, decoder=decoder, rendering_options=rendering_options) |
|
|
| |
| N_importance = rendering_options['depth_resolution_importance'] |
| if N_importance > 0: |
| _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) |
|
|
| depths_fine = self.sample_importance(depths_coarse, weights, N_importance) |
|
|
| colors_fine, densities_fine = self._forward_pass( |
| depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins, |
| planes=planes, decoder=decoder, rendering_options=rendering_options) |
|
|
| all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, |
| depths_fine, colors_fine, densities_fine) |
|
|
| |
| rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options) |
| else: |
| rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) |
|
|
| return rgb_final, depth_final, weights.sum(2) |
|
|
| def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): |
| plane_axes = self.plane_axes.to(planes.device) |
| sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) |
|
|
| out = decoder(sampled_features, sample_directions) |
| if options.get('density_noise', 0) > 0: |
| out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] |
| return out |
|
|
| def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options): |
| out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options) |
| out['sigma'] = self.activation_factory(options)(out['sigma']) |
| return out |
|
|
| def sort_samples(self, all_depths, all_colors, all_densities): |
| _, indices = torch.sort(all_depths, dim=-2) |
| all_depths = torch.gather(all_depths, -2, indices) |
| all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) |
| all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) |
| return all_depths, all_colors, all_densities |
|
|
| def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2): |
| all_depths = torch.cat([depths1, depths2], dim = -2) |
| all_colors = torch.cat([colors1, colors2], dim = -2) |
| all_densities = torch.cat([densities1, densities2], dim = -2) |
|
|
| _, indices = torch.sort(all_depths, dim=-2) |
| all_depths = torch.gather(all_depths, -2, indices) |
| all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) |
| all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) |
|
|
| return all_depths, all_colors, all_densities |
|
|
| def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): |
| """ |
| Return depths of approximately uniformly spaced samples along rays. |
| """ |
| N, M, _ = ray_origins.shape |
| if disparity_space_sampling: |
| depths_coarse = torch.linspace(0, |
| 1, |
| depth_resolution, |
| device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) |
| depth_delta = 1/(depth_resolution - 1) |
| depths_coarse += torch.rand_like(depths_coarse) * depth_delta |
| depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) |
| else: |
| if type(ray_start) == torch.Tensor: |
| depths_coarse = linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) |
| depth_delta = (ray_end - ray_start) / (depth_resolution - 1) |
| depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] |
| else: |
| depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) |
| depth_delta = (ray_end - ray_start)/(depth_resolution - 1) |
| depths_coarse += torch.rand_like(depths_coarse) * depth_delta |
|
|
| return depths_coarse |
|
|
| def sample_importance(self, z_vals, weights, N_importance): |
| """ |
| Return depths of importance sampled points along rays. See NeRF importance sampling for more. |
| """ |
| with torch.no_grad(): |
| batch_size, num_rays, samples_per_ray, _ = z_vals.shape |
|
|
| z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) |
| weights = weights.reshape(batch_size * num_rays, -1) |
|
|
| |
| weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1) |
| weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() |
| weights = weights + 0.01 |
|
|
| z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) |
| importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], |
| N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) |
| return importance_z_vals |
|
|
| def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): |
| """ |
| Sample @N_importance samples from @bins with distribution defined by @weights. |
| Inputs: |
| bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" |
| weights: (N_rays, N_samples_) |
| N_importance: the number of samples to draw from the distribution |
| det: deterministic or not |
| eps: a small number to prevent division by zero |
| Outputs: |
| samples: the sampled samples |
| """ |
| N_rays, N_samples_ = weights.shape |
| weights = weights + eps |
| pdf = weights / torch.sum(weights, -1, keepdim=True) |
| cdf = torch.cumsum(pdf, -1) |
| cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) |
| |
|
|
| if det: |
| u = torch.linspace(0, 1, N_importance, device=bins.device) |
| u = u.expand(N_rays, N_importance) |
| else: |
| u = torch.rand(N_rays, N_importance, device=bins.device) |
| u = u.contiguous() |
|
|
| inds = torch.searchsorted(cdf, u, right=True) |
| below = torch.clamp_min(inds-1, 0) |
| above = torch.clamp_max(inds, N_samples_) |
|
|
| inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) |
| cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) |
| bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) |
|
|
| denom = cdf_g[...,1]-cdf_g[...,0] |
| denom[denom<eps] = 1 |
| |
|
|
| samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0]) |
| return samples |