| import dataclasses |
| import os |
| from typing import Any, Callable, Optional, Tuple, List |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from absl import flags |
| import gin |
| from internal import utils |
|
|
| gin.add_config_file_search_path('configs/') |
|
|
| configurables = { |
| 'torch': [torch.reciprocal, torch.log, torch.log1p, torch.exp, torch.sqrt, torch.square], |
| } |
|
|
| for module, configurables in configurables.items(): |
| for configurable in configurables: |
| gin.config.external_configurable(configurable, module=module) |
|
|
|
|
| @gin.configurable() |
| @dataclasses.dataclass |
| class Config: |
| """Configuration flags for everything.""" |
| seed = 0 |
| dataset_loader: str = 'llff' |
| batching: str = 'all_images' |
| batch_size: int = 2 ** 16 |
| patch_size: int = 1 |
| factor: int = 4 |
| multiscale: bool = False |
| multiscale_levels: int = 4 |
| |
| forward_facing: bool = False |
| render_path: bool = False |
| llffhold: int = 8 |
| |
| llff_use_all_images_for_training: bool = False |
| llff_use_all_images_for_testing: bool = False |
| use_tiffs: bool = False |
| compute_disp_metrics: bool = False |
| compute_normal_metrics: bool = False |
| disable_multiscale_loss: bool = False |
| randomized: bool = True |
| near: float = 2. |
| far: float = 6. |
| exp_name: str = "test" |
| data_dir: Optional[str] = "/SSD_DISK/datasets/360_v2/bicycle" |
| vocab_tree_path: Optional[str] = None |
| render_chunk_size: int = 65536 |
| num_showcase_images: int = 5 |
| deterministic_showcase: bool = True |
| vis_num_rays: int = 16 |
| |
| vis_decimate: int = 0 |
|
|
| |
| max_steps: int = 25000 |
| early_exit_steps: Optional[int] = None |
| checkpoint_every: int = 5000 |
| resume_from_checkpoint: bool = True |
| checkpoints_total_limit: int = 1 |
| gradient_scaling: bool = False |
| print_every: int = 100 |
| train_render_every: int = 500 |
| data_loss_type: str = 'charb' |
| charb_padding: float = 0.001 |
| data_loss_mult: float = 1.0 |
| data_coarse_loss_mult: float = 0. |
| interlevel_loss_mult: float = 0.0 |
| anti_interlevel_loss_mult: float = 0.01 |
| pulse_width = [0.03, 0.003] |
| orientation_loss_mult: float = 0.0 |
| orientation_coarse_loss_mult: float = 0.0 |
| |
| orientation_loss_target: str = 'normals_pred' |
| predicted_normal_loss_mult: float = 0.0 |
| |
| predicted_normal_coarse_loss_mult: float = 0.0 |
| hash_decay_mults: float = 0.1 |
|
|
| lr_init: float = 0.01 |
| lr_final: float = 0.001 |
| lr_delay_steps: int = 5000 |
| lr_delay_mult: float = 1e-8 |
| adam_beta1: float = 0.9 |
| adam_beta2: float = 0.99 |
| adam_eps: float = 1e-15 |
| grad_max_norm: float = 0. |
| grad_max_val: float = 0. |
| distortion_loss_mult: float = 0.005 |
| opacity_loss_mult: float = 0. |
|
|
| |
| eval_only_once: bool = True |
| eval_save_output: bool = True |
| eval_save_ray_data: bool = False |
| eval_render_interval: int = 1 |
| eval_dataset_limit: int = np.iinfo(np.int32).max |
| eval_quantize_metrics: bool = True |
| eval_crop_borders: int = 0 |
|
|
| |
| render_video_fps: int = 60 |
| render_video_crf: int = 18 |
| render_path_frames: int = 120 |
| z_variation: float = 0. |
| z_phase: float = 0. |
| render_dist_percentile: float = 0.5 |
| render_dist_curve_fn: Callable[..., Any] = np.log |
| render_path_file: Optional[str] = None |
| render_resolution: Optional[Tuple[int, int]] = None |
| |
| render_focal: Optional[float] = None |
| render_camtype: Optional[str] = None |
| render_spherical: bool = False |
| render_save_async: bool = True |
|
|
| render_spline_keyframes: Optional[str] = None |
| |
| |
| |
| render_spline_n_interp: int = 30 |
| render_spline_degree: int = 5 |
| render_spline_smoothness: float = .03 |
| |
| |
| render_spline_interpolate_exposure: bool = False |
|
|
| |
| rawnerf_mode: bool = False |
| exposure_percentile: float = 97. |
| num_border_pixels_to_mask: int = 0 |
| |
| apply_bayer_mask: bool = False |
| autoexpose_renders: bool = False |
| |
| eval_raw_affine_cc: bool = False |
|
|
| zero_glo: bool = False |
|
|
| |
| valid_weight_thresh: float = 0.05 |
| isosurface_threshold: float = 20 |
| mesh_voxels: int = 512 ** 3 |
| visibility_resolution: int = 512 |
| mesh_radius: float = 1.0 |
| mesh_max_radius: float = 10.0 |
| std_value: float = 0.0 |
| compute_visibility: bool = False |
| extract_visibility: bool = True |
| decimate_target: int = -1 |
| vertex_color: bool = True |
| vertex_projection: bool = True |
|
|
| |
| tsdf_radius: float = 2.0 |
| tsdf_resolution: int = 512 |
| truncation_margin: float = 5.0 |
| tsdf_max_radius: float = 10.0 |
|
|
|
|
| def define_common_flags(): |
| |
| flags.DEFINE_string('mode', None, 'Required by GINXM, not used.') |
| flags.DEFINE_string('base_folder', None, 'Required by GINXM, not used.') |
| flags.DEFINE_multi_string('gin_bindings', None, 'Gin parameter bindings.') |
| flags.DEFINE_multi_string('gin_configs', None, 'Gin config files.') |
|
|
|
|
| def load_config(): |
| """Load the config, and optionally checkpoint it.""" |
| gin.parse_config_files_and_bindings( |
| flags.FLAGS.gin_configs, flags.FLAGS.gin_bindings, skip_unknown=True) |
| config = Config() |
| return config |
|
|