| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| |
| import numpy as np |
| import math |
|
|
| |
| |
|
|
|
|
| from transformers import AutoConfig, AutoModel |
| from transformers.modeling_utils import no_init_weights |
|
|
| import warnings |
|
|
| warnings.filterwarnings("ignore") |
|
|
| import os |
| from pathlib import Path |
| from einops import rearrange, repeat |
|
|
| from easydict import EasyDict as edict |
|
|
| from source.vae_hacked import Decoder |
| from source.rendering.utils import sample_importance,unify_attributes, create_voxel |
| from source.rendering.point_representer import PointRepresenter |
| from source.rendering.point_integrator import PointIntegrator |
| from source.rendering.sat2density_transform_eg3d import get_original_coord,Point_sampler_pano,Point_sampler_ortho |
| from source.rendering.transform_perspective import PointSamplerPerspective |
| from source.rendering.mlp_model import MLPNetwork2 |
| from source.sr_module import SuperresolutionHybrid2X |
| from source.xyz2thetaphi import xyz2thetaphi |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_utils import ModelMixin |
| import tqdm |
|
|
| def normalize_2nd_moment(x, dim=1, eps=1e-8): |
| return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() |
|
|
|
|
| def resolve_backbone_candidates(backbone): |
| env_override_map = { |
| "dinov2-base": "SAT3DGEN_DINOV2_BASE_PATH", |
| "dinov2-large": "SAT3DGEN_DINOV2_LARGE_PATH", |
| "dinov3-large-sat": "SAT3DGEN_DINOV3_SAT_PATH", |
| "dinov3-large-lvd": "SAT3DGEN_DINOV3_LVD_PATH", |
| } |
| default_candidate_map = { |
| "dinov2-base": [ |
| "facebook/dinov2-base", |
| ], |
| "dinov2-large": [ |
| "facebook/dinov2-large", |
| ], |
| "dinov3-large-sat": [ |
| "facebook/dinov3-vitl16-pretrain-sat493m", |
| ], |
| "dinov3-large-lvd": [ |
| "facebook/dinov3-vitl16-pretrain-lvd1689m", |
| ], |
| } |
| if backbone not in default_candidate_map: |
| raise NotImplementedError(f"Unsupported backbone: {backbone}") |
|
|
| candidates = [] |
| env_override = os.environ.get(env_override_map[backbone]) |
| if env_override: |
| candidates.append(env_override) |
| candidates.extend(default_candidate_map[backbone]) |
| return candidates |
|
|
|
|
| |
| |
| |
| _BACKBONE_CONFIGS = { |
| "dinov2-base": { |
| "model_type": "dinov2", |
| "hidden_size": 768, |
| "num_hidden_layers": 12, |
| "num_attention_heads": 12, |
| "intermediate_size": 3072, |
| "patch_size": 14, |
| "image_size": 518, |
| "num_channels": 3, |
| "num_register_tokens": 0, |
| }, |
| "dinov2-large": { |
| "model_type": "dinov2", |
| "hidden_size": 1024, |
| "num_hidden_layers": 24, |
| "num_attention_heads": 16, |
| "intermediate_size": 4096, |
| "patch_size": 14, |
| "image_size": 518, |
| "num_channels": 3, |
| "num_register_tokens": 0, |
| }, |
| "dinov3-large-sat": { |
| "model_type": "dinov3_vit", |
| "hidden_size": 1024, |
| "num_hidden_layers": 24, |
| "num_attention_heads": 16, |
| "intermediate_size": 4096, |
| "patch_size": 16, |
| "image_size": 224, |
| "num_channels": 3, |
| "num_register_tokens": 4, |
| "hidden_act": "gelu", |
| "attention_dropout": 0.0, |
| "drop_path_rate": 0.0, |
| "initializer_range": 0.02, |
| "layer_norm_eps": 1e-05, |
| "layerscale_value": 1.0, |
| "key_bias": False, |
| "mlp_bias": True, |
| "proj_bias": True, |
| "query_bias": True, |
| "value_bias": True, |
| "use_gated_mlp": False, |
| "rope_theta": 100.0, |
| "pos_embed_rescale": 2.0, |
| }, |
| "dinov3-large-lvd": { |
| "model_type": "dinov3_vit", |
| "hidden_size": 1024, |
| "num_hidden_layers": 24, |
| "num_attention_heads": 16, |
| "intermediate_size": 4096, |
| "patch_size": 16, |
| "image_size": 224, |
| "num_channels": 3, |
| "num_register_tokens": 4, |
| "hidden_act": "gelu", |
| "attention_dropout": 0.0, |
| "drop_path_rate": 0.0, |
| "initializer_range": 0.02, |
| "layer_norm_eps": 1e-05, |
| "layerscale_value": 1.0, |
| "key_bias": False, |
| "mlp_bias": True, |
| "proj_bias": True, |
| "query_bias": True, |
| "value_bias": True, |
| "use_gated_mlp": False, |
| "rope_theta": 100.0, |
| "pos_embed_rescale": 2.0, |
| }, |
| } |
|
|
| def load_backbone_model(backbone, skip_weights=False): |
| """Load (or create) the backbone vision model. |
| |
| When *skip_weights* is ``True`` the model structure is instantiated |
| from a built-in config dict **without** any network access. This is |
| useful when the caller will overwrite all parameters later (e.g. via |
| ``Sat3DGen.from_pretrained``), avoiding a redundant multi-GB |
| download of the backbone checkpoint. |
| """ |
| if skip_weights: |
| if backbone not in _BACKBONE_CONFIGS: |
| raise NotImplementedError(f"No built-in config for backbone: {backbone}") |
| print(f"Creating backbone structure from built-in config (skip weights): {backbone}") |
| config = AutoConfig.for_model(**_BACKBONE_CONFIGS[backbone]) |
| with no_init_weights(): |
| model = AutoModel.from_config(config) |
| return model.eval().requires_grad_(False) |
|
|
| load_errors = [] |
| for candidate in resolve_backbone_candidates(backbone): |
| expanded_candidate = os.path.expanduser(candidate) |
| resolved_candidate = expanded_candidate if Path(expanded_candidate).exists() else candidate |
| try: |
| print("Trying pretrained_model_name_or_path:", resolved_candidate) |
| return AutoModel.from_pretrained(resolved_candidate).eval().requires_grad_(False) |
| except Exception as exc: |
| load_errors.append(f"{resolved_candidate}: {exc}") |
|
|
| formatted_errors = "\n".join(load_errors) |
| raise RuntimeError( |
| f"Failed to load the backbone `{backbone}`.\n" |
| f"Tried the following candidates:\n{formatted_errors}\n" |
| "You can override the lookup with the corresponding SAT3DGEN_*_PATH environment variable." |
| ) |
|
|
| class MappingNetwork(torch.nn.Module): |
| def __init__(self, |
| z_dim, |
| w_dim, |
| num_layers = 8, |
| norm = True, |
| ): |
| super().__init__() |
| self.z_dim = z_dim |
| self.w_dim = w_dim |
| self.num_layers = num_layers |
| self.norm = norm |
|
|
| features_list = [z_dim] * (num_layers) + [w_dim] |
| layers = [] |
| for idx in range(num_layers): |
| layers.append(nn.Linear(features_list[idx], features_list[idx + 1])) |
| layers.append(nn.LeakyReLU(0.2)) |
| self.mapping = nn.Sequential(*layers) |
|
|
| def forward(self, z): |
| |
| if self.norm: |
| z = normalize_2nd_moment(z.to(torch.float32)) |
| |
| x = self.mapping(z) |
| return x |
|
|
| class dino_3d_model(nn.Module): |
| def __init__(self,output_ch=192,ch_mult=[1,2,4,4,4],pad = False, with_attn=True,backbone='dinov2-base',no_hidden_states=False, no_cls_token=False, skip_backbone_weights=False): |
| super().__init__() |
| self.dino_model = load_backbone_model(backbone, skip_weights=skip_backbone_weights) |
| if backbone == 'dinov2-base': |
| z_channels = 6144 if not no_cls_token else 6144//2 |
|
|
| self.feature_list = [3,6,9,12] if not no_hidden_states else [] |
| if self.feature_list == []: |
| z_channels = z_channels//4 |
| elif backbone in ['dinov2-large',"dinov3-large-sat","dinov3-large-lvd"]: |
| z_channels = 8192 if not no_cls_token else 8192//2 |
| self.feature_list = [6,12,18,24] if not no_hidden_states else [] |
| if self.feature_list == []: |
| z_channels = z_channels//4 |
| self.backbone = backbone |
| self.no_cls_token = no_cls_token |
| self.decoder = Decoder(ch=128,out_ch=output_ch,ch_mult=ch_mult,num_res_blocks=2,attn_resolutions=[],z_channels=z_channels,resolution=256,in_channels=None,with_attn=with_attn) |
| self.pad = pad |
| self.patch_size = self.dino_model.config.to_dict()['patch_size'] |
| self.num_register_tokens = self.dino_model.config.to_dict()['num_register_tokens'] if 'num_register_tokens' in self.dino_model.config.to_dict().keys() else 0 |
|
|
| def forward(self, inputs): |
| _h,_w = inputs.shape[-2:] |
| assert _h == 16 * self.patch_size |
| output = self.dino_model(inputs,output_hidden_states=True) |
| out_put_list = [] |
| if self.feature_list == []: |
| out_put_list.append(output.last_hidden_state) |
| else: |
| for i in self.feature_list: |
| out_put_list.append(output.hidden_states[i]) |
| |
| |
| |
| |
| x = torch.cat(out_put_list,dim=2) |
| dino_feature = rearrange(x[:,1+self.num_register_tokens:], 'b (h w) c -> b c h w', h=_h//self.patch_size , w=_w//self.patch_size) |
| if not self.no_cls_token: |
| cls_token = x[:,0] |
| dino_feature = torch.cat([dino_feature, cls_token.unsqueeze(-1).unsqueeze(-1).repeat(1,1,_h//self.patch_size,_w//self.patch_size)],dim=1) |
| |
| if self.pad: |
| ori_size = dino_feature.size(-1) |
| pad_size = ori_size*self.pad |
| |
| assert pad_size == int(pad_size), 'pad_size should be int' |
| pad_size = int(pad_size) |
| dino_feature = F.pad(dino_feature,(pad_size,pad_size,pad_size,pad_size),'constant', 0) |
| output = self.decoder(dino_feature) |
| return output |
|
|
|
|
|
|
|
|
| def convert_to_easydict(d): |
| if isinstance(d, dict): |
| return edict({k: convert_to_easydict(v) for k, v in d.items()}) |
| return d |
|
|
| class Sat3DGen(ModelMixin, ConfigMixin): |
| |
| |
| |
| |
| _skip_backbone_weights: bool = False |
|
|
| @register_to_config |
| def __init__(self, opt): |
| super().__init__() |
| self.opt = opt |
|
|
| |
| self.opt = convert_to_easydict(opt) |
| if 'sr_padding_mode' not in self.opt.keys(): |
| self.opt.sr_padding_mode = 'zeros' |
| if 'representation_type' not in self.opt.keys(): |
| self.opt.representation_type = 'triplane' |
| self.sat_mapping_mode = 'v2' if not hasattr(self.opt.network, 'sat_mapping_mode') else self.opt.network.sat_mapping_mode |
| assert self.sat_mapping_mode in ['v2'], 'sat_mapping_mode should be v1 or v2' |
| self.sr_factor = 1 if not hasattr(self.opt.network, 'sr_factor') else self.opt.network.sr_factor |
| self.if_w_sky_mapping = True |
| self.backbone = 'dinov2-base' if not hasattr(self.opt, 'backbone') else self.opt.backbone |
| if self.if_w_sky_mapping: |
| self.z_dim = 270 |
| self.w_dim = 512 |
| self.sky_mapping = MappingNetwork(self.z_dim,self.w_dim,norm=False) |
| else: |
| self.z_dim = 270 |
| self.w_dim = 270 |
| assert self.sr_factor in [1,2] , 'sr_factor should be 1 or 2' |
| self.image_size = self.opt.network.image_size |
| self.latent_size = self.opt.network.latent_size |
| self.latent_channel = self.opt.network.latent_channel |
| if 'pad' in self.opt.keys(): |
| self.pad = self.opt.pad |
| self.position_scale_factor = 1 / (self.pad*2+1) |
| assert self.opt.network.position_scale_factor ==1, 'position_scale_factor should be 1,not used in this version.' |
| else: |
| self.position_scale_factor = self.opt.network.position_scale_factor |
| self.pad = False |
|
|
| color_channels = 32 if not hasattr(self.opt.network, 'color_channels') else self.opt.network.color_channels |
| self.sr_module = SuperresolutionHybrid2X(color_channels, 3,padding_mode=self.opt.sr_padding_mode,v2=True) |
| if self.opt.representation_type == 'triplane': |
| output_ch = self.opt.network.triplane.dim*3 |
| elif self.opt.representation_type in ['oneplane','oneplane_multi']: |
| output_ch = self.opt.network.triplane.dim*2 |
| self.with_sky = True |
| self.sky_input_dim = 2 |
|
|
| if self.with_sky: |
| self.sky_decoder = Decoder(ch=32,out_ch=color_channels,ch_mult=[1,2,2,4,4,4,4],num_res_blocks=2,attn_resolutions=[],z_channels=self.w_dim ,resolution=256,in_channels=None,with_attn=False,pano_pad=True) |
| self.unet_model = dino_3d_model(output_ch = output_ch, |
| ch_mult = self.opt.network.triplane.ch_mult if hasattr(self.opt.network.triplane, 'ch_mult') else [1,2,4,4,4], |
| pad = self.pad, |
| with_attn = self.opt.network.with_attn if hasattr(self.opt.network, 'with_attn') else True, |
| backbone = self.backbone, |
| no_hidden_states=self.opt.network.no_hidden_states if hasattr(self.opt.network, 'no_hidden_states') else False, |
| no_cls_token=self.opt.network.no_cls_token if hasattr(self.opt.network, 'no_cls_token') else False, |
| skip_backbone_weights=self._skip_backbone_weights, |
| ) |
|
|
| self.num_importance = self.opt.network.point_sampling_kwargs.num_importance |
| |
| self.opt.network.point_sampling_kwargs.pop('num_importance') |
| if self.opt.representation_type == 'oneplane': |
| input_dim_mlp = self.opt.network.triplane.dim*2 |
| elif self.opt.representation_type in ['triplane','oneplane_multi']: |
| input_dim_mlp = self.opt.network.triplane.dim |
| self.mlp = MLPNetwork2(input_dim=input_dim_mlp, |
| hidden_dim=64, |
| output_dim=color_channels, |
| style_dim=self.w_dim, |
| ) |
|
|
| self.point_representer = PointRepresenter( |
| representation_type=self.opt.representation_type, |
| triplane_axes=None, |
| coordinate_scale=None, |
| ) |
| self.point_integrator = PointIntegrator(**self.opt.network.ray_marching_kwargs) |
| unused_parameter = ['max_height','origin_height','realworld_scale'] |
| for i in unused_parameter: |
| if i in self.opt.network.point_sampling_kwargs.keys(): |
| self.opt.network.point_sampling_kwargs.pop(i) |
| if self.sr_factor ==2: |
| self.opt.render_size = 256 |
| self.point_sampler_definition(self.opt.render_size if hasattr(self.opt, 'render_size') else 256) |
|
|
| def point_sampler_definition(self, render_size=256): |
| pano_size = np.array([render_size*2,render_size//2]) / self.sr_factor |
| pano_dir = get_original_coord(W=int(pano_size[0]),H=int(pano_size[1]),full=True).unsqueeze(0).float() |
| |
| |
| if torch.cuda.is_available(): |
| pano_dir = pano_dir.cuda() |
| self.pano_direction = pano_dir |
| |
| self.point_sampler = Point_sampler_pano(pano_direction=self.pano_direction,**self.opt.network.point_sampling_kwargs) |
| self.point_sampler_per = PointSamplerPerspective(num_points=self.opt.network.point_sampling_kwargs.num_points,aabb_strict=True,render_size=[render_size// self.sr_factor,render_size// self.sr_factor]) |
| if render_size==256 and self.sr_factor == 2: |
| self.point_sampler_sat = Point_sampler_ortho(num_points=self.opt.network.point_sampling_kwargs.num_points,position_scale_factor=self.position_scale_factor,render_size=render_size// self.sr_factor) |
| else: |
| self.point_sampler_sat = Point_sampler_ortho(num_points=self.opt.network.point_sampling_kwargs.num_points,position_scale_factor=self.position_scale_factor,resolution=int(render_size*1.5),render_size=render_size) |
| print('render size:', render_size, 'sr_factor:', self.sr_factor) |
| |
|
|
| def from_sat_to_triplane(self,x): |
| planes_feature = self.unet_model(x) |
| if self.opt.representation_type == 'triplane': |
| triplane_ori = rearrange(planes_feature, 'b (n c) h w -> b n c h w',n=3) |
| elif self.opt.representation_type in ['oneplane','oneplane_multi']: |
| one_plane_ori = planes_feature[:,:self.opt.network.triplane.dim] |
| one_plane_ori = rearrange(one_plane_ori, 'b (n c) h w -> b n c h w',n=1) |
| one_line_ori = planes_feature[:,self.opt.network.triplane.dim:] |
| one_line_ori = torch.mean(one_line_ori, dim=2, keepdim=False) |
| triplane_ori = [one_plane_ori,one_line_ori] |
| return triplane_ori |
| |
| def c2w_prepare(self, c2w): |
| if c2w is not None: |
| c2w[:,:3, 3] = c2w[:,:3, 3] * self.position_scale_factor |
| return c2w |
|
|
| def w_sky_prepare(self, z_ill): |
| if z_ill is not None: |
| if self.if_w_sky_mapping: |
| w_sky = self.sky_mapping(z_ill) |
| else: |
| w_sky = z_ill |
| else: |
| w_sky = None |
| return w_sky |
| |
| def w_sky2sky_feature_2D(self, w_sky, z_ill=None): |
| sky_feature_2D = None |
| if self.with_sky and z_ill is not None: |
| sky_feature_2D = repeat(w_sky, 'b c -> b c h w', h=8, w=8) |
| sky_feature_2D = self.sky_decoder(sky_feature_2D) |
| sky_feature_2D = torch.sigmoid(sky_feature_2D) |
| |
| b,c,h,w = sky_feature_2D.shape |
| zero_pad_sky = torch.zeros((b,c,h,int(w*0.8)),device=sky_feature_2D.device) |
| sky_feature_2D = torch.cat([sky_feature_2D,zero_pad_sky],dim=3) |
| return sky_feature_2D |
| |
| def from_3D_to_results(self, |
| triplane_ori, |
| c2w=None, |
| w_sky=None, |
| sky_feature_2D=None, |
| syn_sat=False, |
| random_sat_crop=True, |
| syn_pano=True, |
| syn_per=False, |
| same_histo=False, |
| intrinsics=None, |
| coordinates=None): |
| results = edict() |
| point_sampling_result = [] |
| w_list = [] |
| syn_sign = [] |
| if type(triplane_ori) is list: |
| N = triplane_ori[0].shape[0] |
| else: |
| N = triplane_ori.shape[0] |
| |
|
|
|
|
| if syn_sat: |
| point_sampling_result_sat = self.point_sampler_sat(batch_size=N,random_crop=random_sat_crop,crop_type='crop') |
| point_sampling_result.append(point_sampling_result_sat) |
| if not same_histo: |
| w_sat = torch.zeros([N,self.w_dim], device=triplane_ori.device if type(triplane_ori) is not list else triplane_ori[0].device) |
| else: |
| w_sat = w_sky |
|
|
| w_list.append(w_sat) |
| syn_sign.append('sat') |
|
|
| if syn_pano: |
| resize_for_pano = False |
| point_sampling_result_pano = self.point_sampler(batch_size=N,position=c2w[:,:3, 3]) |
| if self.training: |
| if point_sampling_result_pano.rays_world.size(1) != point_sampling_result_pano.rays_world.size(2): |
| resize_for_pano = True |
| |
| point_sampling_result_pano.rays_world = rearrange(point_sampling_result_pano.rays_world, 'b h (w d) c -> b (h d) w c', d=2) |
| point_sampling_result_pano.ray_origins = rearrange(point_sampling_result_pano.ray_origins, 'b h (w d) c -> b (h d) w c', d=2) |
| point_sampling_result_pano.points_world = rearrange(point_sampling_result_pano.points_world, 'b h (w d) n c -> b (h d) w n c', d=2) |
| point_sampling_result_pano.radii = rearrange(point_sampling_result_pano.radii, 'b h (w d) c -> b (h d) w c', d=2) |
| point_sampling_result.append(point_sampling_result_pano) |
| w_list.append(w_sky) |
| syn_sign.append('pano') |
| if syn_per: |
| point_sampling_result_per = self.point_sampler_per(intrinsics=intrinsics, c2w=c2w) |
| point_sampling_result.append(point_sampling_result_per) |
| w_list.append(w_sky) |
| syn_sign.append('pespective') |
|
|
| if self.training and len(point_sampling_result) >1: |
| point_sampling_result_cat = edict() |
| point_sampling_result_cat.rays_world = torch.cat([i.rays_world for i in point_sampling_result],dim=0) |
| point_sampling_result_cat.ray_origins = torch.cat([i.ray_origins for i in point_sampling_result],dim=0) |
| point_sampling_result_cat.points_world = torch.cat([i.points_world for i in point_sampling_result],dim=0) |
| point_sampling_result_cat.radii = torch.cat([i.radii for i in point_sampling_result],dim=0) |
|
|
| w_input = torch.cat(w_list,dim=0) |
| if self.opt.representation_type == 'triplane': |
| feature_input = triplane_ori.repeat(len(point_sampling_result),1,1,1,1) |
|
|
| elif self.opt.representation_type in ['oneplane','oneplane_multi']: |
| feature_input = [triplane_ori[0].repeat(len(point_sampling_result),1,1,1,1),triplane_ori[1].repeat(len(point_sampling_result),1,1)] |
| output = self.from_point_sampling2result(point_sampling_result_cat, |
| feature_input, |
| w_sky=w_input, |
| ) |
| else: |
| for i in range(len(point_sampling_result)): |
| if syn_sign[i] == 'sat': |
| results.sat_output = self.from_point_sampling2result(point_sampling_result[i], |
| triplane_ori, |
| w_sky=w_list[i], |
| ) |
| elif syn_sign[i] == 'pano': |
| results.str_output = self.from_point_sampling2result(point_sampling_result[i], |
| triplane_ori, |
| w_sky=w_list[i], |
| ) |
| elif syn_sign[i] == 'pespective': |
| results.per_output = self.from_point_sampling2result(point_sampling_result[i], |
| triplane_ori, |
| w_sky=w_list[i], |
| ) |
| if 'sat' in syn_sign: |
| if self.training and len(point_sampling_result) >1: |
| results.sat_output = edict() |
| results.sat_output.feature_raw = output.feature_raw[:N] |
| results.sat_output.alpha_raw = output.alpha_raw[:N] |
| results.sat_output.image_depth = output.image_depth[:N] |
| results.sat_output.image_radii = output.image_radii[:N] |
| if 'idx' in point_sampling_result_sat.keys(): |
| results.sat_output.idx = point_sampling_result_sat['idx'] |
|
|
| if 'pano' in syn_sign: |
| if self.training and len(point_sampling_result) >1: |
| results.str_output = edict() |
| results.str_output.feature_raw = output.feature_raw[N:2*N] |
| results.str_output.alpha_raw = output.alpha_raw[N:2*N] |
| results.str_output.image_depth = output.image_depth[N:2*N] |
| results.str_output.image_radii = output.image_radii[N:2*N] |
|
|
| if resize_for_pano: |
| results.str_output.feature_raw = rearrange(results.str_output.feature_raw, 'b c (h d) w -> b c h (w d)', d=2) |
| results.str_output.alpha_raw = rearrange(results.str_output.alpha_raw, 'b c (h d) w -> b c h (w d)', d=2) |
| results.str_output.image_depth = rearrange(results.str_output.image_depth, 'b c (h d) w -> b c h (w d)', d=2) |
| results.str_output.image_radii = rearrange(results.str_output.image_radii, 'b c (h d) w -> b c h (w d)', d=2) |
| if 'idx' in point_sampling_result_pano.keys(): |
| results.str_output.idx = point_sampling_result_pano['idx'] |
|
|
| results.str_output.ray_direction = point_sampling_result_pano.rays_world |
| |
| if self.with_sky: |
| ray_direction = xyz2thetaphi(results.str_output.ray_direction) |
| sky_img = F.grid_sample(sky_feature_2D, ray_direction,align_corners=True) |
| sky_img = torch.clamp(sky_img, 0, 1) |
| if resize_for_pano: |
| sky_img = rearrange(sky_img, 'b c (h d) w -> b c h (w d)', d=2) |
| rgb_feature_compo = results.str_output.feature_raw * results.str_output.alpha_raw + sky_img * (1 - results.str_output.alpha_raw) |
| results.str_output.sky_img = sky_img |
| results.str_output.image_raw_compo = rgb_feature_compo |
| if self.sr_factor == 2: |
| results.str_output.sr_image = self.sr_module(rgb_feature_compo) |
|
|
| if 'pespective' in syn_sign: |
| if self.training and len(point_sampling_result) >1: |
| results.per_output = edict() |
| results.per_output.feature_raw = output.feature_raw[-N:] |
| results.per_output.alpha_raw = output.alpha_raw[-N:] |
| results.per_output.image_depth = output.image_depth[-N:] |
| results.per_output.image_radii = output.image_radii[-N:] |
|
|
|
|
| if 'idx' in point_sampling_result_per.keys(): |
| results.per_output.idx = point_sampling_result_per['idx'] |
|
|
| results.per_output.ray_direction = point_sampling_result_per.rays_world |
|
|
| |
| if self.with_sky: |
| ray_direction = xyz2thetaphi(results.per_output.ray_direction) |
| sky_img = F.grid_sample(sky_feature_2D, ray_direction,align_corners=True) |
| sky_img = torch.clamp(sky_img, 0, 1) |
| rgb_feature_compo = results.per_output.feature_raw * results.per_output.alpha_raw + sky_img * (1 - results.per_output.alpha_raw) |
| results.per_output.sky_img = sky_img |
| results.per_output.image_raw_compo = rgb_feature_compo |
| if self.sr_factor == 2: |
| results.per_output.sr_image = self.sr_module(rgb_feature_compo) |
|
|
| if coordinates is not None: |
| |
| results.density = self.density_reg(coordinates,triplane_ori) |
| return results |
| |
| def density_reg(self,coordinates,triplane_ori,sample_color=False,w_sky=None): |
| |
| assert coordinates is not None |
| sample_result = self.sample_mixed(coordinates, |
| triplane_ori, |
| sample_color=sample_color, |
| w_sky=w_sky, |
| ) |
| sample_density = sample_result['density'] |
| color_result = sample_result['color'][...,:3] if sample_color==True else None |
| if self.opt.network.ray_marching_kwargs.density_clamp_mode == 'mipnerf': |
| sample_density = F.softplus(sample_density - 1) |
| elif self.opt.network.ray_marching_kwargs.density_clamp_mode == 'relu': |
| sample_density = F.relu(sample_density + 3) |
| else: |
| raise NotImplementedError |
| if sample_color: |
| return color_result |
| return sample_density |
|
|
|
|
| def forward(self, |
| x, |
| z_ill=None, |
| syn_sat=False, |
| random_sat_crop=True, |
| syn_pano=True, |
| syn_per=False, |
| same_histo=False, |
| intrinsics=None, |
| c2w=None, |
| coordinates=None, |
| ): |
| c2w = self.c2w_prepare(c2w) |
|
|
| triplane_ori = self.from_sat_to_triplane(x) |
|
|
| w_sky = self.w_sky_prepare(z_ill) |
|
|
| sky_feature_2D = self.w_sky2sky_feature_2D(w_sky,z_ill) |
|
|
|
|
| results = self.from_3D_to_results(triplane_ori, |
| c2w, |
| w_sky, |
| sky_feature_2D, |
| syn_sat=syn_sat, |
| random_sat_crop=random_sat_crop, |
| syn_pano=syn_pano, |
| syn_per=syn_per, |
| same_histo=same_histo, |
| intrinsics=intrinsics, |
| coordinates=coordinates) |
| results.triplane = triplane_ori |
|
|
| return results |
| |
| |
|
|
|
|
|
|
|
|
| def sample_mixed(self, |
| coordinates, |
| triplanes, |
| sample_color=False, |
| w_sky=None |
| ): |
|
|
|
|
|
|
| point_features = self.point_representer( |
| coordinates, ref_representation=triplanes) |
| color_density_result = self.mlp(point_features,only_density=not sample_color,style=w_sky) |
|
|
| return color_density_result |
|
|
|
|
| |
| def from_point_sampling2result(self, |
| point_sampling_result, |
| triplanes, |
| w_sky=None, |
| **synthesis_kwarg |
| ): |
| points = point_sampling_result['points_world'] |
| ray_dirs = point_sampling_result['rays_world'] |
| radii_coarse = point_sampling_result['radii'] |
| ray_origins = point_sampling_result['ray_origins'] |
|
|
| _, H, W, K, _ = points.shape |
| R = H * W |
| points_coarse = rearrange(points, 'n h w k c -> n (h w) k c') |
| points = rearrange(points, 'n h w k c -> n (h w k) c') |
| ray_dirs = rearrange(ray_dirs, 'n h w c -> n (h w) c') |
| if len(ray_origins.shape) == 4: |
| ray_origins = rearrange(ray_origins, 'n h w c -> n (h w) c') |
| elif len(ray_origins.shape) == 2: |
| ray_origins = repeat(ray_origins, 'n c -> n (h w) c', h=R, w=1) |
| radii_coarse = rearrange(radii_coarse, 'n h w k -> n (h w) k 1') |
|
|
| point_features = self.point_representer( |
| points, ref_representation=triplanes) |
| color_density_result = self.mlp(point_features,w_sky) |
|
|
| densities_coarse = color_density_result['density'] |
| colors_coarse = color_density_result['color'] |
| densities_coarse = rearrange(densities_coarse, 'n (r k) c -> n r k c', r=R, k=K) |
| colors_coarse = rearrange(colors_coarse, 'n (r k) c -> n r k c', r=R, k=K) |
|
|
| if self.num_importance > 0: |
| |
| rendering_result = self.point_integrator(colors_coarse, |
| densities_coarse, |
| radii_coarse) |
| weights = rendering_result['weight'] |
|
|
| |
| radii_fine = sample_importance(radii_coarse, |
| weights, |
| self.num_importance, |
| smooth_weights=True) |
| points = ray_origins.unsqueeze( |
| -2) + radii_fine * ray_dirs.unsqueeze( |
| -2) |
| points_fine = points |
| points = rearrange(points, 'n r k c -> n (r k) c') |
|
|
| point_features = self.point_representer( |
| points, ref_representation=triplanes) |
| color_density_result = self.mlp(point_features,w_sky) |
|
|
| densities_fine = color_density_result['density'] |
| colors_fine = color_density_result['color'] |
| densities_fine = rearrange(densities_fine, 'n (r k) c -> n r k c', r=R, k=self.num_importance) |
| colors_fine = rearrange(colors_fine, 'n (r k) c -> n r k c', r=R, k=self.num_importance) |
|
|
| |
| (all_radiis, all_colors, all_densities, |
| all_points) = unify_attributes(radii_coarse, |
| colors_coarse, |
| densities_coarse, |
| radii_fine, |
| colors_fine, |
| densities_fine, |
| points1=points_coarse, |
| points2=points_fine) |
|
|
| |
| rendering_result = self.point_integrator(all_colors, |
| all_densities, |
| all_radiis) |
|
|
| else: |
| |
| rendering_result = self.point_integrator(colors_coarse, |
| densities_coarse, |
| radii_coarse) |
| |
|
|
| feature_samples = rendering_result['composite_color'] |
| radii_samples = rendering_result['composite_radial_dist'] |
|
|
| feature_image = rearrange(feature_samples, 'n (h w) c -> n c h w', h=H, w=W).contiguous() |
| image_radii = rearrange(radii_samples, 'n (h w) c -> n c h w', h=H, w=W).contiguous() |
|
|
| image_alpha = rearrange(rendering_result['opacity'], 'n (h w) c -> n c h w', h=H, w=W).contiguous() |
| image_depth = rearrange(rendering_result['composite_radial_dist'], 'n (h w) c -> n c h w', h=H, w=W).contiguous() |
|
|
| |
| result = edict() |
| result.feature_raw = feature_image |
| result.alpha_raw = image_alpha |
| |
| result.image_depth = image_depth |
| result.image_radii = image_radii |
| result.ray_origin = ray_origins |
| if 'idx' in point_sampling_result.keys(): |
| result.idx = point_sampling_result['idx'] |
| return result |
|
|
| @torch.no_grad() |
| def forward_grid(self, planes, grid_size=256,position_scale_factor=1,crop=False): |
| max_batch = 15000000 |
| |
| device = planes[0].device if isinstance(planes, (list, tuple)) else planes.device |
| voxel_grid = create_voxel(N=grid_size,position_scale_factor=1)['voxel_grid'].to(device) |
| densities = torch.zeros( |
| (voxel_grid.shape[0], voxel_grid.shape[1], 1)).to(device) |
| |
|
|
| |
|
|
| head = 0 |
| with tqdm.tqdm(total=voxel_grid.shape[1]) as pbar: |
| with torch.no_grad(): |
| while head < voxel_grid.shape[1]: |
| density = self.density_reg(coordinates=voxel_grid[:, head:head + max_batch],triplane_ori=planes) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| densities[:, head:head + max_batch] = density |
| head = head + max_batch |
| pbar.update(max_batch) |
|
|
| densities = densities.reshape( |
| (grid_size, grid_size, grid_size)).cpu().numpy() |
| |
| |
| |
| |
| if self.position_scale_factor < 1: |
| pad = int(np.round(((1-self.position_scale_factor)*densities.shape[0]/2))) |
| if not crop: |
| |
| |
| |
| pad_value = 0 |
| densities[:pad] = pad_value |
| densities[-pad:] = pad_value |
| densities[:, :pad] = pad_value |
| densities[:, -pad:] = pad_value |
| densities[:, :, :pad] = pad_value |
| else: |
| densities = densities[pad:-pad, pad:-pad, pad:] |
| return densities |
|
|
|
|
| @torch.no_grad() |
| def save_shape_from_sat(self, sat_img, position_scale_factor=1,crop=False,grid_size=320): |
| planes = self.from_sat_to_triplane(sat_img) |
|
|
| return self.forward_grid(planes,position_scale_factor=1,crop=crop,grid_size=grid_size) |
|
|
| @torch.no_grad() |
| def save_shape(self, planes,position_scale_factor=1,save_type='density',crop=False): |
| densities = self.forward_grid(planes,position_scale_factor=position_scale_factor) |
|
|
| if save_type == 'density': |
| try: |
| import mrcfile |
| except ImportError: |
| raise ImportError("mrcfile is required for density export. Install via: pip install mrcfile") |
| with mrcfile.new_mmap(f'0000.mrc', |
| overwrite=True, |
| shape=densities.shape, |
| mrc_mode=2) as mrc: |
| mrc.data[:] = densities |
| print('save density done') |
| |
| try: |
| import open3d as o3d |
| except ImportError: |
| raise ImportError("open3d is required for 3D shape export. Install via: pip install open3d") |
| if save_type == 'mesh': |
| from skimage import measure |
| import trimesh |
| |
| verts, faces, _, _ = measure.marching_cubes(densities, level=4.5) |
|
|
|
|
| |
| mesh = trimesh.Trimesh(vertices=verts, faces=faces) |
|
|
| |
| mesh.vertex_normals |
|
|
| |
| |
|
|
| |
| mesh.export('mesh.ply') |
|
|
|
|
|
|
|
|
| if save_type in ['pointcloud','voxel']: |
|
|
|
|
|
|
| def efficient_filter_numpy(densities, threshold=5): |
| size = densities.shape[0] |
| |
| |
| high_density = np.where(densities >= threshold, 1, 0) |
| |
| |
| x_sum = high_density[:-2, 1:-1, 1:-1] +high_density[1:-1, 1:-1, 1:-1] + high_density[2:, 1:-1, 1:-1] |
| y_sum = high_density[1:-1, :-2, 1:-1] + high_density[1:-1, 1:-1, 1:-1] + high_density[1:-1, 2:, 1:-1] |
| z_sum = high_density[1:-1, 1:-1, :-2] + high_density[1:-1, 1:-1, 1:-1] + high_density[1:-1, 1:-1, 2:] |
| |
| mask = (x_sum == 3) & (y_sum == 3) & (z_sum == 3) |
| |
| |
| densities[1:-1, 1:-1, 1:-1][mask] = 0 |
| |
| return densities |
|
|
| |
| print('the number of voxels >= 5 before filtering:', np.sum(densities >= 5)) |
| densities = efficient_filter_numpy(densities) |
| |
| print('the number of voxels >= 5 after filtering:', np.sum(densities >= 5)) |
| |
| |
| |
|
|
| points = np.array(np.where(densities >= 5)).T |
| points = (points / size) *2 - 1 |
|
|
| point_cloud = o3d.geometry.PointCloud() |
| point_cloud.points = o3d.utility.Vector3dVector(points) |
|
|
|
|
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
|
|
| if save_type == 'pointcloud': |
| |
| o3d.io.write_point_cloud("point_cloud.ply", point_cloud) |
|
|
| if save_type == 'voxel': |
| voxel_size = (1 / size)* 2 |
| voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(point_cloud, voxel_size) |
| o3d.io.write_voxel_grid("voxel_grid.ply", voxel_grid) |
|
|
| |
| |
| |
|
|
| |
|
|
| return 0 |
| |
| |
| def extract_mesh( |
| self, |
| planes: torch.Tensor, |
| mesh_resolution: int = 320, |
| mesh_threshold: int = 5.0, |
| w_sky = None, |
| **kwargs, |
| ): |
| ''' |
| Extract a 3D mesh from triplane nerf. Only support batch_size 1. |
| :param planes: triplane features |
| :param mesh_resolution: marching cubes resolution |
| :param mesh_threshold: iso-surface threshold |
| ''' |
| print('mesh_resolution:', mesh_resolution) |
| device = planes.device if type(planes) is not list else planes[0].device |
|
|
| grid_out = self.forward_grid( |
| planes=planes, |
| grid_size=mesh_resolution, |
| ) |
| try: |
| import mcubes |
| except ImportError: |
| raise ImportError("PyMCubes is required for mesh extraction. Install via: pip install PyMCubes") |
| vertices, faces = mcubes.marching_cubes( |
| grid_out, |
| mesh_threshold, |
| ) |
| vertices = vertices / (mesh_resolution - 1) * 2 - 1 |
| |
| vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0) |
| vertices_colors = self.density_reg(vertices_tensor,planes,sample_color=True,w_sky=w_sky) |
| vertices_colors = (vertices_colors * 255).squeeze(0).cpu().numpy().astype(np.uint8) |
| return vertices, faces, vertices_colors |
| |
| |
|
|
| class EMANorm(nn.Module): |
| def __init__(self, beta): |
| super().__init__() |
| self.register_buffer('magnitude_ema', torch.ones([])) |
| self.beta = beta |
|
|
| def forward(self, x): |
| if self.training: |
| magnitude_cur = x.detach().to(torch.float32).square().mean() |
| self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.beta)) |
| input_gain = self.magnitude_ema.rsqrt() |
| x = x.mul(input_gain) |
| return x |
|
|
|
|
| |
| |
| VAE_finetune = Sat3DGen |
|
|
| |
|
|