| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from threading import local |
| import torch |
| import torch.nn as nn |
| from utils.torch_utils import persistence |
| from .networks_stylegan2 import Generator as StyleGAN2Backbone |
| from .networks_stylegan2 import ToRGBLayer, SynthesisNetwork, MappingNetwork |
| from .volumetric_rendering.renderer import ImportanceRenderer |
| from .volumetric_rendering.ray_sampler import RaySampler, PatchRaySampler |
| import dnnlib |
| from pdb import set_trace as st |
| import math |
|
|
| import torch.nn.functional as F |
| import itertools |
| from ldm.modules.diffusionmodules.model import SimpleDecoder, Decoder |
|
|
|
|
| @persistence.persistent_class |
| class TriPlaneGenerator(torch.nn.Module): |
|
|
| def __init__( |
| self, |
| z_dim, |
| c_dim, |
| w_dim, |
| img_resolution, |
| img_channels, |
| sr_num_fp16_res=0, |
| mapping_kwargs={}, |
| rendering_kwargs={}, |
| sr_kwargs={}, |
| bcg_synthesis_kwargs={}, |
| |
| |
| **synthesis_kwargs, |
| ): |
| super().__init__() |
| self.z_dim = z_dim |
| self.c_dim = c_dim |
| self.w_dim = w_dim |
| self.img_resolution = img_resolution |
| self.img_channels = img_channels |
| self.renderer = ImportanceRenderer() |
| |
| |
| |
| |
| self.backbone = StyleGAN2Backbone(z_dim, |
| c_dim, |
| w_dim, |
| img_resolution=256, |
| img_channels=32 * 3, |
| mapping_kwargs=mapping_kwargs, |
| **synthesis_kwargs) |
| self.superresolution = dnnlib.util.construct_class_by_name( |
| class_name=rendering_kwargs['superresolution_module'], |
| channels=32, |
| img_resolution=img_resolution, |
| sr_num_fp16_res=sr_num_fp16_res, |
| sr_antialias=rendering_kwargs['sr_antialias'], |
| **sr_kwargs) |
|
|
| |
| if rendering_kwargs.get('use_background', False): |
| self.bcg_synthesis = SynthesisNetwork( |
| w_dim, |
| img_resolution=self.superresolution.input_resolution, |
| img_channels=32, |
| **bcg_synthesis_kwargs) |
| self.bcg_mapping = MappingNetwork(z_dim=z_dim, |
| c_dim=c_dim, |
| w_dim=w_dim, |
| num_ws=self.num_ws, |
| **mapping_kwargs) |
| |
|
|
| self.decoder = OSGDecoder( |
| 32, { |
| 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), |
| 'decoder_output_dim': 32 |
| }) |
| self.neural_rendering_resolution = 64 |
| self.rendering_kwargs = rendering_kwargs |
|
|
| self._last_planes = None |
| self.pool_256 = torch.nn.AdaptiveAvgPool2d((256, 256)) |
|
|
| def mapping(self, |
| z, |
| c, |
| truncation_psi=1, |
| truncation_cutoff=None, |
| update_emas=False): |
| if self.rendering_kwargs['c_gen_conditioning_zero']: |
| c = torch.zeros_like(c) |
| return self.backbone.mapping(z, |
| c * |
| self.rendering_kwargs.get('c_scale', 0), |
| truncation_psi=truncation_psi, |
| truncation_cutoff=truncation_cutoff, |
| update_emas=update_emas) |
|
|
| def synthesis(self, |
| ws, |
| c, |
| neural_rendering_resolution=None, |
| update_emas=False, |
| cache_backbone=False, |
| use_cached_backbone=False, |
| return_meta=False, |
| return_raw_only=False, |
| **synthesis_kwargs): |
|
|
| return_sampling_details_flag = self.rendering_kwargs.get( |
| 'return_sampling_details_flag', False) |
|
|
| if return_sampling_details_flag: |
| return_meta = True |
|
|
| cam2world_matrix = c[:, :16].view(-1, 4, 4) |
| |
| |
| intrinsics = c[:, 16:25].view(-1, 3, 3) |
|
|
| if neural_rendering_resolution is None: |
| neural_rendering_resolution = self.neural_rendering_resolution |
| else: |
| self.neural_rendering_resolution = neural_rendering_resolution |
|
|
| H = W = self.neural_rendering_resolution |
| |
| ray_origins, ray_directions = self.ray_sampler( |
| cam2world_matrix, intrinsics, neural_rendering_resolution) |
|
|
| |
| N, M, _ = ray_origins.shape |
| if use_cached_backbone and self._last_planes is not None: |
| planes = self._last_planes |
| else: |
| planes = self.backbone.synthesis( |
| ws[:, :self.backbone.num_ws, :], |
| update_emas=update_emas, |
| **synthesis_kwargs) |
| if cache_backbone: |
| self._last_planes = planes |
|
|
| |
| planes = planes.view(len(planes), 3, 32, planes.shape[-2], |
| planes.shape[-1]) |
|
|
| |
| |
| rendering_details = self.renderer( |
| planes, |
| self.decoder, |
| ray_origins, |
| ray_directions, |
| self.rendering_kwargs, |
| |
| return_meta=return_meta) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| feature_samples, depth_samples, weights_samples = ( |
| rendering_details[k] |
| for k in ['feature_samples', 'depth_samples', 'weights_samples']) |
|
|
| if return_sampling_details_flag: |
| shape_synthesized = rendering_details['shape_synthesized'] |
| else: |
| shape_synthesized = None |
|
|
| |
| feature_image = feature_samples.permute(0, 2, 1).reshape( |
| N, feature_samples.shape[-1], H, W).contiguous() |
| depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
|
| |
| rgb_image = feature_image[:, :3] |
| if not return_raw_only: |
| sr_image = self.superresolution( |
| rgb_image, |
| feature_image, |
| ws[:, -1:, :], |
| noise_mode=self.rendering_kwargs['superresolution_noise_mode'], |
| **{ |
| k: synthesis_kwargs[k] |
| for k in synthesis_kwargs.keys() if k != 'noise_mode' |
| }) |
| else: |
| sr_image = rgb_image |
|
|
| ret_dict = { |
| 'image': sr_image, |
| 'image_raw': rgb_image, |
| 'image_depth': depth_image, |
| 'weights_samples': weights_samples, |
| 'shape_synthesized': shape_synthesized |
| } |
| if return_meta: |
| ret_dict.update({ |
| |
| 'feature_volume': |
| rendering_details['feature_volume'], |
| 'all_coords': |
| rendering_details['all_coords'], |
| 'weights': |
| rendering_details['weights'], |
| }) |
|
|
| return ret_dict |
|
|
| def sample(self, |
| coordinates, |
| directions, |
| z, |
| c, |
| truncation_psi=1, |
| truncation_cutoff=None, |
| update_emas=False, |
| **synthesis_kwargs): |
| |
| ws = self.mapping(z, |
| c, |
| truncation_psi=truncation_psi, |
| truncation_cutoff=truncation_cutoff, |
| update_emas=update_emas) |
| planes = self.backbone.synthesis(ws, |
| update_emas=update_emas, |
| **synthesis_kwargs) |
| planes = planes.view(len(planes), 3, 32, planes.shape[-2], |
| planes.shape[-1]) |
| return self.renderer.run_model(planes, self.decoder, coordinates, |
| directions, self.rendering_kwargs) |
|
|
| def sample_mixed(self, |
| coordinates, |
| directions, |
| ws, |
| truncation_psi=1, |
| truncation_cutoff=None, |
| update_emas=False, |
| **synthesis_kwargs): |
| |
| planes = self.backbone.synthesis(ws, |
| update_emas=update_emas, |
| **synthesis_kwargs) |
| planes = planes.view(len(planes), 3, 32, planes.shape[-2], |
| planes.shape[-1]) |
| return self.renderer.run_model(planes, self.decoder, coordinates, |
| directions, self.rendering_kwargs) |
|
|
| def forward(self, |
| z, |
| c, |
| truncation_psi=1, |
| truncation_cutoff=None, |
| neural_rendering_resolution=None, |
| update_emas=False, |
| cache_backbone=False, |
| use_cached_backbone=False, |
| **synthesis_kwargs): |
| |
| ws = self.mapping(z, |
| c, |
| truncation_psi=truncation_psi, |
| truncation_cutoff=truncation_cutoff, |
| update_emas=update_emas) |
| return self.synthesis( |
| ws, |
| c, |
| update_emas=update_emas, |
| neural_rendering_resolution=neural_rendering_resolution, |
| cache_backbone=cache_backbone, |
| use_cached_backbone=use_cached_backbone, |
| **synthesis_kwargs) |
|
|
|
|
| from .networks_stylegan2 import FullyConnectedLayer |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| @persistence.persistent_class |
| class OSGDecoder(torch.nn.Module): |
|
|
| def __init__(self, n_features, options): |
| super().__init__() |
| self.hidden_dim = 64 |
| self.decoder_output_dim = options['decoder_output_dim'] |
|
|
| self.net = torch.nn.Sequential( |
| FullyConnectedLayer(n_features, |
| self.hidden_dim, |
| lr_multiplier=options['decoder_lr_mul']), |
| torch.nn.Softplus(), |
| FullyConnectedLayer(self.hidden_dim, |
| 1 + options['decoder_output_dim'], |
| lr_multiplier=options['decoder_lr_mul'])) |
| self.activation = options.get('decoder_activation', 'sigmoid') |
|
|
| def forward(self, sampled_features, ray_directions): |
| |
| sampled_features = sampled_features.mean(1) |
| x = sampled_features |
|
|
| N, M, C = x.shape |
| x = x.view(N * M, C) |
|
|
| x = self.net(x) |
| x = x.view(N, M, -1) |
| rgb = x[..., 1:] |
| sigma = x[..., 0:1] |
| if self.activation == "sigmoid": |
| |
| rgb = torch.sigmoid(rgb) * (1 + 2 * 0.001) - 0.001 |
| elif self.activation == "lrelu": |
| |
| rgb = torch.nn.functional.leaky_relu(rgb, 0.2, |
| inplace=True) * math.sqrt(2) |
| return {'rgb': rgb, 'sigma': sigma} |
|
|
|
|
| class LRMOSGDecoder(nn.Module): |
| """ |
| Triplane decoder that gives RGB and sigma values from sampled features. |
| Using ReLU here instead of Softplus in the original implementation. |
| |
| Reference: |
| EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 |
| """ |
| def __init__(self, n_features: int, |
| hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): |
| super().__init__() |
| self.decoder_output_dim = 3 |
| self.net = nn.Sequential( |
| nn.Linear(3 * n_features, hidden_dim), |
| activation(), |
| *itertools.chain(*[[ |
| nn.Linear(hidden_dim, hidden_dim), |
| activation(), |
| ] for _ in range(num_layers - 2)]), |
| nn.Linear(hidden_dim, 1 + self.decoder_output_dim), |
| ) |
| |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, sampled_features, ray_directions): |
| |
| |
| |
| _N, n_planes, _M, _C = sampled_features.shape |
| sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) |
| x = sampled_features |
|
|
| N, M, C = x.shape |
| x = x.contiguous().view(N*M, C) |
|
|
| x = self.net(x) |
| x = x.view(N, M, -1) |
| rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 |
| sigma = x[..., 0:1] |
|
|
| return {'rgb': rgb, 'sigma': sigma} |
|
|
|
|
| class Triplane(torch.nn.Module): |
|
|
| def __init__( |
| self, |
| c_dim=25, |
| img_resolution=128, |
| img_channels=3, |
| out_chans=96, |
| triplane_size=224, |
| rendering_kwargs={}, |
| decoder_in_chans=32, |
| decoder_output_dim=32, |
| sr_num_fp16_res=0, |
| sr_kwargs={}, |
| create_triplane=False, |
| bcg_synthesis_kwargs={}, |
| lrm_decoder=False, |
| ): |
| super().__init__() |
| self.c_dim = c_dim |
| self.img_resolution = img_resolution |
| self.img_channels = img_channels |
| self.triplane_size = triplane_size |
|
|
| self.decoder_in_chans = decoder_in_chans |
| self.out_chans = out_chans |
|
|
| self.renderer = ImportanceRenderer() |
|
|
| if 'PatchRaySampler' in rendering_kwargs: |
| self.ray_sampler = PatchRaySampler() |
| else: |
| self.ray_sampler = RaySampler() |
|
|
| if lrm_decoder: |
| self.decoder = LRMOSGDecoder( |
| decoder_in_chans,) |
| else: |
| self.decoder = OSGDecoder( |
| decoder_in_chans, |
| { |
| 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), |
| |
| 'decoder_output_dim': decoder_output_dim |
| }) |
|
|
| self.neural_rendering_resolution = img_resolution |
| |
| self.rendering_kwargs = rendering_kwargs |
| self.create_triplane = create_triplane |
| if create_triplane: |
| self.planes = nn.Parameter(torch.randn(1, out_chans, 256, 256)) |
|
|
| if bool(sr_kwargs): |
| assert decoder_in_chans == decoder_output_dim, 'tradition' |
| if rendering_kwargs['superresolution_module'] in [ |
| 'utils.torch_utils.components.PixelUnshuffleUpsample', |
| 'utils.torch_utils.components.NearestConvSR', |
| 'utils.torch_utils.components.NearestConvSR_Residual' |
| ]: |
| self.superresolution = dnnlib.util.construct_class_by_name( |
| class_name=rendering_kwargs['superresolution_module'], |
| |
| sr_ratio=2, |
| output_dim=decoder_output_dim, |
| num_out_ch=3, |
| ) |
| else: |
| self.superresolution = dnnlib.util.construct_class_by_name( |
| class_name=rendering_kwargs['superresolution_module'], |
| |
| channels=decoder_output_dim, |
| img_resolution=img_resolution, |
| sr_num_fp16_res=sr_num_fp16_res, |
| sr_antialias=rendering_kwargs['sr_antialias'], |
| **sr_kwargs) |
| else: |
| self.superresolution = None |
|
|
| self.bcg_synthesis = None |
|
|
| |
| def forward( |
| self, |
| planes=None, |
| |
| c=None, |
| ws=None, |
| ray_origins=None, |
| ray_directions=None, |
| z_bcg=None, |
| neural_rendering_resolution=None, |
| update_emas=False, |
| cache_backbone=False, |
| use_cached_backbone=False, |
| return_meta=False, |
| return_raw_only=False, |
| sample_ray_only=False, |
| fg_bbox=None, |
| **synthesis_kwargs): |
|
|
| cam2world_matrix = c[:, :16].reshape(-1, 4, 4) |
| |
| |
| intrinsics = c[:, 16:25].reshape(-1, 3, 3) |
|
|
| if neural_rendering_resolution is None: |
| neural_rendering_resolution = self.neural_rendering_resolution |
| else: |
| self.neural_rendering_resolution = neural_rendering_resolution |
|
|
| if ray_directions is None: |
| H = W = self.neural_rendering_resolution |
| |
| |
| |
|
|
| if sample_ray_only: |
| ray_origins, ray_directions, ray_bboxes = self.ray_sampler( |
| cam2world_matrix, intrinsics, |
| self.rendering_kwargs.get( 'patch_rendering_resolution' ), |
| self.neural_rendering_resolution, fg_bbox) |
|
|
| |
| ret_dict = { |
| 'ray_origins': ray_origins, |
| 'ray_directions': ray_directions, |
| 'ray_bboxes': ray_bboxes, |
| } |
|
|
| return ret_dict |
|
|
| else: |
| ray_origins, ray_directions, _ = self.ray_sampler( |
| cam2world_matrix, intrinsics, self.neural_rendering_resolution, |
| self.neural_rendering_resolution) |
|
|
| else: |
| assert ray_origins is not None |
| H = W = int(ray_directions.shape[1]** |
| 0.5) |
|
|
| |
| if planes is None: |
| assert self.planes is not None |
| planes = self.planes.repeat_interleave(c.shape[0], dim=0) |
| return_sampling_details_flag = self.rendering_kwargs.get( |
| 'return_sampling_details_flag', False) |
|
|
| if return_sampling_details_flag: |
| return_meta = True |
|
|
| |
| N, M, _ = ray_origins.shape |
|
|
| |
| if planes.shape[1] == 3 * 2 * self.decoder_in_chans: |
| |
| |
| triplane_bg = True |
| |
| |
| |
|
|
| else: |
| triplane_bg = False |
|
|
| |
|
|
| |
| |
| |
|
|
| |
| planes = planes.reshape( |
| len(planes), |
| 3, |
| -1, |
| planes.shape[-2], |
| planes.shape[-1]) |
|
|
| |
| rendering_details = self.renderer(planes, |
| self.decoder, |
| ray_origins, |
| ray_directions, |
| self.rendering_kwargs, |
| return_meta=return_meta) |
|
|
| feature_samples, depth_samples, weights_samples = ( |
| rendering_details[k] |
| for k in ['feature_samples', 'depth_samples', 'weights_samples']) |
|
|
| if return_sampling_details_flag: |
| shape_synthesized = rendering_details['shape_synthesized'] |
| else: |
| shape_synthesized = None |
|
|
| |
| feature_image = feature_samples.permute(0, 2, 1).reshape( |
| N, feature_samples.shape[-1], H, |
| W).contiguous() |
| depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
| weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 |
| if triplane_bg: |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| feature_image = (1 - mask_image) * rendering_details[ |
| 'bg_ret_dict']['rgb_final'] + feature_image |
|
|
| rgb_image = feature_image[:, :3] |
|
|
| |
| if self.superresolution is not None and not return_raw_only: |
| |
|
|
| if ws is not None and ws.ndim == 2: |
| ws = ws.unsqueeze( |
| 1)[:, -1:, :] |
|
|
| sr_image = self.superresolution( |
| rgb=rgb_image, |
| x=feature_image, |
| base_x=rgb_image, |
| ws=ws, |
| noise_mode=self. |
| rendering_kwargs['superresolution_noise_mode'], |
| **{ |
| k: synthesis_kwargs[k] |
| for k in synthesis_kwargs.keys() if k != 'noise_mode' |
| }) |
| else: |
| |
| sr_image = None |
|
|
| if shape_synthesized is not None: |
| shape_synthesized.update({ |
| 'image_depth': depth_image, |
| }) |
|
|
| ret_dict = { |
| 'feature_image': feature_image, |
| |
| 'image_raw': rgb_image, |
| 'image_depth': depth_image, |
| 'weights_samples': weights_samples, |
| |
| |
| 'shape_synthesized': shape_synthesized, |
| "image_mask": mask_image, |
| } |
|
|
| if sr_image is not None: |
| ret_dict.update({ |
| 'image_sr': sr_image, |
| }) |
|
|
| if return_meta: |
| ret_dict.update({ |
| 'feature_volume': |
| rendering_details['feature_volume'], |
| 'all_coords': |
| rendering_details['all_coords'], |
| 'weights': |
| rendering_details['weights'], |
| }) |
|
|
| return ret_dict |
|
|
|
|
| class Triplane_fg_bg_plane(Triplane): |
| |
|
|
| def __init__(self, |
| c_dim=25, |
| img_resolution=128, |
| img_channels=3, |
| out_chans=96, |
| triplane_size=224, |
| rendering_kwargs={}, |
| decoder_in_chans=32, |
| decoder_output_dim=32, |
| sr_num_fp16_res=0, |
| sr_kwargs={}, |
| bcg_synthesis_kwargs={}): |
| super().__init__(c_dim, img_resolution, img_channels, out_chans, |
| triplane_size, rendering_kwargs, decoder_in_chans, |
| decoder_output_dim, sr_num_fp16_res, sr_kwargs, |
| bcg_synthesis_kwargs) |
|
|
| self.bcg_decoder = Decoder( |
| ch=64, |
| out_ch=32, |
| |
| ch_mult=(1, 2), |
| num_res_blocks=2, |
| dropout=0.0, |
| attn_resolutions=(), |
| z_channels=4, |
| resolution=64, |
| in_channels=3, |
| ) |
|
|
| |
| def forward( |
| self, |
| planes, |
| bg_plane, |
| |
| c, |
| ws=None, |
| z_bcg=None, |
| neural_rendering_resolution=None, |
| update_emas=False, |
| cache_backbone=False, |
| use_cached_backbone=False, |
| return_meta=False, |
| return_raw_only=False, |
| **synthesis_kwargs): |
|
|
| |
| if planes is None: |
| assert self.planes is not None |
| planes = self.planes.repeat_interleave(c.shape[0], dim=0) |
| return_sampling_details_flag = self.rendering_kwargs.get( |
| 'return_sampling_details_flag', False) |
|
|
| if return_sampling_details_flag: |
| return_meta = True |
|
|
| cam2world_matrix = c[:, :16].reshape(-1, 4, 4) |
| |
| |
| intrinsics = c[:, 16:25].reshape(-1, 3, 3) |
|
|
| if neural_rendering_resolution is None: |
| neural_rendering_resolution = self.neural_rendering_resolution |
| else: |
| self.neural_rendering_resolution = neural_rendering_resolution |
|
|
| H = W = self.neural_rendering_resolution |
| |
| ray_origins, ray_directions, _ = self.ray_sampler( |
| cam2world_matrix, intrinsics, neural_rendering_resolution) |
|
|
| |
| N, M, _ = ray_origins.shape |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| planes = planes.view( |
| len(planes), |
| 3, |
| -1, |
| planes.shape[-2], |
| planes.shape[-1]) |
|
|
| |
| rendering_details = self.renderer(planes, |
| self.decoder, |
| ray_origins, |
| ray_directions, |
| self.rendering_kwargs, |
| return_meta=return_meta) |
|
|
| feature_samples, depth_samples, weights_samples = ( |
| rendering_details[k] |
| for k in ['feature_samples', 'depth_samples', 'weights_samples']) |
|
|
| if return_sampling_details_flag: |
| shape_synthesized = rendering_details['shape_synthesized'] |
| else: |
| shape_synthesized = None |
|
|
| |
| feature_image = feature_samples.permute(0, 2, 1).reshape( |
| N, feature_samples.shape[-1], H, |
| W).contiguous() |
| depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
| weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
|
| bcg_image = self.bcg_decoder(bg_plane) |
| bcg_image = torch.nn.functional.interpolate( |
| bcg_image, |
| size=feature_image.shape[2:], |
| mode='bilinear', |
| align_corners=False, |
| antialias=self.rendering_kwargs['sr_antialias']) |
|
|
| mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 |
|
|
| |
| feature_image = feature_image + (1 - weights_samples) * bcg_image |
|
|
| rgb_image = feature_image[:, :3] |
|
|
| |
| if self.superresolution is not None and not return_raw_only: |
| |
|
|
| if ws is not None and ws.ndim == 2: |
| ws = ws.unsqueeze( |
| 1)[:, -1:, :] |
|
|
| sr_image = self.superresolution( |
| rgb=rgb_image, |
| x=feature_image, |
| base_x=rgb_image, |
| ws=ws, |
| noise_mode=self. |
| rendering_kwargs['superresolution_noise_mode'], |
| **{ |
| k: synthesis_kwargs[k] |
| for k in synthesis_kwargs.keys() if k != 'noise_mode' |
| }) |
| else: |
| |
| sr_image = None |
|
|
| if shape_synthesized is not None: |
| shape_synthesized.update({ |
| 'image_depth': depth_image, |
| }) |
|
|
| ret_dict = { |
| 'feature_image': feature_image, |
| |
| 'image_raw': rgb_image, |
| 'image_depth': depth_image, |
| 'weights_samples': weights_samples, |
| |
| |
| 'shape_synthesized': shape_synthesized, |
| "image_mask": mask_image, |
| } |
|
|
| if sr_image is not None: |
| ret_dict.update({ |
| 'image_sr': sr_image, |
| }) |
|
|
| if return_meta: |
| ret_dict.update({ |
| 'feature_volume': |
| rendering_details['feature_volume'], |
| 'all_coords': |
| rendering_details['all_coords'], |
| 'weights': |
| rendering_details['weights'], |
| }) |
|
|
| return ret_dict |
|
|