Sat3DGen / source /generator.py
qian43's picture
Update source/generator.py
924a6c1 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
# from diffusers import DDIMScheduler
import numpy as np
import math
# from diffusers import StableDiffusionPipeline, DDIMScheduler
# from pytorch_lightning import LightningModule, Trainer
from transformers import AutoConfig, AutoModel
from transformers.modeling_utils import no_init_weights
import warnings
warnings.filterwarnings("ignore")
import os
from pathlib import Path
from einops import rearrange, repeat
from easydict import EasyDict as edict
from source.vae_hacked import Decoder
from source.rendering.utils import sample_importance,unify_attributes, create_voxel
from source.rendering.point_representer import PointRepresenter
from source.rendering.point_integrator import PointIntegrator
from source.rendering.sat2density_transform_eg3d import get_original_coord,Point_sampler_pano,Point_sampler_ortho
from source.rendering.transform_perspective import PointSamplerPerspective
from source.rendering.mlp_model import MLPNetwork2
from source.sr_module import SuperresolutionHybrid2X
from source.xyz2thetaphi import xyz2thetaphi
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin
import tqdm
def normalize_2nd_moment(x, dim=1, eps=1e-8):
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
def resolve_backbone_candidates(backbone):
env_override_map = {
"dinov2-base": "SAT3DGEN_DINOV2_BASE_PATH",
"dinov2-large": "SAT3DGEN_DINOV2_LARGE_PATH",
"dinov3-large-sat": "SAT3DGEN_DINOV3_SAT_PATH",
"dinov3-large-lvd": "SAT3DGEN_DINOV3_LVD_PATH",
}
default_candidate_map = {
"dinov2-base": [
"facebook/dinov2-base",
],
"dinov2-large": [
"facebook/dinov2-large",
],
"dinov3-large-sat": [
"facebook/dinov3-vitl16-pretrain-sat493m",
],
"dinov3-large-lvd": [
"facebook/dinov3-vitl16-pretrain-lvd1689m",
],
}
if backbone not in default_candidate_map:
raise NotImplementedError(f"Unsupported backbone: {backbone}")
candidates = []
env_override = os.environ.get(env_override_map[backbone])
if env_override:
candidates.append(env_override)
candidates.extend(default_candidate_map[backbone])
return candidates
# Built-in backbone configs so we can create the model structure without any
# network access (the gated DINOv3 repos require authentication even to fetch
# the config file). These are architecture-only settings and never change.
_BACKBONE_CONFIGS = {
"dinov2-base": {
"model_type": "dinov2",
"hidden_size": 768,
"num_hidden_layers": 12,
"num_attention_heads": 12,
"intermediate_size": 3072,
"patch_size": 14,
"image_size": 518,
"num_channels": 3,
"num_register_tokens": 0,
},
"dinov2-large": {
"model_type": "dinov2",
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"intermediate_size": 4096,
"patch_size": 14,
"image_size": 518,
"num_channels": 3,
"num_register_tokens": 0,
},
"dinov3-large-sat": {
"model_type": "dinov3_vit",
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"intermediate_size": 4096,
"patch_size": 16,
"image_size": 224,
"num_channels": 3,
"num_register_tokens": 4,
"hidden_act": "gelu",
"attention_dropout": 0.0,
"drop_path_rate": 0.0,
"initializer_range": 0.02,
"layer_norm_eps": 1e-05,
"layerscale_value": 1.0,
"key_bias": False,
"mlp_bias": True,
"proj_bias": True,
"query_bias": True,
"value_bias": True,
"use_gated_mlp": False,
"rope_theta": 100.0,
"pos_embed_rescale": 2.0,
},
"dinov3-large-lvd": {
"model_type": "dinov3_vit",
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"intermediate_size": 4096,
"patch_size": 16,
"image_size": 224,
"num_channels": 3,
"num_register_tokens": 4,
"hidden_act": "gelu",
"attention_dropout": 0.0,
"drop_path_rate": 0.0,
"initializer_range": 0.02,
"layer_norm_eps": 1e-05,
"layerscale_value": 1.0,
"key_bias": False,
"mlp_bias": True,
"proj_bias": True,
"query_bias": True,
"value_bias": True,
"use_gated_mlp": False,
"rope_theta": 100.0,
"pos_embed_rescale": 2.0,
},
}
def load_backbone_model(backbone, skip_weights=False):
"""Load (or create) the backbone vision model.
When *skip_weights* is ``True`` the model structure is instantiated
from a built-in config dict **without** any network access. This is
useful when the caller will overwrite all parameters later (e.g. via
``Sat3DGen.from_pretrained``), avoiding a redundant multi-GB
download of the backbone checkpoint.
"""
if skip_weights:
if backbone not in _BACKBONE_CONFIGS:
raise NotImplementedError(f"No built-in config for backbone: {backbone}")
print(f"Creating backbone structure from built-in config (skip weights): {backbone}")
config = AutoConfig.for_model(**_BACKBONE_CONFIGS[backbone])
with no_init_weights():
model = AutoModel.from_config(config)
return model.eval().requires_grad_(False)
load_errors = []
for candidate in resolve_backbone_candidates(backbone):
expanded_candidate = os.path.expanduser(candidate)
resolved_candidate = expanded_candidate if Path(expanded_candidate).exists() else candidate
try:
print("Trying pretrained_model_name_or_path:", resolved_candidate)
return AutoModel.from_pretrained(resolved_candidate).eval().requires_grad_(False)
except Exception as exc:
load_errors.append(f"{resolved_candidate}: {exc}")
formatted_errors = "\n".join(load_errors)
raise RuntimeError(
f"Failed to load the backbone `{backbone}`.\n"
f"Tried the following candidates:\n{formatted_errors}\n"
"You can override the lookup with the corresponding SAT3DGEN_*_PATH environment variable."
)
class MappingNetwork(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
w_dim, # Intermediate latent (W) dimensionality.
num_layers = 8, # Number of mapping layers.
norm = True,
):
super().__init__()
self.z_dim = z_dim
self.w_dim = w_dim
self.num_layers = num_layers
self.norm = norm
features_list = [z_dim] * (num_layers) + [w_dim]
layers = []
for idx in range(num_layers):
layers.append(nn.Linear(features_list[idx], features_list[idx + 1]))
layers.append(nn.LeakyReLU(0.2))
self.mapping = nn.Sequential(*layers)
def forward(self, z):
# Embed, normalize, and concat inputs.
if self.norm:
z = normalize_2nd_moment(z.to(torch.float32)) # normalize z to sphere
x = self.mapping(z)
return x
class dino_3d_model(nn.Module):
def __init__(self,output_ch=192,ch_mult=[1,2,4,4,4],pad = False, with_attn=True,backbone='dinov2-base',no_hidden_states=False, no_cls_token=False, skip_backbone_weights=False):
super().__init__()
self.dino_model = load_backbone_model(backbone, skip_weights=skip_backbone_weights)
if backbone == 'dinov2-base':
z_channels = 6144 if not no_cls_token else 6144//2
self.feature_list = [3,6,9,12] if not no_hidden_states else []
if self.feature_list == []:
z_channels = z_channels//4
elif backbone in ['dinov2-large',"dinov3-large-sat","dinov3-large-lvd"]:
z_channels = 8192 if not no_cls_token else 8192//2
self.feature_list = [6,12,18,24] if not no_hidden_states else []
if self.feature_list == []:
z_channels = z_channels//4
self.backbone = backbone
self.no_cls_token = no_cls_token
self.decoder = Decoder(ch=128,out_ch=output_ch,ch_mult=ch_mult,num_res_blocks=2,attn_resolutions=[],z_channels=z_channels,resolution=256,in_channels=None,with_attn=with_attn)
self.pad = pad
self.patch_size = self.dino_model.config.to_dict()['patch_size']
self.num_register_tokens = self.dino_model.config.to_dict()['num_register_tokens'] if 'num_register_tokens' in self.dino_model.config.to_dict().keys() else 0
def forward(self, inputs):
_h,_w = inputs.shape[-2:]
assert _h == 16 * self.patch_size
output = self.dino_model(inputs,output_hidden_states=True)
out_put_list = []
if self.feature_list == []:
out_put_list.append(output.last_hidden_state)
else:
for i in self.feature_list:
out_put_list.append(output.hidden_states[i])
# a mistake, because len(output.hidden_states) is 13
# our last feature list is 12,
# and last_hidden_state is layer normed output.hidden_states[-1],
# so we should not append it to out_put_list
x = torch.cat(out_put_list,dim=2)
dino_feature = rearrange(x[:,1+self.num_register_tokens:], 'b (h w) c -> b c h w', h=_h//self.patch_size , w=_w//self.patch_size)
if not self.no_cls_token:
cls_token = x[:,0]
dino_feature = torch.cat([dino_feature, cls_token.unsqueeze(-1).unsqueeze(-1).repeat(1,1,_h//self.patch_size,_w//self.patch_size)],dim=1) # [2, 7680, 16, 16]
# noise = torch.randn_like(dino_feature)
if self.pad:
ori_size = dino_feature.size(-1)
pad_size = ori_size*self.pad
# make sure pad size is int
assert pad_size == int(pad_size), 'pad_size should be int'
pad_size = int(pad_size)
dino_feature = F.pad(dino_feature,(pad_size,pad_size,pad_size,pad_size),'constant', 0)
output = self.decoder(dino_feature) # 320*320 # 4 time 2x upsampling
return output
def convert_to_easydict(d):
if isinstance(d, dict):
return edict({k: convert_to_easydict(v) for k, v in d.items()})
return d
class Sat3DGen(ModelMixin, ConfigMixin):
# When True, skip downloading pretrained backbone weights during __init__.
# The weights will be loaded later by from_pretrained() from the full
# checkpoint (which already contains the backbone parameters), avoiding
# a redundant multi-GB download of the standalone backbone model.
_skip_backbone_weights: bool = False
@register_to_config
def __init__(self, opt):
super().__init__()
self.opt = opt
# if opt is not a edict object, convert it to edict object .
self.opt = convert_to_easydict(opt)
if 'sr_padding_mode' not in self.opt.keys():
self.opt.sr_padding_mode = 'zeros'
if 'representation_type' not in self.opt.keys():
self.opt.representation_type = 'triplane'
self.sat_mapping_mode = 'v2' if not hasattr(self.opt.network, 'sat_mapping_mode') else self.opt.network.sat_mapping_mode
assert self.sat_mapping_mode in ['v2'], 'sat_mapping_mode should be v1 or v2'
self.sr_factor = 1 if not hasattr(self.opt.network, 'sr_factor') else self.opt.network.sr_factor
self.if_w_sky_mapping = True
self.backbone = 'dinov2-base' if not hasattr(self.opt, 'backbone') else self.opt.backbone
if self.if_w_sky_mapping:
self.z_dim = 270
self.w_dim = 512
self.sky_mapping = MappingNetwork(self.z_dim,self.w_dim,norm=False)
else:
self.z_dim = 270
self.w_dim = 270
assert self.sr_factor in [1,2] , 'sr_factor should be 1 or 2'
self.image_size = self.opt.network.image_size # not used
self.latent_size = self.opt.network.latent_size
self.latent_channel = self.opt.network.latent_channel
if 'pad' in self.opt.keys():
self.pad = self.opt.pad
self.position_scale_factor = 1 / (self.pad*2+1)
assert self.opt.network.position_scale_factor ==1, 'position_scale_factor should be 1,not used in this version.'
else:
self.position_scale_factor = self.opt.network.position_scale_factor
self.pad = False
color_channels = 32 if not hasattr(self.opt.network, 'color_channels') else self.opt.network.color_channels
self.sr_module = SuperresolutionHybrid2X(color_channels, 3,padding_mode=self.opt.sr_padding_mode,v2=True)
if self.opt.representation_type == 'triplane':
output_ch = self.opt.network.triplane.dim*3
elif self.opt.representation_type in ['oneplane','oneplane_multi']:
output_ch = self.opt.network.triplane.dim*2
self.with_sky = True
self.sky_input_dim = 2
if self.with_sky:
self.sky_decoder = Decoder(ch=32,out_ch=color_channels,ch_mult=[1,2,2,4,4,4,4],num_res_blocks=2,attn_resolutions=[],z_channels=self.w_dim ,resolution=256,in_channels=None,with_attn=False,pano_pad=True)
self.unet_model = dino_3d_model(output_ch = output_ch,
ch_mult = self.opt.network.triplane.ch_mult if hasattr(self.opt.network.triplane, 'ch_mult') else [1,2,4,4,4],
pad = self.pad,
with_attn = self.opt.network.with_attn if hasattr(self.opt.network, 'with_attn') else True,
backbone = self.backbone,
no_hidden_states=self.opt.network.no_hidden_states if hasattr(self.opt.network, 'no_hidden_states') else False,
no_cls_token=self.opt.network.no_cls_token if hasattr(self.opt.network, 'no_cls_token') else False,
skip_backbone_weights=self._skip_backbone_weights,
)
self.num_importance = self.opt.network.point_sampling_kwargs.num_importance
# delete num_importance from self.opt.network.point_sampling_kwargs.
self.opt.network.point_sampling_kwargs.pop('num_importance')
if self.opt.representation_type == 'oneplane':
input_dim_mlp = self.opt.network.triplane.dim*2
elif self.opt.representation_type in ['triplane','oneplane_multi']:
input_dim_mlp = self.opt.network.triplane.dim
self.mlp = MLPNetwork2(input_dim=input_dim_mlp,
hidden_dim=64,
output_dim=color_channels,
style_dim=self.w_dim,
)
self.point_representer = PointRepresenter(
representation_type=self.opt.representation_type,
triplane_axes=None,
coordinate_scale=None,
)
self.point_integrator = PointIntegrator(**self.opt.network.ray_marching_kwargs)
unused_parameter = ['max_height','origin_height','realworld_scale']
for i in unused_parameter:
if i in self.opt.network.point_sampling_kwargs.keys():
self.opt.network.point_sampling_kwargs.pop(i)
if self.sr_factor ==2:
self.opt.render_size = 256
self.point_sampler_definition(self.opt.render_size if hasattr(self.opt, 'render_size') else 256)
def point_sampler_definition(self, render_size=256):
pano_size = np.array([render_size*2,render_size//2]) / self.sr_factor
pano_dir = get_original_coord(W=int(pano_size[0]),H=int(pano_size[1]),full=True).unsqueeze(0).float()
# Defer .cuda() – the tensor will be moved to the correct device
# when the model is moved via .to(device).
if torch.cuda.is_available():
pano_dir = pano_dir.cuda()
self.pano_direction = pano_dir
# point_sampling_kwargs.pano_direction =
self.point_sampler = Point_sampler_pano(pano_direction=self.pano_direction,**self.opt.network.point_sampling_kwargs)
self.point_sampler_per = PointSamplerPerspective(num_points=self.opt.network.point_sampling_kwargs.num_points,aabb_strict=True,render_size=[render_size// self.sr_factor,render_size// self.sr_factor])
if render_size==256 and self.sr_factor == 2:
self.point_sampler_sat = Point_sampler_ortho(num_points=self.opt.network.point_sampling_kwargs.num_points,position_scale_factor=self.position_scale_factor,render_size=render_size// self.sr_factor)
else:
self.point_sampler_sat = Point_sampler_ortho(num_points=self.opt.network.point_sampling_kwargs.num_points,position_scale_factor=self.position_scale_factor,resolution=int(render_size*1.5),render_size=render_size)
print('render size:', render_size, 'sr_factor:', self.sr_factor)
def from_sat_to_triplane(self,x):
planes_feature = self.unet_model(x)
if self.opt.representation_type == 'triplane':
triplane_ori = rearrange(planes_feature, 'b (n c) h w -> b n c h w',n=3)
elif self.opt.representation_type in ['oneplane','oneplane_multi']:
one_plane_ori = planes_feature[:,:self.opt.network.triplane.dim]
one_plane_ori = rearrange(one_plane_ori, 'b (n c) h w -> b n c h w',n=1)
one_line_ori = planes_feature[:,self.opt.network.triplane.dim:]
one_line_ori = torch.mean(one_line_ori, dim=2, keepdim=False)
triplane_ori = [one_plane_ori,one_line_ori]
return triplane_ori
def c2w_prepare(self, c2w):
if c2w is not None:
c2w[:,:3, 3] = c2w[:,:3, 3] * self.position_scale_factor
return c2w
def w_sky_prepare(self, z_ill):
if z_ill is not None:
if self.if_w_sky_mapping:
w_sky = self.sky_mapping(z_ill)
else:
w_sky = z_ill
else:
w_sky = None
return w_sky
def w_sky2sky_feature_2D(self, w_sky, z_ill=None):
sky_feature_2D = None
if self.with_sky and z_ill is not None:
sky_feature_2D = repeat(w_sky, 'b c -> b c h w', h=8, w=8)
sky_feature_2D = self.sky_decoder(sky_feature_2D)
sky_feature_2D = torch.sigmoid(sky_feature_2D)
# pad to full panorama width
b,c,h,w = sky_feature_2D.shape
zero_pad_sky = torch.zeros((b,c,h,int(w*0.8)),device=sky_feature_2D.device)
sky_feature_2D = torch.cat([sky_feature_2D,zero_pad_sky],dim=3)
return sky_feature_2D
def from_3D_to_results(self,
triplane_ori,
c2w=None,
w_sky=None,
sky_feature_2D=None,
syn_sat=False,
random_sat_crop=True,
syn_pano=True,
syn_per=False,
same_histo=False,
intrinsics=None,
coordinates=None):
results = edict()
point_sampling_result = []
w_list = []
syn_sign = []
if type(triplane_ori) is list:
N = triplane_ori[0].shape[0]
else:
N = triplane_ori.shape[0]
# triplane_ori_repeat = triplane_ori.repeat(2,1,1,1,1)
if syn_sat:
point_sampling_result_sat = self.point_sampler_sat(batch_size=N,random_crop=random_sat_crop,crop_type='crop')
point_sampling_result.append(point_sampling_result_sat)
if not same_histo:
w_sat = torch.zeros([N,self.w_dim], device=triplane_ori.device if type(triplane_ori) is not list else triplane_ori[0].device)
else:
w_sat = w_sky
w_list.append(w_sat)
syn_sign.append('sat')
if syn_pano:
resize_for_pano = False
point_sampling_result_pano = self.point_sampler(batch_size=N,position=c2w[:,:3, 3])
if self.training:
if point_sampling_result_pano.rays_world.size(1) != point_sampling_result_pano.rays_world.size(2):
resize_for_pano = True
# rearrange from [4, 64, 256, 3] to [4, 64*2, 128/2, 3]
point_sampling_result_pano.rays_world = rearrange(point_sampling_result_pano.rays_world, 'b h (w d) c -> b (h d) w c', d=2)
point_sampling_result_pano.ray_origins = rearrange(point_sampling_result_pano.ray_origins, 'b h (w d) c -> b (h d) w c', d=2)
point_sampling_result_pano.points_world = rearrange(point_sampling_result_pano.points_world, 'b h (w d) n c -> b (h d) w n c', d=2)
point_sampling_result_pano.radii = rearrange(point_sampling_result_pano.radii, 'b h (w d) c -> b (h d) w c', d=2)
point_sampling_result.append(point_sampling_result_pano)
w_list.append(w_sky)
syn_sign.append('pano')
if syn_per:
point_sampling_result_per = self.point_sampler_per(intrinsics=intrinsics, c2w=c2w)
point_sampling_result.append(point_sampling_result_per)
w_list.append(w_sky)
syn_sign.append('pespective')
if self.training and len(point_sampling_result) >1:
point_sampling_result_cat = edict()
point_sampling_result_cat.rays_world = torch.cat([i.rays_world for i in point_sampling_result],dim=0)
point_sampling_result_cat.ray_origins = torch.cat([i.ray_origins for i in point_sampling_result],dim=0)
point_sampling_result_cat.points_world = torch.cat([i.points_world for i in point_sampling_result],dim=0)
point_sampling_result_cat.radii = torch.cat([i.radii for i in point_sampling_result],dim=0)
w_input = torch.cat(w_list,dim=0)
if self.opt.representation_type == 'triplane':
feature_input = triplane_ori.repeat(len(point_sampling_result),1,1,1,1)
elif self.opt.representation_type in ['oneplane','oneplane_multi']:
feature_input = [triplane_ori[0].repeat(len(point_sampling_result),1,1,1,1),triplane_ori[1].repeat(len(point_sampling_result),1,1)]
output = self.from_point_sampling2result(point_sampling_result_cat,
feature_input,
w_sky=w_input,
)
else:
for i in range(len(point_sampling_result)):
if syn_sign[i] == 'sat':
results.sat_output = self.from_point_sampling2result(point_sampling_result[i],
triplane_ori,
w_sky=w_list[i],
)
elif syn_sign[i] == 'pano':
results.str_output = self.from_point_sampling2result(point_sampling_result[i],
triplane_ori,
w_sky=w_list[i],
)
elif syn_sign[i] == 'pespective':
results.per_output = self.from_point_sampling2result(point_sampling_result[i],
triplane_ori,
w_sky=w_list[i],
)
if 'sat' in syn_sign:
if self.training and len(point_sampling_result) >1:
results.sat_output = edict()
results.sat_output.feature_raw = output.feature_raw[:N]
results.sat_output.alpha_raw = output.alpha_raw[:N]
results.sat_output.image_depth = output.image_depth[:N]
results.sat_output.image_radii = output.image_radii[:N]
if 'idx' in point_sampling_result_sat.keys():
results.sat_output.idx = point_sampling_result_sat['idx']
if 'pano' in syn_sign:
if self.training and len(point_sampling_result) >1:
results.str_output = edict()
results.str_output.feature_raw = output.feature_raw[N:2*N]
results.str_output.alpha_raw = output.alpha_raw[N:2*N]
results.str_output.image_depth = output.image_depth[N:2*N]
results.str_output.image_radii = output.image_radii[N:2*N]
if resize_for_pano:
results.str_output.feature_raw = rearrange(results.str_output.feature_raw, 'b c (h d) w -> b c h (w d)', d=2)
results.str_output.alpha_raw = rearrange(results.str_output.alpha_raw, 'b c (h d) w -> b c h (w d)', d=2)
results.str_output.image_depth = rearrange(results.str_output.image_depth, 'b c (h d) w -> b c h (w d)', d=2)
results.str_output.image_radii = rearrange(results.str_output.image_radii, 'b c (h d) w -> b c h (w d)', d=2)
if 'idx' in point_sampling_result_pano.keys():
results.str_output.idx = point_sampling_result_pano['idx']
results.str_output.ray_direction = point_sampling_result_pano.rays_world
# render sky
if self.with_sky:
ray_direction = xyz2thetaphi(results.str_output.ray_direction)
sky_img = F.grid_sample(sky_feature_2D, ray_direction,align_corners=True)
sky_img = torch.clamp(sky_img, 0, 1)
if resize_for_pano:
sky_img = rearrange(sky_img, 'b c (h d) w -> b c h (w d)', d=2)
rgb_feature_compo = results.str_output.feature_raw * results.str_output.alpha_raw + sky_img * (1 - results.str_output.alpha_raw)
results.str_output.sky_img = sky_img
results.str_output.image_raw_compo = rgb_feature_compo
if self.sr_factor == 2:
results.str_output.sr_image = self.sr_module(rgb_feature_compo)
if 'pespective' in syn_sign:
if self.training and len(point_sampling_result) >1:
results.per_output = edict()
results.per_output.feature_raw = output.feature_raw[-N:]
results.per_output.alpha_raw = output.alpha_raw[-N:]
results.per_output.image_depth = output.image_depth[-N:]
results.per_output.image_radii = output.image_radii[-N:]
if 'idx' in point_sampling_result_per.keys():
results.per_output.idx = point_sampling_result_per['idx']
results.per_output.ray_direction = point_sampling_result_per.rays_world
# render sky
if self.with_sky:
ray_direction = xyz2thetaphi(results.per_output.ray_direction)
sky_img = F.grid_sample(sky_feature_2D, ray_direction,align_corners=True)
sky_img = torch.clamp(sky_img, 0, 1)
rgb_feature_compo = results.per_output.feature_raw * results.per_output.alpha_raw + sky_img * (1 - results.per_output.alpha_raw)
results.per_output.sky_img = sky_img
results.per_output.image_raw_compo = rgb_feature_compo
if self.sr_factor == 2:
results.per_output.sr_image = self.sr_module(rgb_feature_compo)
if coordinates is not None:
# for density regularization
results.density = self.density_reg(coordinates,triplane_ori)
return results
def density_reg(self,coordinates,triplane_ori,sample_color=False,w_sky=None):
# Only for density regularization in training process.
assert coordinates is not None
sample_result = self.sample_mixed(coordinates,
triplane_ori,
sample_color=sample_color,
w_sky=w_sky,
)
sample_density = sample_result['density']
color_result = sample_result['color'][...,:3] if sample_color==True else None
if self.opt.network.ray_marching_kwargs.density_clamp_mode == 'mipnerf':
sample_density = F.softplus(sample_density - 1)
elif self.opt.network.ray_marching_kwargs.density_clamp_mode == 'relu':
sample_density = F.relu(sample_density + 3)
else:
raise NotImplementedError
if sample_color:
return color_result
return sample_density
def forward(self,
x,
z_ill=None,
syn_sat=False,
random_sat_crop=True,
syn_pano=True,
syn_per=False,
same_histo=False,
intrinsics=None,
c2w=None,
coordinates=None,
):
c2w = self.c2w_prepare(c2w)
triplane_ori = self.from_sat_to_triplane(x)
w_sky = self.w_sky_prepare(z_ill)
sky_feature_2D = self.w_sky2sky_feature_2D(w_sky,z_ill)
results = self.from_3D_to_results(triplane_ori,
c2w,
w_sky,
sky_feature_2D,
syn_sat=syn_sat,
random_sat_crop=random_sat_crop,
syn_pano=syn_pano,
syn_per=syn_per,
same_histo=same_histo,
intrinsics=intrinsics,
coordinates=coordinates)
results.triplane = triplane_ori
return results
def sample_mixed(self,
coordinates,
triplanes,
sample_color=False,
w_sky=None
):
point_features = self.point_representer(
coordinates, ref_representation=triplanes)
color_density_result = self.mlp(point_features,only_density=not sample_color,style=w_sky) # point_features: B N C
return color_density_result
def from_point_sampling2result(self,
point_sampling_result,
triplanes,
w_sky=None,
**synthesis_kwarg
):
points = point_sampling_result['points_world'] # [N, H, W, K, 3]
ray_dirs = point_sampling_result['rays_world'] # [N, H, W, 3]
radii_coarse = point_sampling_result['radii'] # [N, H, W, K]
ray_origins = point_sampling_result['ray_origins'] # [N, 3]
_, H, W, K, _ = points.shape
R = H * W
points_coarse = rearrange(points, 'n h w k c -> n (h w) k c') # [N, R, K, 3]
points = rearrange(points, 'n h w k c -> n (h w k) c') # [N, R * K, 3]
ray_dirs = rearrange(ray_dirs, 'n h w c -> n (h w) c')
if len(ray_origins.shape) == 4:
ray_origins = rearrange(ray_origins, 'n h w c -> n (h w) c')
elif len(ray_origins.shape) == 2:
ray_origins = repeat(ray_origins, 'n c -> n (h w) c', h=R, w=1)
radii_coarse = rearrange(radii_coarse, 'n h w k -> n (h w) k 1')
point_features = self.point_representer(
points, ref_representation=triplanes) # [N, R * K, C]
color_density_result = self.mlp(point_features,w_sky) # point_features: B N C
densities_coarse = color_density_result['density'] # [N, R * K, 1]
colors_coarse = color_density_result['color'] # [N, R * K, C]
densities_coarse = rearrange(densities_coarse, 'n (r k) c -> n r k c', r=R, k=K)
colors_coarse = rearrange(colors_coarse, 'n (r k) c -> n r k c', r=R, k=K)
if self.num_importance > 0:
# Do the integration along the coarse pass.
rendering_result = self.point_integrator(colors_coarse,
densities_coarse,
radii_coarse)
weights = rendering_result['weight']
# Importance sampling.
radii_fine = sample_importance(radii_coarse,
weights,
self.num_importance,
smooth_weights=True)
points = ray_origins.unsqueeze(
-2) + radii_fine * ray_dirs.unsqueeze(
-2) # [N, R, num_importance, 3]
points_fine = points
points = rearrange(points, 'n r k c -> n (r k) c') # [N, R * num_importance, 3]
point_features = self.point_representer(
points, ref_representation=triplanes)
color_density_result = self.mlp(point_features,w_sky)
densities_fine = color_density_result['density']
colors_fine = color_density_result['color']
densities_fine = rearrange(densities_fine, 'n (r k) c -> n r k c', r=R, k=self.num_importance)
colors_fine = rearrange(colors_fine, 'n (r k) c -> n r k c', r=R, k=self.num_importance)
# Gather coarse and fine results together.
(all_radiis, all_colors, all_densities,
all_points) = unify_attributes(radii_coarse,
colors_coarse,
densities_coarse,
radii_fine,
colors_fine,
densities_fine,
points1=points_coarse,
points2=points_fine)
# Do the integration along the fine pass.
rendering_result = self.point_integrator(all_colors,
all_densities,
all_radiis)
else:
# Only do the integration along the coarse pass.
rendering_result = self.point_integrator(colors_coarse,
densities_coarse,
radii_coarse)
# all_points = points_coarse # [N, R, K, 3]
feature_samples = rendering_result['composite_color']
radii_samples = rendering_result['composite_radial_dist']
feature_image = rearrange(feature_samples, 'n (h w) c -> n c h w', h=H, w=W).contiguous() # [N, C, H, W]
image_radii = rearrange(radii_samples, 'n (h w) c -> n c h w', h=H, w=W).contiguous() # [N, 1, H, W]
image_alpha = rearrange(rendering_result['opacity'], 'n (h w) c -> n c h w', h=H, w=W).contiguous()
image_depth = rearrange(rendering_result['composite_radial_dist'], 'n (h w) c -> n c h w', h=H, w=W).contiguous()
# rgb_image = feature_image[:, :3]
result = edict()
result.feature_raw = feature_image
result.alpha_raw = image_alpha
# result.image_raw = rgb_image
result.image_depth = image_depth
result.image_radii = image_radii
result.ray_origin = ray_origins
if 'idx' in point_sampling_result.keys():
result.idx = point_sampling_result['idx']
return result
@torch.no_grad()
def forward_grid(self, planes, grid_size=256,position_scale_factor=1,crop=False):
max_batch = 15000000
# size = int(grid_size/self.position_scale_factor)
device = planes[0].device if isinstance(planes, (list, tuple)) else planes.device
voxel_grid = create_voxel(N=grid_size,position_scale_factor=1)['voxel_grid'].to(device)
densities = torch.zeros(
(voxel_grid.shape[0], voxel_grid.shape[1], 1)).to(device)
# data/CVACT/satview_correct/
# read img to cuda, [-1,1]
head = 0
with tqdm.tqdm(total=voxel_grid.shape[1]) as pbar:
with torch.no_grad():
while head < voxel_grid.shape[1]:
density = self.density_reg(coordinates=voxel_grid[:, head:head + max_batch],triplane_ori=planes)
# density = self.forward(sat_img,
# z,
# None,
# syn_pano=False,
# coordinates=voxel_grid[:, head:head + max_batch])['density']
# if self.opt.network.ray_marching_kwargs.density_clamp_mode == 'mipnerf':
# densities = F.softplus(densities - 1)
# else:
# raise NotImplementedError
# density = G.sample(
# voxel_grid[:, head:head + max_batch],
# batch_codes,
# sat_img,
# **G_kwargs)['density']
densities[:, head:head + max_batch] = density
head = head + max_batch
pbar.update(max_batch)
densities = densities.reshape(
(grid_size, grid_size, grid_size)).cpu().numpy()
# densities = np.flip(densities, 0)
# densities = np.flip(densities, 0)
# densities = np.flip(densities, 1)
# # Trim the border of the extracted cube.
if self.position_scale_factor < 1:
pad = int(np.round(((1-self.position_scale_factor)*densities.shape[0]/2)))
if not crop:
# densities = densities[pad:-pad, pad:-pad, pad:]
# return densities
# else:
pad_value = 0
densities[:pad] = pad_value
densities[-pad:] = pad_value
densities[:, :pad] = pad_value
densities[:, -pad:] = pad_value
densities[:, :, :pad] = pad_value # z space
else:
densities = densities[pad:-pad, pad:-pad, pad:]
return densities
@torch.no_grad()
def save_shape_from_sat(self, sat_img, position_scale_factor=1,crop=False,grid_size=320):
planes = self.from_sat_to_triplane(sat_img)
return self.forward_grid(planes,position_scale_factor=1,crop=crop,grid_size=grid_size)
@torch.no_grad()
def save_shape(self, planes,position_scale_factor=1,save_type='density',crop=False):
densities = self.forward_grid(planes,position_scale_factor=position_scale_factor)
if save_type == 'density':
try:
import mrcfile
except ImportError:
raise ImportError("mrcfile is required for density export. Install via: pip install mrcfile")
with mrcfile.new_mmap(f'0000.mrc',
overwrite=True,
shape=densities.shape,
mrc_mode=2) as mrc:
mrc.data[:] = densities
print('save density done')
try:
import open3d as o3d
except ImportError:
raise ImportError("open3d is required for 3D shape export. Install via: pip install open3d")
if save_type == 'mesh':
from skimage import measure
import trimesh
# Extract a mesh with Marching Cubes.
verts, faces, _, _ = measure.marching_cubes(densities, level=4.5)
# Build the Trimesh object.
mesh = trimesh.Trimesh(vertices=verts, faces=faces)
# Compute vertex normals.
mesh.vertex_normals
# Optional mesh visualization.
# mesh.show()
# Export the mesh as a PLY file.
mesh.export('mesh.ply')
if save_type in ['pointcloud','voxel']:
def efficient_filter_numpy(densities, threshold=5):
size = densities.shape[0]
# Mark voxels whose density is above the threshold.
high_density = np.where(densities >= threshold, 1, 0)
# Count high-density voxels along each local axis.
x_sum = high_density[:-2, 1:-1, 1:-1] +high_density[1:-1, 1:-1, 1:-1] + high_density[2:, 1:-1, 1:-1]
y_sum = high_density[1:-1, :-2, 1:-1] + high_density[1:-1, 1:-1, 1:-1] + high_density[1:-1, 2:, 1:-1]
z_sum = high_density[1:-1, 1:-1, :-2] + high_density[1:-1, 1:-1, 1:-1] + high_density[1:-1, 1:-1, 2:]
# Keep only voxels that satisfy all local support conditions.
mask = (x_sum == 3) & (y_sum == 3) & (z_sum == 3)
# Remove voxels that pass the mask.
densities[1:-1, 1:-1, 1:-1][mask] = 0
return densities
# print the number of voxels >= 5
print('the number of voxels >= 5 before filtering:', np.sum(densities >= 5))
densities = efficient_filter_numpy(densities)
# print the number of voxels >= 5 after filtering
print('the number of voxels >= 5 after filtering:', np.sum(densities >= 5))
points = np.array(np.where(densities >= 5)).T
points = (points / size) *2 - 1
point_cloud = o3d.geometry.PointCloud()
point_cloud.points = o3d.utility.Vector3dVector(points)
# def position_to_color(points):
# # Map point coordinates from [-1, 1] to [0, 1].
# normalized_points = (points + 1) / 3
# # Use x, y, z as r, g, b.
# colors = normalized_points
# # A more complex color mapping is also possible.
# # colors = np.column_stack([
# # normalized_points[:, 0], # r from x
# # (normalized_points[:, 1] + normalized_points[:, 2]) / 2, # g from (y+z)/2
# # 1 - normalized_points[:, 2] # b from 1-z
# # ])
# return colors
# colors = position_to_color(np.asarray(point_cloud.points))
# point_cloud.colors = o3d.utility.Vector3dVector(colors)
if save_type == 'pointcloud':
# save point cloud
o3d.io.write_point_cloud("point_cloud.ply", point_cloud)
if save_type == 'voxel':
voxel_size = (1 / size)* 2
voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(point_cloud, voxel_size)
o3d.io.write_voxel_grid("voxel_grid.ply", voxel_grid)
# elif save_type == 'mesh':
# mesh = voxel_grid.to_mesh()
# o3d.io.write_triangle_mesh("mesh.ply", mesh)
# print(xyz.shape)
return 0
def extract_mesh(
self,
planes: torch.Tensor,
mesh_resolution: int = 320,
mesh_threshold: int = 5.0,
w_sky = None,
**kwargs,
):
'''
Extract a 3D mesh from triplane nerf. Only support batch_size 1.
:param planes: triplane features
:param mesh_resolution: marching cubes resolution
:param mesh_threshold: iso-surface threshold
'''
print('mesh_resolution:', mesh_resolution)
device = planes.device if type(planes) is not list else planes[0].device
grid_out = self.forward_grid(
planes=planes,
grid_size=mesh_resolution,
)
try:
import mcubes
except ImportError:
raise ImportError("PyMCubes is required for mesh extraction. Install via: pip install PyMCubes")
vertices, faces = mcubes.marching_cubes(
grid_out,
mesh_threshold,
)
vertices = vertices / (mesh_resolution - 1) * 2 - 1
# query vertex colors
vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0)
vertices_colors = self.density_reg(vertices_tensor,planes,sample_color=True,w_sky=w_sky)
vertices_colors = (vertices_colors * 255).squeeze(0).cpu().numpy().astype(np.uint8)
return vertices, faces, vertices_colors
class EMANorm(nn.Module):
def __init__(self, beta):
super().__init__()
self.register_buffer('magnitude_ema', torch.ones([]))
self.beta = beta
def forward(self, x):
if self.training:
magnitude_cur = x.detach().to(torch.float32).square().mean()
self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.beta))
input_gain = self.magnitude_ema.rsqrt()
x = x.mul(input_gain)
return x
# Backward-compatible alias so that existing config.json files with
# "_class_name": "VAE_finetune" (e.g. on HuggingFace) keep working.
VAE_finetune = Sat3DGen