| import collections |
| import functools |
|
|
| import torch.optim |
| from internal import camera_utils |
| from internal import configs |
| from internal import datasets |
| from internal import image |
| from internal import math |
| from internal import models |
| from internal import ref_utils |
| from internal import stepfun |
| from internal import utils |
| import numpy as np |
| from torch.utils._pytree import tree_map, tree_flatten |
| from torch_scatter import segment_coo |
|
|
|
|
| class GradientScaler(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, colors, sigmas, ray_dist): |
| ctx.save_for_backward(ray_dist) |
| return colors, sigmas |
|
|
| @staticmethod |
| def backward(ctx, grad_output_colors, grad_output_sigmas): |
| (ray_dist,) = ctx.saved_tensors |
| scaling = torch.square(ray_dist).clamp(0, 1) |
| return grad_output_colors * scaling[..., None], grad_output_sigmas * scaling, None |
|
|
|
|
| def tree_reduce(fn, tree, initializer=0): |
| return functools.reduce(fn, tree_flatten(tree)[0], initializer) |
|
|
|
|
| def tree_sum(tree): |
| return tree_reduce(lambda x, y: x + y, tree, initializer=0) |
|
|
|
|
| def tree_norm_sq(tree): |
| return tree_sum(tree_map(lambda x: torch.sum(x ** 2), tree)) |
|
|
|
|
| def tree_norm(tree): |
| return torch.sqrt(tree_norm_sq(tree)) |
|
|
|
|
| def tree_abs_max(tree): |
| return tree_reduce( |
| lambda x, y: max(x, torch.abs(y).max().item()), tree, initializer=0) |
|
|
|
|
| def tree_len(tree): |
| return tree_sum(tree_map(lambda z: np.prod(z.shape), tree)) |
|
|
|
|
| def summarize_tree(tree, fn, ancestry=(), max_depth=3): |
| """Flatten 'tree' while 'fn'-ing values and formatting keys like/this.""" |
| stats = {} |
| for k, v in tree.items(): |
| name = ancestry + (k,) |
| stats['/'.join(name)] = fn(v) |
| if hasattr(v, 'items') and len(ancestry) < (max_depth - 1): |
| stats.update(summarize_tree(v, fn, ancestry=name, max_depth=max_depth)) |
| return stats |
|
|
|
|
| def compute_data_loss(batch, renderings, config): |
| """Computes data loss terms for RGB, normal, and depth outputs.""" |
| data_losses = [] |
| stats = collections.defaultdict(lambda: []) |
|
|
| |
| |
| |
| lossmult = batch['lossmult'] |
| lossmult = torch.broadcast_to(lossmult, batch['rgb'][..., :3].shape) |
| if config.disable_multiscale_loss: |
| lossmult = torch.ones_like(lossmult) |
|
|
| for rendering in renderings: |
| resid_sq = (rendering['rgb'] - batch['rgb'][..., :3]) ** 2 |
| denom = lossmult.sum() |
| stats['mses'].append(((lossmult * resid_sq).sum() / denom).item()) |
|
|
| if config.data_loss_type == 'mse': |
| |
| data_loss = resid_sq |
| elif config.data_loss_type == 'charb': |
| |
| data_loss = torch.sqrt(resid_sq + config.charb_padding ** 2) |
| elif config.data_loss_type == 'rawnerf': |
| |
| rgb_render_clip = rendering['rgb'].clamp_max(1) |
| resid_sq_clip = (rgb_render_clip - batch['rgb'][..., :3]) ** 2 |
| |
| scaling_grad = 1. / (1e-3 + rgb_render_clip.detach()) |
| |
| data_loss = resid_sq_clip * scaling_grad ** 2 |
| else: |
| assert False |
| data_losses.append((lossmult * data_loss).sum() / denom) |
|
|
| if config.compute_disp_metrics: |
| |
| |
| disp = 1 / (1 + rendering['distance_mean']) |
| stats['disparity_mses'].append(((disp - batch['disps']) ** 2).mean().item()) |
|
|
| if config.compute_normal_metrics: |
| if 'normals' in rendering: |
| weights = rendering['acc'] * batch['alphas'] |
| normalized_normals_gt = ref_utils.l2_normalize(batch['normals']) |
| normalized_normals = ref_utils.l2_normalize(rendering['normals']) |
| normal_mae = ref_utils.compute_weighted_mae(weights, normalized_normals, |
| normalized_normals_gt) |
| else: |
| |
| normal_mae = torch.nan |
| stats['normal_maes'].append(normal_mae.item()) |
|
|
| loss = ( |
| config.data_coarse_loss_mult * sum(data_losses[:-1]) + |
| config.data_loss_mult * data_losses[-1]) |
|
|
| stats = {k: np.array(stats[k]) for k in stats} |
| return loss, stats |
|
|
|
|
| def interlevel_loss(ray_history, config): |
| """Computes the interlevel loss defined in mip-NeRF 360.""" |
| |
| last_ray_results = ray_history[-1] |
| c = last_ray_results['sdist'].detach() |
| w = last_ray_results['weights'].detach() |
| loss_interlevel = 0. |
| for ray_results in ray_history[:-1]: |
| cp = ray_results['sdist'] |
| wp = ray_results['weights'] |
| loss_interlevel += stepfun.lossfun_outer(c, w, cp, wp).mean() |
| return config.interlevel_loss_mult * loss_interlevel |
|
|
|
|
| def anti_interlevel_loss(ray_history, config): |
| """Computes the interlevel loss defined in mip-NeRF 360.""" |
| last_ray_results = ray_history[-1] |
| c = last_ray_results['sdist'].detach() |
| w = last_ray_results['weights'].detach() |
| w_normalize = w / (c[..., 1:] - c[..., :-1]) |
| loss_anti_interlevel = 0. |
| for i, ray_results in enumerate(ray_history[:-1]): |
| cp = ray_results['sdist'] |
| wp = ray_results['weights'] |
| c_, w_ = stepfun.blur_stepfun(c, w_normalize, config.pulse_width[i]) |
|
|
| |
| area = 0.5 * (w_[..., 1:] + w_[..., :-1]) * (c_[..., 1:] - c_[..., :-1]) |
|
|
| cdf = torch.cat([torch.zeros_like(area[..., :1]), torch.cumsum(area, dim=-1)], dim=-1) |
|
|
| |
| cdf_interp = math.sorted_interp_quad(cp, c_, w_, cdf) |
| |
| w_s = torch.diff(cdf_interp, dim=-1) |
|
|
| loss_anti_interlevel += ((w_s - wp).clamp_min(0) ** 2 / (wp + 1e-5)).mean() |
| return config.anti_interlevel_loss_mult * loss_anti_interlevel |
|
|
|
|
| def distortion_loss(ray_history, config): |
| """Computes the distortion loss regularizer defined in mip-NeRF 360.""" |
| last_ray_results = ray_history[-1] |
| c = last_ray_results['sdist'] |
| w = last_ray_results['weights'] |
| loss = stepfun.lossfun_distortion(c, w).mean() |
| return config.distortion_loss_mult * loss |
|
|
|
|
| def orientation_loss(batch, model, ray_history, config): |
| """Computes the orientation loss regularizer defined in ref-NeRF.""" |
| total_loss = 0. |
| for i, ray_results in enumerate(ray_history): |
| w = ray_results['weights'] |
| n = ray_results[config.orientation_loss_target] |
| if n is None: |
| raise ValueError('Normals cannot be None if orientation loss is on.') |
| |
| v = -1. * batch['viewdirs'] |
| n_dot_v = (n * v[..., None, :]).sum(dim=-1) |
| loss = (w * n_dot_v.clamp_min(0) ** 2).sum(dim=-1).mean() |
| if i < model.num_levels - 1: |
| total_loss += config.orientation_coarse_loss_mult * loss |
| else: |
| total_loss += config.orientation_loss_mult * loss |
| return total_loss |
|
|
|
|
| def hash_decay_loss(ray_history, config): |
| total_loss = 0. |
| for i, ray_results in enumerate(ray_history): |
| total_loss += config.hash_decay_mults * ray_results['loss_hash_decay'] |
| return total_loss |
|
|
|
|
| def opacity_loss(renderings, config): |
| total_loss = 0. |
| for i, rendering in enumerate(renderings): |
| o = rendering['acc'] |
| total_loss += config.opacity_loss_mult * (-o * torch.log(o + 1e-5)).mean() |
| return total_loss |
|
|
|
|
| def predicted_normal_loss(model, ray_history, config): |
| """Computes the predicted normal supervision loss defined in ref-NeRF.""" |
| total_loss = 0. |
| for i, ray_results in enumerate(ray_history): |
| w = ray_results['weights'] |
| n = ray_results['normals'] |
| n_pred = ray_results['normals_pred'] |
| if n is None or n_pred is None: |
| raise ValueError( |
| 'Predicted normals and gradient normals cannot be None if ' |
| 'predicted normal loss is on.') |
| loss = torch.mean((w * (1.0 - torch.sum(n * n_pred, dim=-1))).sum(dim=-1)) |
| if i < model.num_levels - 1: |
| total_loss += config.predicted_normal_coarse_loss_mult * loss |
| else: |
| total_loss += config.predicted_normal_loss_mult * loss |
| return total_loss |
|
|
|
|
| def clip_gradients(model, accelerator, config): |
| """Clips gradients of MLP based on norm and max value.""" |
| if config.grad_max_norm > 0 and accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(model.parameters(), config.grad_max_norm) |
|
|
| if config.grad_max_val > 0 and accelerator.sync_gradients: |
| accelerator.clip_grad_value_(model.parameters(), config.grad_max_val) |
|
|
| for param in model.parameters(): |
| param.grad.nan_to_num_() |
|
|
|
|
| def create_optimizer(config: configs.Config, model): |
| """Creates optax optimizer for model training.""" |
| adam_kwargs = { |
| 'betas': [config.adam_beta1, config.adam_beta2], |
| 'eps': config.adam_eps, |
| } |
| lr_kwargs = { |
| 'max_steps': config.max_steps, |
| 'lr_delay_steps': config.lr_delay_steps, |
| 'lr_delay_mult': config.lr_delay_mult, |
| } |
|
|
| lr_fn_main = lambda step: math.learning_rate_decay( |
| step, |
| lr_init=config.lr_init, |
| lr_final=config.lr_final, |
| **lr_kwargs) |
| optimizer = torch.optim.Adam(model.parameters(), lr=config.lr_init, **adam_kwargs) |
|
|
| return optimizer, lr_fn_main |
|
|