| import math |
| import random |
| from einops import rearrange |
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| import numpy as np |
| from tqdm import trange |
|
|
| from functools import partial |
|
|
| from nsr.networks_stylegan2 import Generator as StyleGAN2Backbone |
| from nsr.volumetric_rendering.renderer import ImportanceRenderer, ImportanceRendererfg_bg |
| from nsr.volumetric_rendering.ray_sampler import RaySampler |
| from nsr.triplane import OSGDecoder, Triplane, Triplane_fg_bg_plane |
| |
| |
| from vit.vision_transformer import TriplaneFusionBlockv4_nested, TriplaneFusionBlockv4_nested_init_from_dino_lite, TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, VisionTransformer, TriplaneFusionBlockv4_nested_init_from_dino |
|
|
| from .vision_transformer import Block, VisionTransformer |
| from .utils import trunc_normal_ |
|
|
| from guided_diffusion import dist_util, logger |
|
|
| from pdb import set_trace as st |
|
|
| from ldm.modules.diffusionmodules.model import Encoder, Decoder |
| from utils.torch_utils.components import PixelShuffleUpsample, ResidualBlock, Upsample, PixelUnshuffleUpsample, Conv3x3TriplaneTransformation |
| from utils.torch_utils.distributions.distributions import DiagonalGaussianDistribution |
| from nsr.superresolution import SuperresolutionHybrid2X, SuperresolutionHybrid4X |
|
|
| from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer |
|
|
| from nsr.common_blks import ResMlp |
| from .vision_transformer import * |
|
|
| from dit.dit_models import get_2d_sincos_pos_embed |
| from torch import _assert |
| from itertools import repeat |
| import collections.abc |
|
|
|
|
| |
| def _ntuple(n): |
|
|
| def parse(x): |
| if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
| return tuple(x) |
| return tuple(repeat(x, n)) |
|
|
| return parse |
|
|
|
|
| to_1tuple = _ntuple(1) |
| to_2tuple = _ntuple(2) |
|
|
|
|
| class PatchEmbedTriplane(nn.Module): |
| """ GroupConv patchembeder on triplane |
| """ |
|
|
| def __init__( |
| self, |
| img_size=32, |
| patch_size=2, |
| in_chans=4, |
| embed_dim=768, |
| norm_layer=None, |
| flatten=True, |
| bias=True, |
| ): |
| super().__init__() |
| img_size = to_2tuple(img_size) |
| patch_size = to_2tuple(patch_size) |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.grid_size = (img_size[0] // patch_size[0], |
| img_size[1] // patch_size[1]) |
| self.num_patches = self.grid_size[0] * self.grid_size[1] |
| self.flatten = flatten |
|
|
| self.proj = nn.Conv2d(in_chans, |
| embed_dim * 3, |
| kernel_size=patch_size, |
| stride=patch_size, |
| bias=bias, |
| groups=3) |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| _assert( |
| H == self.img_size[0], |
| f"Input image height ({H}) doesn't match model ({self.img_size[0]})." |
| ) |
| _assert( |
| W == self.img_size[1], |
| f"Input image width ({W}) doesn't match model ({self.img_size[1]})." |
| ) |
| x = self.proj(x) |
|
|
| x = x.reshape(B, x.shape[1] // 3, 3, x.shape[-2], |
| x.shape[-1]) |
|
|
| if self.flatten: |
| x = x.flatten(2).transpose(1, 2) |
| x = self.norm(x) |
| return x |
|
|
|
|
| class PatchEmbedTriplaneRodin(PatchEmbedTriplane): |
|
|
| def __init__(self, |
| img_size=32, |
| patch_size=2, |
| in_chans=4, |
| embed_dim=768, |
| norm_layer=None, |
| flatten=True, |
| bias=True): |
| super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, |
| flatten, bias) |
| self.proj = RodinRollOutConv3D_GroupConv(in_chans, |
| embed_dim * 3, |
| kernel_size=patch_size, |
| stride=patch_size, |
| padding=0) |
|
|
|
|
| class ViTTriplaneDecomposed(nn.Module): |
|
|
| def __init__( |
| self, |
| vit_decoder, |
| triplane_decoder: Triplane, |
| cls_token=False, |
| decoder_pred_size=-1, |
| unpatchify_out_chans=-1, |
| |
| channel_multiplier=4, |
| use_fusion_blk=True, |
| fusion_blk_depth=4, |
| fusion_blk=TriplaneFusionBlock, |
| fusion_blk_start=0, |
| ldm_z_channels=4, |
| ldm_embed_dim=4, |
| vae_p=2, |
| token_size=None, |
| w_avg=torch.zeros([512]), |
| patch_size=None, |
| **kwargs, |
| ) -> None: |
| super().__init__() |
| |
| self.superresolution = nn.ModuleDict({}) |
|
|
| self.decomposed_IN = False |
|
|
| self.decoder_pred_3d = None |
| self.transformer_3D_blk = None |
| self.logvar = None |
| self.channel_multiplier = channel_multiplier |
|
|
| self.cls_token = cls_token |
| self.vit_decoder = vit_decoder |
| self.triplane_decoder = triplane_decoder |
|
|
| if patch_size is None: |
| self.patch_size = self.vit_decoder.patch_embed.patch_size |
| else: |
| self.patch_size = patch_size |
|
|
| if isinstance(self.patch_size, tuple): |
| self.patch_size = self.patch_size[0] |
|
|
| |
|
|
| if unpatchify_out_chans == -1: |
| self.unpatchify_out_chans = self.triplane_decoder.out_chans |
| else: |
| self.unpatchify_out_chans = unpatchify_out_chans |
|
|
| |
| if decoder_pred_size == -1: |
| decoder_pred_size = self.patch_size**2 * self.triplane_decoder.out_chans |
|
|
| self.decoder_pred = nn.Linear( |
| self.vit_decoder.embed_dim, |
| decoder_pred_size, |
| |
| |
| bias=True) |
| |
|
|
| |
| self.plane_n = 3 |
|
|
| |
| self.ldm_z_channels = ldm_z_channels |
| self.ldm_embed_dim = ldm_embed_dim |
| self.vae_p = vae_p |
| self.token_size = 16 |
| self.vae_res = self.vae_p * self.token_size |
|
|
| |
| |
| |
| |
|
|
| self.vit_decoder.pos_embed = nn.Parameter( |
| torch.zeros(1, 3 * (self.token_size**2 + self.cls_token), |
| vit_decoder.embed_dim)) |
|
|
| self.fusion_blk_start = fusion_blk_start |
| self.create_fusion_blks(fusion_blk_depth, use_fusion_blk, fusion_blk) |
| |
| |
|
|
| |
| self.register_buffer('w_avg', w_avg) |
| self.rendering_kwargs = self.triplane_decoder.rendering_kwargs |
|
|
|
|
| @torch.inference_mode() |
| def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**16): |
| |
| |
| N, P = points.shape[:2] |
| if planes.ndim == 4: |
| planes = planes.reshape( |
| len(planes), |
| 3, |
| -1, |
| planes.shape[-2], |
| planes.shape[-1]) |
|
|
| |
| outs = [] |
| for i in trange(0, points.shape[1], chunk_size): |
| chunk_points = points[:, i:i+chunk_size] |
|
|
| |
| |
| chunk_out = self.triplane_decoder.renderer._run_model( |
| planes=planes, |
| decoder=self.triplane_decoder.decoder, |
| sample_coordinates=chunk_points, |
| sample_directions=torch.zeros_like(chunk_points), |
| options=self.rendering_kwargs, |
| ) |
| |
|
|
| outs.append(chunk_out) |
| torch.cuda.empty_cache() |
| |
| |
|
|
| |
| point_features = { |
| k: torch.cat([out[k] for out in outs], dim=1) |
| for k in outs[0].keys() |
| } |
| return point_features |
|
|
| def triplane_decode_grid(self, vit_decode_out, grid_size, aabb: torch.Tensor = None, **kwargs): |
| |
| |
|
|
| assert isinstance(vit_decode_out, dict) |
| planes = vit_decode_out['latent_after_vit'] |
|
|
| |
| if aabb is None: |
| if 'sampler_bbox_min' in self.rendering_kwargs: |
| aabb = torch.tensor([ |
| [self.rendering_kwargs['sampler_bbox_min']] * 3, |
| [self.rendering_kwargs['sampler_bbox_max']] * 3, |
| ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) |
| else: |
| aabb = torch.tensor([ |
| [-self.rendering_kwargs['box_warp']/2] * 3, |
| [self.rendering_kwargs['box_warp']/2] * 3, |
| ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) |
|
|
| assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" |
| N = planes.shape[0] |
|
|
| |
| grid_points = [] |
| for i in range(N): |
| grid_points.append(torch.stack(torch.meshgrid( |
| torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), |
| torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), |
| torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), |
| indexing='ij', |
| ), dim=-1).reshape(-1, 3)) |
| cube_grid = torch.stack(grid_points, dim=0).to(planes.device) |
| |
|
|
| features = self.forward_points(planes, cube_grid) |
|
|
| |
| features = { |
| k: v.reshape(N, grid_size, grid_size, grid_size, -1) |
| for k, v in features.items() |
| } |
|
|
| |
|
|
| return features |
|
|
|
|
| def create_uvit_arch(self): |
| |
| logger.log( |
| f'length of vit_decoder.blocks: {len(self.vit_decoder.blocks)}') |
| for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
| blk.skip_linear = nn.Linear(2 * self.vit_decoder.embed_dim, |
| self.vit_decoder.embed_dim) |
|
|
| |
| nn.init.constant_(blk.skip_linear.weight, 0) |
| if isinstance(blk.skip_linear, |
| nn.Linear) and blk.skip_linear.bias is not None: |
| nn.init.constant_(blk.skip_linear.bias, 0) |
|
|
|
|
| |
|
|
| def vit_decode_backbone(self, latent, img_size): |
| return self.forward_vit_decoder(latent, img_size) |
|
|
| def init_weights(self): |
| |
| p = self.token_size |
| D = self.vit_decoder.pos_embed.shape[-1] |
| grid_size = (3 * p, p) |
| pos_embed = get_2d_sincos_pos_embed(D, |
| grid_size).reshape(3 * p * p, |
| D) |
| self.vit_decoder.pos_embed.data.copy_( |
| torch.from_numpy(pos_embed).float().unsqueeze(0)) |
| logger.log('init pos_embed with sincos') |
|
|
| |
| def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): |
|
|
| vit_decoder_blks = self.vit_decoder.blocks |
| assert len(vit_decoder_blks) == 12, 'ViT-B by default' |
|
|
| nh = self.vit_decoder.blocks[0].attn.num_heads |
| dim = self.vit_decoder.embed_dim |
|
|
| fusion_blk_start = self.fusion_blk_start |
| triplane_fusion_vit_blks = nn.ModuleList() |
|
|
| if fusion_blk_start != 0: |
| for i in range(0, fusion_blk_start): |
| triplane_fusion_vit_blks.append( |
| vit_decoder_blks[i]) |
|
|
| for i in range(fusion_blk_start, len(vit_decoder_blks), |
| fusion_blk_depth): |
| vit_blks_group = vit_decoder_blks[i:i + |
| fusion_blk_depth] |
| triplane_fusion_vit_blks.append( |
| |
| fusion_blk(vit_blks_group, nh, dim, use_fusion_blk)) |
|
|
| self.vit_decoder.blocks = triplane_fusion_vit_blks |
|
|
| def triplane_decode(self, latent, c): |
| ret_dict = self.triplane_decoder(latent, c) |
| ret_dict.update({'latent': latent}) |
| return ret_dict |
|
|
| def triplane_renderer(self, latent, coordinates, directions): |
|
|
| planes = latent.view(len(latent), 3, |
| self.triplane_decoder.decoder_in_chans, |
| latent.shape[-2], |
| latent.shape[-1]) |
|
|
| ret_dict = self.triplane_decoder.renderer.run_model( |
| planes, self.triplane_decoder.decoder, coordinates, directions, |
| self.triplane_decoder.rendering_kwargs) |
| |
| return ret_dict |
|
|
| |
|
|
| |
| def unpatchify_triplane(self, x, p=None, unpatchify_out_chans=None): |
| """ |
| x: (N, L, patch_size**2 * self.out_chans) |
| imgs: (N, self.out_chans, H, W) |
| """ |
| if unpatchify_out_chans is None: |
| unpatchify_out_chans = self.unpatchify_out_chans // 3 |
| |
| if self.cls_token: |
| x = x[:, 1:] |
|
|
| if p is None: |
| p = self.patch_size |
| h = w = int((x.shape[1] // 3)**.5) |
| assert h * w * 3 == x.shape[1] |
|
|
| x = x.reshape(shape=(x.shape[0], 3, h, w, p, p, unpatchify_out_chans)) |
| x = torch.einsum('ndhwpqc->ndchpwq', |
| x) |
| triplanes = x.reshape(shape=(x.shape[0], unpatchify_out_chans * 3, |
| h * p, h * p)) |
| return triplanes |
|
|
| def interpolate_pos_encoding(self, x, w, h): |
| previous_dtype = x.dtype |
| npatch = x.shape[1] - 1 |
| N = self.vit_decoder.pos_embed.shape[1] - 1 |
| |
| |
| return self.vit_decoder.pos_embed |
|
|
| |
| |
| class_pos_embed = pos_embed[:, 0] |
| patch_pos_embed = pos_embed[:, 1:] |
| dim = x.shape[-1] |
| w0 = w // self.patch_size |
| h0 = h // self.patch_size |
| |
| |
| w0, h0 = w0 + 0.1, h0 + 0.1 |
|
|
| |
| |
| |
| |
| |
|
|
| |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
| return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), |
| dim=1).to(previous_dtype) |
|
|
| def forward_vit_decoder(self, x, img_size=None): |
| |
|
|
| |
| |
| if img_size is None: |
| img_size = self.img_size |
|
|
| if self.cls_token: |
| x = x + self.vit_decoder.interpolate_pos_encoding( |
| x, img_size, img_size)[:, :] |
| else: |
| x = x + self.vit_decoder.interpolate_pos_encoding( |
| x, img_size, img_size)[:, 1:] |
|
|
| for blk in self.vit_decoder.blocks: |
| x = blk(x) |
| x = self.vit_decoder.norm(x) |
|
|
| return x |
|
|
| def unpatchify(self, x, p=None, unpatchify_out_chans=None): |
| """ |
| x: (N, L, patch_size**2 * self.out_chans) |
| imgs: (N, self.out_chans, H, W) |
| """ |
| |
| if unpatchify_out_chans is None: |
| unpatchify_out_chans = self.unpatchify_out_chans |
| |
| if self.cls_token: |
| x = x[:, 1:] |
|
|
| if p is None: |
| p = self.patch_size |
| h = w = int(x.shape[1]**.5) |
| assert h * w == x.shape[1] |
|
|
| x = x.reshape(shape=(x.shape[0], h, w, p, p, unpatchify_out_chans)) |
| x = torch.einsum('nhwpqc->nchpwq', x) |
| imgs = x.reshape(shape=(x.shape[0], unpatchify_out_chans, h * p, |
| h * p)) |
| return imgs |
|
|
| def forward(self, latent, c, img_size): |
| latent = self.forward_vit_decoder(latent, img_size) |
|
|
| if self.cls_token: |
| |
| cls_token = latent[:, :1] |
| else: |
| cls_token = None |
|
|
| |
| latent = self.decoder_pred( |
| latent) |
| |
| latent = self.unpatchify( |
| latent) |
|
|
| |
| |
| |
| |
| ret_dict = self.triplane_decoder(planes=latent, c=c) |
| ret_dict.update({'latent': latent, 'cls_token': cls_token}) |
|
|
| return ret_dict |
|
|
|
|
| class VAE_LDM_V4_vit3D_v3_conv3D_depth2_xformer_mha_PEinit_2d_sincos_uvit_RodinRollOutConv_4x4_lite_mlp_unshuffle_4XC_final( |
| ViTTriplaneDecomposed): |
| """ |
| 1. reuse attention proj layer from dino |
| 2. reuse attention; first self then 3D cross attention |
| """ |
| """ 4*4 SR with 2X channels |
| """ |
|
|
| def __init__( |
| self, |
| vit_decoder: VisionTransformer, |
| triplane_decoder: Triplane, |
| cls_token, |
| |
| |
| use_fusion_blk=True, |
| fusion_blk_depth=2, |
| channel_multiplier=4, |
| fusion_blk=TriplaneFusionBlockv3, |
| **kwargs) -> None: |
| super().__init__( |
| vit_decoder, |
| triplane_decoder, |
| cls_token, |
| |
| |
| fusion_blk=fusion_blk, |
| use_fusion_blk=use_fusion_blk, |
| fusion_blk_depth=fusion_blk_depth, |
| channel_multiplier=channel_multiplier, |
| decoder_pred_size=(4 // 1)**2 * |
| int(triplane_decoder.out_chans // 3 * channel_multiplier), |
| **kwargs) |
|
|
| patch_size = vit_decoder.patch_embed.patch_size |
|
|
| self.reparameterization_soft_clamp = False |
|
|
| if isinstance(patch_size, tuple): |
| patch_size = patch_size[0] |
|
|
| |
| unpatchify_out_chans = triplane_decoder.out_chans * 1, |
|
|
| if unpatchify_out_chans == -1: |
| unpatchify_out_chans = triplane_decoder.out_chans * 3 |
|
|
| ldm_z_channels = triplane_decoder.out_chans |
| |
| ldm_embed_dim = triplane_decoder.out_chans |
| ldm_z_channels = ldm_embed_dim = triplane_decoder.out_chans |
|
|
| self.superresolution.update( |
| dict( |
| after_vit_conv=nn.Conv2d( |
| int(triplane_decoder.out_chans * 2), |
| triplane_decoder.out_chans * 2, |
| 3, |
| padding=1), |
| quant_conv=torch.nn.Conv2d(2 * ldm_z_channels, |
| 2 * ldm_embed_dim, 1), |
| ldm_downsample=nn.Linear( |
| 384, |
| |
| self.vae_p * self.vae_p * 3 * self.ldm_z_channels * |
| 2, |
| bias=True), |
| ldm_upsample=nn.Linear(self.vae_p * self.vae_p * |
| self.ldm_z_channels * 1, |
| vit_decoder.embed_dim, |
| bias=True), |
| quant_mlp=Mlp(2 * self.ldm_z_channels, |
| out_features=2 * self.ldm_embed_dim), |
| conv_sr=RodinConv3D4X_lite_mlp_as_residual( |
| int(triplane_decoder.out_chans * channel_multiplier), |
| int(triplane_decoder.out_chans * 1)))) |
|
|
| has_token = bool(self.cls_token) |
| self.vit_decoder.pos_embed = nn.Parameter( |
| torch.zeros(1, 3 * 16 * 16 + has_token, vit_decoder.embed_dim)) |
|
|
| self.init_weights() |
| self.reparameterization_soft_clamp = True |
|
|
| self.create_uvit_arch() |
|
|
| def vae_reparameterization(self, latent, sample_posterior): |
| """input: latent from ViT encoder |
| """ |
| |
| latents3D = self.superresolution['ldm_downsample'](latent) |
|
|
| if self.vae_p > 1: |
| latents3D = self.unpatchify3D( |
| latents3D, |
| p=self.vae_p, |
| unpatchify_out_chans=self.ldm_z_channels * |
| 2) |
| latents3D = latents3D.reshape( |
| latents3D.shape[0], 3, -1, latents3D.shape[-1] |
| ) |
| else: |
| latents3D = latents3D.reshape(latents3D.shape[0], |
| latents3D.shape[1], 3, |
| 2 * self.ldm_z_channels) |
| latents3D = latents3D.permute(0, 2, 1, 3) |
|
|
| |
| |
|
|
| |
| posterior = self.vae_encode(latents3D) |
|
|
| if sample_posterior: |
| latent = posterior.sample() |
| else: |
| latent = posterior.mode() |
|
|
| log_q = posterior.log_p(latent) |
|
|
| |
| |
|
|
| |
| latent_normalized_2Ddiffusion = latent.reshape( |
| latent.shape[0], -1, self.token_size * self.vae_p, |
| self.token_size * self.vae_p) |
| log_q_2Ddiffusion = log_q.reshape( |
| latent.shape[0], -1, self.token_size * self.vae_p, |
| self.token_size * self.vae_p) |
| latent = latent.permute(0, 2, 3, 1) |
|
|
| latent = latent.reshape(latent.shape[0], -1, |
| latent.shape[-1]) |
|
|
| ret_dict = dict( |
| normal_entropy=posterior.normal_entropy(), |
| latent_normalized=latent, |
| latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, |
| log_q_2Ddiffusion=log_q_2Ddiffusion, |
| log_q=log_q, |
| posterior=posterior, |
| latent_name= |
| 'latent_normalized' |
| ) |
|
|
| return ret_dict |
|
|
| def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
| if self.cls_token: |
| cls_token = latent_from_vit[:, :1] |
| else: |
| cls_token = None |
|
|
| |
| latent = self.decoder_pred( |
| latent_from_vit |
| ) |
|
|
| latent = self.unpatchify_triplane( |
| latent, |
| p=4, |
| unpatchify_out_chans=int( |
| self.channel_multiplier * self.unpatchify_out_chans // |
| 3)) |
|
|
| |
| latent = self.superresolution['conv_sr'](latent) |
|
|
| ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
| |
| sr_w_code = self.w_avg |
| assert sr_w_code is not None |
| ret_dict.update( |
| dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( |
| latent_from_vit.shape[0], 0), )) |
|
|
| return ret_dict |
|
|
| def forward_vit_decoder(self, x, img_size=None): |
| |
|
|
| |
| |
| if img_size is None: |
| img_size = self.img_size |
|
|
| |
| |
| x = x + self.interpolate_pos_encoding(x, img_size, |
| img_size)[:, :] |
|
|
| B, L, C = x.shape |
| x = x.view(B, 3, L // 3, C) |
|
|
| skips = [x] |
| assert self.fusion_blk_start == 0 |
|
|
| |
| for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // |
| 2 - 1]: |
| x = blk(x) |
| skips.append(x) |
|
|
| |
| |
| for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - |
| 1:len(self.vit_decoder.blocks) // |
| 2]: |
| x = blk(x) |
|
|
| |
| for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
| x = x + blk.skip_linear(torch.cat([x, skips.pop()], |
| dim=-1)) |
| x = blk(x) |
|
|
| x = self.vit_decoder.norm(x) |
|
|
| |
| x = x.view(B, L, C) |
| return x |
|
|
| def triplane_decode(self, |
| vit_decode_out, |
| c, |
| return_raw_only=False, |
| **kwargs): |
|
|
| if isinstance(vit_decode_out, dict): |
| latent_after_vit, sr_w_code = (vit_decode_out.get(k, None) |
| for k in ('latent_after_vit', |
| 'sr_w_code')) |
|
|
| else: |
| latent_after_vit = vit_decode_out |
| sr_w_code = None |
| vit_decode_out = dict(latent_normalized=latent_after_vit |
| ) |
|
|
| |
| ret_dict = self.triplane_decoder(latent_after_vit, |
| c, |
| ws=sr_w_code, |
| return_raw_only=return_raw_only, |
| **kwargs) |
| ret_dict.update({ |
| 'latent_after_vit': latent_after_vit, |
| **vit_decode_out |
| }) |
|
|
| return ret_dict |
|
|
| def vit_decode_backbone(self, latent, img_size): |
| |
| if isinstance(latent, dict): |
| if 'latent_normalized' not in latent: |
| latent = latent[ |
| 'latent_normalized_2Ddiffusion'] |
| else: |
| latent = latent[ |
| 'latent_normalized'] |
|
|
| |
| if latent.ndim != 3: |
| latent = latent.reshape(latent.shape[0], latent.shape[1] // 3, 3, |
| (self.vae_p * self.token_size)**2).permute( |
| 0, 2, 3, 1) |
| latent = latent.reshape(latent.shape[0], -1, |
| latent.shape[-1]) |
|
|
| assert latent.shape == ( |
| |
| latent.shape[0], |
| 3 * ((self.vae_p * self.token_size)**2), |
| self.ldm_z_channels), f'latent.shape: {latent.shape}' |
|
|
| latent = self.superresolution['ldm_upsample'](latent) |
|
|
| return super().vit_decode_backbone( |
| latent, img_size) |
|
|
|
|
| class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn( |
| ViTTriplaneDecomposed): |
| |
| def __init__( |
| self, |
| vit_decoder: VisionTransformer, |
| triplane_decoder: Triplane_fg_bg_plane, |
| cls_token, |
| |
| |
| use_fusion_blk=True, |
| fusion_blk_depth=2, |
| fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, |
| channel_multiplier=4, |
| ldm_z_channels=4, |
| ldm_embed_dim=4, |
| vae_p=2, |
| **kwargs) -> None: |
| |
| super().__init__( |
| vit_decoder, |
| triplane_decoder, |
| cls_token, |
| |
| channel_multiplier=channel_multiplier, |
| use_fusion_blk=use_fusion_blk, |
| fusion_blk_depth=fusion_blk_depth, |
| fusion_blk=fusion_blk, |
| ldm_z_channels=ldm_z_channels, |
| ldm_embed_dim=ldm_embed_dim, |
| vae_p=vae_p, |
| decoder_pred_size=(4 // 1)**2 * |
| int(triplane_decoder.out_chans // 3 * channel_multiplier), |
| **kwargs) |
|
|
| logger.log( |
| f'length of vit_decoder.blocks: {len(self.vit_decoder.blocks)}') |
|
|
| |
| self.superresolution.update( |
| dict( |
| ldm_downsample=nn.Linear( |
| 384, |
| self.vae_p * self.vae_p * 3 * self.ldm_z_channels * |
| 2, |
| bias=True), |
| ldm_upsample=PatchEmbedTriplane( |
| self.vae_p * self.token_size, |
| self.vae_p, |
| 3 * self.ldm_embed_dim, |
| vit_decoder.embed_dim, |
| bias=True), |
| quant_conv=nn.Conv2d(2 * 3 * self.ldm_z_channels, |
| 2 * self.ldm_embed_dim * 3, |
| kernel_size=1, |
| groups=3), |
| conv_sr=RodinConv3D4X_lite_mlp_as_residual_lite( |
| int(triplane_decoder.out_chans * channel_multiplier), |
| int(triplane_decoder.out_chans * 1)))) |
|
|
| |
| self.init_weights() |
| self.reparameterization_soft_clamp = True |
|
|
| self.create_uvit_arch() |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| def vit_decode(self, latent, img_size, sample_posterior=True): |
|
|
| ret_dict = self.vae_reparameterization(latent, sample_posterior) |
| |
|
|
| latent = self.vit_decode_backbone(ret_dict, img_size) |
| return self.vit_decode_postprocess(latent, ret_dict) |
|
|
| |
| def unpatchify3D(self, x, p, unpatchify_out_chans, plane_n=3): |
| """ |
| x: (N, L, patch_size**2 * self.out_chans) |
| return: 3D latents |
| """ |
|
|
| if self.cls_token: |
| x = x[:, 1:] |
|
|
| h = w = int(x.shape[1]**.5) |
| assert h * w == x.shape[1] |
|
|
| x = x.reshape(shape=(x.shape[0], h, w, p, p, plane_n, |
| unpatchify_out_chans)) |
|
|
| x = torch.einsum( |
| 'nhwpqdc->ndhpwqc', x |
| ) |
|
|
| latents3D = x.reshape(shape=(x.shape[0], plane_n, h * p, h * p, |
| unpatchify_out_chans)) |
| return latents3D |
|
|
| |
| def vae_encode(self, h): |
| |
| |
| |
| B, _, H, W = h.shape |
| moments = self.superresolution['quant_conv'](h) |
|
|
| moments = moments.reshape( |
| B, |
| |
| moments.shape[1] // self.plane_n, |
| |
| self.plane_n, |
| H, |
| W, |
| ) |
|
|
| moments = moments.flatten(-2) |
|
|
| posterior = DiagonalGaussianDistribution( |
| moments, soft_clamp=self.reparameterization_soft_clamp) |
| return posterior |
|
|
| def vae_reparameterization(self, latent, sample_posterior): |
| """input: latent from ViT encoder |
| """ |
| |
| |
| latents3D = self.superresolution['ldm_downsample']( |
| latent) |
|
|
| assert self.vae_p > 1 |
| latents3D = self.unpatchify3D( |
| latents3D, |
| p=self.vae_p, |
| unpatchify_out_chans=self.ldm_z_channels * |
| 2) |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| B, _, H, W, C = latents3D.shape |
| latents3D = latents3D.permute(0, 1, 4, 2, 3).reshape(B, -1, H, |
| W) |
|
|
| |
| posterior = self.vae_encode(latents3D) |
|
|
| if sample_posterior: |
| latent = posterior.sample() |
| else: |
| latent = posterior.mode() |
|
|
| log_q = posterior.log_p(latent) |
|
|
| |
| latent_normalized_2Ddiffusion = latent.reshape( |
| latent.shape[0], -1, self.token_size * self.vae_p, |
| self.token_size * self.vae_p) |
| log_q_2Ddiffusion = log_q.reshape( |
| latent.shape[0], -1, self.token_size * self.vae_p, |
| self.token_size * self.vae_p) |
| |
| latent = latent.permute(0, 2, 3, 1) |
|
|
| latent = latent.reshape(latent.shape[0], -1, |
| latent.shape[-1]) |
|
|
| ret_dict = dict( |
| normal_entropy=posterior.normal_entropy(), |
| latent_normalized=latent, |
| latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, |
| log_q_2Ddiffusion=log_q_2Ddiffusion, |
| log_q=log_q, |
| posterior=posterior, |
| ) |
|
|
| return ret_dict |
|
|
| def vit_decode_backbone(self, latent, img_size): |
| |
| if isinstance(latent, dict): |
| latent = latent['latent_normalized_2Ddiffusion'] |
|
|
| |
| |
| |
|
|
| |
| latent = self.superresolution['ldm_upsample']( |
| latent) |
| |
|
|
| |
| return self.forward_vit_decoder(latent, img_size) |
|
|
| def triplane_decode(self, |
| vit_decode_out, |
| c, |
| return_raw_only=False, |
| **kwargs): |
|
|
| if isinstance(vit_decode_out, dict): |
| latent_after_vit, sr_w_code = (vit_decode_out.get(k, None) |
| for k in ('latent_after_vit', |
| 'sr_w_code')) |
|
|
| else: |
| latent_after_vit = vit_decode_out |
| sr_w_code = None |
| vit_decode_out = dict(latent_normalized=latent_after_vit |
| ) |
|
|
| |
| ret_dict = self.triplane_decoder(latent_after_vit, |
| c, |
| ws=sr_w_code, |
| return_raw_only=return_raw_only, |
| **kwargs) |
| ret_dict.update({ |
| 'latent_after_vit': latent_after_vit, |
| **vit_decode_out |
| }) |
|
|
| return ret_dict |
|
|
| def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
| if self.cls_token: |
| cls_token = latent_from_vit[:, :1] |
| else: |
| cls_token = None |
|
|
| |
| latent = self.decoder_pred( |
| latent_from_vit |
| ) |
|
|
| latent = self.unpatchify_triplane( |
| latent, |
| p=4, |
| unpatchify_out_chans=int( |
| self.channel_multiplier * self.unpatchify_out_chans // |
| 3)) |
|
|
| |
| latent = self.superresolution['conv_sr'](latent) |
|
|
| ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
| |
| sr_w_code = self.w_avg |
| assert sr_w_code is not None |
| ret_dict.update( |
| dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( |
| latent_from_vit.shape[0], 0), )) |
|
|
| return ret_dict |
|
|
| def forward_vit_decoder(self, x, img_size=None): |
| |
|
|
| |
| |
| if img_size is None: |
| img_size = self.img_size |
|
|
| |
| |
| x = x + self.interpolate_pos_encoding(x, img_size, |
| img_size)[:, :] |
|
|
| B, L, C = x.shape |
| x = x.view(B, 3, L // 3, C) |
|
|
| skips = [x] |
| assert self.fusion_blk_start == 0 |
|
|
| |
| for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // |
| 2 - 1]: |
| x = blk(x) |
| skips.append(x) |
|
|
| |
| |
| for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - |
| 1:len(self.vit_decoder.blocks) // |
| 2]: |
| x = blk(x) |
|
|
| |
| for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
| x = x + blk.skip_linear(torch.cat([x, skips.pop()], |
| dim=-1)) |
| x = blk(x) |
|
|
| x = self.vit_decoder.norm(x) |
|
|
| |
| x = x.view(B, L, C) |
| return x |
|
|
|
|
| |
| class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD( |
| RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn): |
|
|
| def __init__(self, |
| vit_decoder: VisionTransformer, |
| triplane_decoder: Triplane_fg_bg_plane, |
| cls_token, |
| normalize_feat=True, |
| sr_ratio=2, |
| use_fusion_blk=True, |
| fusion_blk_depth=2, |
| fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, |
| channel_multiplier=4, |
| **kwargs) -> None: |
| super().__init__(vit_decoder, |
| triplane_decoder, |
| cls_token, |
| |
| use_fusion_blk=use_fusion_blk, |
| fusion_blk_depth=fusion_blk_depth, |
| fusion_blk=fusion_blk, |
| channel_multiplier=channel_multiplier, |
| **kwargs) |
|
|
| for k in [ |
| 'ldm_downsample', |
| |
| ]: |
| del self.superresolution[k] |
|
|
| def vae_reparameterization(self, latent, sample_posterior): |
| |
|
|
| assert self.vae_p > 1 |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| posterior = self.vae_encode(latent) |
|
|
| if sample_posterior: |
| latent = posterior.sample() |
| else: |
| latent = posterior.mode() |
|
|
| log_q = posterior.log_p(latent) |
|
|
| |
| latent_normalized_2Ddiffusion = latent.reshape( |
| latent.shape[0], -1, self.token_size * self.vae_p, |
| self.token_size * self.vae_p) |
| log_q_2Ddiffusion = log_q.reshape( |
| latent.shape[0], -1, self.token_size * self.vae_p, |
| self.token_size * self.vae_p) |
|
|
| latent = latent.permute(0, 2, 3, 1) |
|
|
| latent = latent.reshape(latent.shape[0], -1, |
| latent.shape[-1]) |
|
|
| ret_dict = dict( |
| normal_entropy=posterior.normal_entropy(), |
| latent_normalized=latent, |
| latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, |
| log_q_2Ddiffusion=log_q_2Ddiffusion, |
| log_q=log_q, |
| posterior=posterior, |
| ) |
|
|
| return ret_dict |
|
|
|
|
| class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD_D( |
| RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): |
|
|
| def __init__(self, |
| vit_decoder: VisionTransformer, |
| triplane_decoder: Triplane_fg_bg_plane, |
| cls_token, |
| normalize_feat=True, |
| sr_ratio=2, |
| use_fusion_blk=True, |
| fusion_blk_depth=2, |
| fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, |
| channel_multiplier=4, |
| **kwargs) -> None: |
| super().__init__(vit_decoder, triplane_decoder, cls_token, |
| normalize_feat, sr_ratio, use_fusion_blk, |
| fusion_blk_depth, fusion_blk, channel_multiplier, |
| **kwargs) |
|
|
| self.decoder_pred = None |
|
|
| self.superresolution.update( |
| dict(conv_sr=Decoder( |
| resolution=128, |
| in_channels=3, |
| |
| ch=32, |
| ch_mult=[1, 2, 2, 4], |
| |
| |
| num_res_blocks=1, |
| dropout=0.0, |
| attn_resolutions=[], |
| out_ch=32, |
| |
| z_channels=vit_decoder.embed_dim, |
| ))) |
|
|
| |
| def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
| if self.cls_token: |
| cls_token = latent_from_vit[:, :1] |
| else: |
| cls_token = None |
|
|
| def unflatten_token(x, p=None): |
| B, L, C = x.shape |
| x = x.reshape(B, 3, L // 3, C) |
|
|
| if self.cls_token: |
| x = x[:, :, 1:] |
|
|
| h = w = int((x.shape[2])**.5) |
| assert h * w == x.shape[2] |
|
|
| if p is None: |
| x = x.reshape(shape=(B, 3, h, w, -1)) |
| x = rearrange( |
| x, 'b n h w c->(b n) c h w' |
| ) |
| else: |
| x = x.reshape(shape=(B, 3, h, w, p, p, -1)) |
| x = rearrange( |
| x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' |
| ) |
|
|
| return x |
|
|
| latent = unflatten_token(latent_from_vit) |
|
|
| |
|
|
| |
| latent = self.superresolution['conv_sr'](latent) |
| latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) |
|
|
| ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| return ret_dict |
|
|
| |
|
|
|
|
| class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_lite3DAttn( |
| RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): |
|
|
| def __init__(self, |
| vit_decoder: VisionTransformer, |
| triplane_decoder: Triplane_fg_bg_plane, |
| cls_token, |
| normalize_feat=True, |
| sr_ratio=2, |
| use_fusion_blk=True, |
| fusion_blk_depth=2, |
| fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite, |
| channel_multiplier=4, |
| **kwargs) -> None: |
| super().__init__(vit_decoder, triplane_decoder, cls_token, |
| normalize_feat, sr_ratio, use_fusion_blk, |
| fusion_blk_depth, fusion_blk, channel_multiplier, |
| **kwargs) |
| |
| |
| |
| |
| self.decoder_pred = nn.Linear(self.vit_decoder.embed_dim // 3, |
| 2048, |
| bias=True) |
|
|
| |
| self.superresolution.update( |
| dict(ldm_upsample=PatchEmbedTriplaneRodin( |
| self.vae_p * self.token_size, |
| self.vae_p, |
| 3 * self.ldm_embed_dim, |
| vit_decoder.embed_dim // 3, |
| bias=True))) |
|
|
| |
| has_token = bool(self.cls_token) |
| self.vit_decoder.pos_embed = nn.Parameter( |
| torch.zeros(1, 16 * 16 + has_token, vit_decoder.embed_dim)) |
|
|
| def forward(self, latent, c, img_size): |
|
|
| latent_normalized = self.vit_decode(latent, img_size) |
| return self.triplane_decode(latent_normalized, c) |
|
|
| def vae_reparameterization(self, latent, sample_posterior): |
| |
|
|
| assert self.vae_p > 1 |
|
|
| |
| |
| posterior = self.vae_encode(latent) |
|
|
| if sample_posterior: |
| latent = posterior.sample() |
| else: |
| latent = posterior.mode() |
|
|
| log_q = posterior.log_p(latent) |
|
|
| |
| latent_normalized_2Ddiffusion = latent.reshape( |
| latent.shape[0], -1, self.token_size * self.vae_p, |
| self.token_size * self.vae_p) |
| log_q_2Ddiffusion = log_q.reshape( |
| latent.shape[0], -1, self.token_size * self.vae_p, |
| self.token_size * self.vae_p) |
|
|
| |
|
|
| |
| latent = latent.permute(0, 3, 1, 2) |
| latent = latent.reshape(*latent.shape[:2], -1) |
|
|
| ret_dict = dict( |
| normal_entropy=posterior.normal_entropy(), |
| latent_normalized=latent, |
| latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, |
| log_q_2Ddiffusion=log_q_2Ddiffusion, |
| log_q=log_q, |
| posterior=posterior, |
| ) |
|
|
| return ret_dict |
|
|
| def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
| if self.cls_token: |
| cls_token = latent_from_vit[:, :1] |
| else: |
| cls_token = None |
|
|
| B, N, C = latent_from_vit.shape |
| latent_from_vit = latent_from_vit.reshape(B, N, C // 3, 3).permute( |
| 0, 3, 1, 2) |
|
|
| |
|
|
| |
| latent = self.decoder_pred( |
| latent_from_vit |
| ) |
|
|
| latent = latent.reshape(B, 3 * N, -1) |
|
|
| latent = self.unpatchify_triplane( |
| latent, |
| p=4, |
| unpatchify_out_chans=int( |
| self.channel_multiplier * self.unpatchify_out_chans // |
| 3)) |
|
|
| |
| latent = self.superresolution['conv_sr'](latent) |
|
|
| ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
| |
| sr_w_code = self.w_avg |
| assert sr_w_code is not None |
| ret_dict.update( |
| dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( |
| latent_from_vit.shape[0], 0), )) |
|
|
| return ret_dict |
|
|
| def vit_decode_backbone(self, latent, img_size): |
| |
| if isinstance(latent, dict): |
| latent = latent['latent_normalized_2Ddiffusion'] |
|
|
| |
| |
| |
|
|
| |
| latent = self.superresolution['ldm_upsample']( |
| latent) |
| |
|
|
| B, N3, C = latent.shape |
| latent = latent.reshape(B, 3, N3 // 3, |
| C).permute(0, 2, 3, 1) |
| latent = latent.reshape(*latent.shape[:2], -1) |
|
|
| |
| return self.forward_vit_decoder(latent, img_size) |
|
|
| def forward_vit_decoder(self, x, img_size=None): |
| |
|
|
| |
| |
| if img_size is None: |
| img_size = self.img_size |
|
|
| |
| x = x + self.interpolate_pos_encoding(x, img_size, |
| img_size)[:, :] |
|
|
| B, L, C = x.shape |
|
|
| |
| |
|
|
| skips = [x] |
| assert self.fusion_blk_start == 0 |
|
|
| |
| for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // |
| 2 - 1]: |
| x = blk(x) |
| skips.append(x) |
|
|
| |
| |
| for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - |
| 1:len(self.vit_decoder.blocks) // |
| 2]: |
| x = blk(x) |
|
|
| |
| for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
| x = x + blk.skip_linear(torch.cat([x, skips.pop()], |
| dim=-1)) |
| x = blk(x) |
|
|
| x = self.vit_decoder.norm(x) |
|
|
| |
| x = x.view(B, L, C) |
| return x |
|
|
| def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): |
|
|
| vit_decoder_blks = self.vit_decoder.blocks |
| assert len(vit_decoder_blks) == 12, 'ViT-B by default' |
|
|
| nh = self.vit_decoder.blocks[ |
| 0].attn.num_heads // 3 |
| dim = self.vit_decoder.embed_dim // 3 |
|
|
| fusion_blk_start = self.fusion_blk_start |
| triplane_fusion_vit_blks = nn.ModuleList() |
|
|
| if fusion_blk_start != 0: |
| for i in range(0, fusion_blk_start): |
| triplane_fusion_vit_blks.append( |
| vit_decoder_blks[i]) |
|
|
| for i in range(fusion_blk_start, len(vit_decoder_blks), |
| fusion_blk_depth): |
| vit_blks_group = vit_decoder_blks[i:i + |
| fusion_blk_depth] |
| triplane_fusion_vit_blks.append( |
| |
| fusion_blk(vit_blks_group, nh, dim, use_fusion_blk)) |
|
|
| self.vit_decoder.blocks = triplane_fusion_vit_blks |
| |
|
|
|
|
| |
| class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder_S( |
| RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn): |
|
|
| def __init__( |
| self, |
| vit_decoder: VisionTransformer, |
| triplane_decoder: Triplane_fg_bg_plane, |
| cls_token, |
| normalize_feat=True, |
| sr_ratio=2, |
| use_fusion_blk=True, |
| fusion_blk_depth=2, |
| fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, |
| channel_multiplier=4, |
| **kwargs) -> None: |
| super().__init__( |
| vit_decoder, |
| triplane_decoder, |
| cls_token, |
| use_fusion_blk=use_fusion_blk, |
| fusion_blk_depth=fusion_blk_depth, |
| fusion_blk=fusion_blk, |
| channel_multiplier=channel_multiplier, |
| patch_size=-1, |
| token_size=2, |
| **kwargs) |
| self.D_roll_out_input = False |
|
|
| for k in [ |
| 'ldm_downsample', |
| |
| ]: |
| del self.superresolution[k] |
|
|
| self.decoder_pred = None |
| self.superresolution.update( |
| dict( |
| conv_sr=Decoder( |
| resolution=128, |
| |
| in_channels=3, |
| |
| ch=32, |
| |
| ch_mult=[1, 2, 2, 4], |
| |
| |
| |
| |
| num_res_blocks=1, |
| dropout=0.0, |
| attn_resolutions=[], |
| out_ch=32, |
| |
| z_channels=vit_decoder.embed_dim, |
| |
| ), |
| |
| )) |
|
|
| |
| for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
| del blk.skip_linear |
|
|
| @torch.inference_mode() |
| def forward_points(self, |
| planes, |
| points: torch.Tensor, |
| chunk_size: int = 2**16): |
| |
| |
| N, P = points.shape[:2] |
| if planes.ndim == 4: |
| planes = planes.reshape( |
| len(planes), |
| 3, |
| -1, |
| planes.shape[-2], |
| planes.shape[-1]) |
|
|
| |
| outs = [] |
| for i in trange(0, points.shape[1], chunk_size): |
| chunk_points = points[:, i:i + chunk_size] |
|
|
| |
| |
| chunk_out = self.triplane_decoder.renderer._run_model( |
| planes=planes, |
| decoder=self.triplane_decoder.decoder, |
| sample_coordinates=chunk_points, |
| sample_directions=torch.zeros_like(chunk_points), |
| options=self.rendering_kwargs, |
| ) |
| |
|
|
| outs.append(chunk_out) |
| torch.cuda.empty_cache() |
|
|
| |
|
|
| |
| point_features = { |
| k: torch.cat([out[k] for out in outs], dim=1) |
| for k in outs[0].keys() |
| } |
| return point_features |
|
|
| def triplane_decode_grid(self, |
| vit_decode_out, |
| grid_size, |
| aabb: torch.Tensor = None, |
| **kwargs): |
| |
| |
|
|
| assert isinstance(vit_decode_out, dict) |
| planes = vit_decode_out['latent_after_vit'] |
|
|
| |
| if aabb is None: |
| if 'sampler_bbox_min' in self.rendering_kwargs: |
| aabb = torch.tensor([ |
| [self.rendering_kwargs['sampler_bbox_min']] * 3, |
| [self.rendering_kwargs['sampler_bbox_max']] * 3, |
| ], |
| device=planes.device, |
| dtype=planes.dtype).unsqueeze(0).repeat( |
| planes.shape[0], 1, 1) |
| else: |
| aabb = torch.tensor( |
| [ |
| [-self.rendering_kwargs['box_warp'] / 2] * 3, |
| [self.rendering_kwargs['box_warp'] / 2] * 3, |
| ], |
| device=planes.device, |
| dtype=planes.dtype).unsqueeze(0).repeat( |
| planes.shape[0], 1, 1) |
|
|
| assert planes.shape[0] == aabb.shape[ |
| 0], "Batch size mismatch for planes and aabb" |
| N = planes.shape[0] |
|
|
| |
| grid_points = [] |
| for i in range(N): |
| grid_points.append( |
| torch.stack(torch.meshgrid( |
| torch.linspace(aabb[i, 0, 0], |
| aabb[i, 1, 0], |
| grid_size, |
| device=planes.device), |
| torch.linspace(aabb[i, 0, 1], |
| aabb[i, 1, 1], |
| grid_size, |
| device=planes.device), |
| torch.linspace(aabb[i, 0, 2], |
| aabb[i, 1, 2], |
| grid_size, |
| device=planes.device), |
| indexing='ij', |
| ), |
| dim=-1).reshape(-1, 3)) |
| cube_grid = torch.stack(grid_points, dim=0).to(planes.device) |
| |
|
|
| features = self.forward_points(planes, cube_grid) |
|
|
| |
| features = { |
| k: v.reshape(N, grid_size, grid_size, grid_size, -1) |
| for k, v in features.items() |
| } |
|
|
| |
|
|
| return features |
|
|
| def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): |
| |
| pass |
|
|
| def forward_vit_decoder(self, x, img_size=None): |
| |
| return self.vit_decoder(x) |
|
|
| def vit_decode_backbone(self, latent, img_size): |
| |
| if isinstance(latent, dict): |
| latent = latent['latent_normalized_2Ddiffusion'] |
|
|
| |
| |
| |
|
|
| |
| |
| latent = self.superresolution['ldm_upsample']( |
| latent) |
| |
|
|
| |
| return self.forward_vit_decoder(latent, img_size) |
|
|
| def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
| if self.cls_token: |
| cls_token = latent_from_vit[:, :1] |
| else: |
| cls_token = None |
|
|
| def unflatten_token(x, p=None): |
| B, L, C = x.shape |
| x = x.reshape(B, 3, L // 3, C) |
|
|
| if self.cls_token: |
| x = x[:, :, 1:] |
|
|
| h = w = int((x.shape[2])**.5) |
| assert h * w == x.shape[2] |
|
|
| if p is None: |
| x = x.reshape(shape=(B, 3, h, w, -1)) |
| if not self.D_roll_out_input: |
| x = rearrange( |
| x, 'b n h w c->(b n) c h w' |
| ) |
| else: |
| x = rearrange( |
| x, 'b n h w c->b c h (n w)' |
| ) |
| else: |
| x = x.reshape(shape=(B, 3, h, w, p, p, -1)) |
| if self.D_roll_out_input: |
| x = rearrange( |
| x, 'b n h w p1 p2 c->b c (h p1) (n w p2)' |
| ) |
| else: |
| x = rearrange( |
| x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' |
| ) |
|
|
| return x |
|
|
| latent = unflatten_token( |
| latent_from_vit) |
|
|
| |
| |
|
|
| |
|
|
| |
| latent = self.superresolution['conv_sr'](latent) |
| if not self.D_roll_out_input: |
| latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) |
| else: |
| latent = rearrange(latent, 'b c h (n w)->b (n c) h w', n=3) |
|
|
| ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| return ret_dict |
|
|
| def vae_reparameterization(self, latent, sample_posterior): |
| |
|
|
| assert self.vae_p > 1 |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| posterior = self.vae_encode(latent) |
|
|
| if sample_posterior: |
| latent = posterior.sample() |
| else: |
| latent = posterior.mode() |
|
|
| log_q = posterior.log_p(latent) |
|
|
| |
| latent_normalized_2Ddiffusion = latent.reshape( |
| latent.shape[0], -1, self.token_size * self.vae_p, |
| self.token_size * self.vae_p) |
| log_q_2Ddiffusion = log_q.reshape( |
| latent.shape[0], -1, self.token_size * self.vae_p, |
| self.token_size * self.vae_p) |
| |
|
|
| latent = latent.permute(0, 2, 3, 1) |
|
|
| latent = latent.reshape(latent.shape[0], -1, |
| latent.shape[-1]) |
|
|
| ret_dict = dict( |
| normal_entropy=posterior.normal_entropy(), |
| latent_normalized=latent, |
| latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, |
| log_q_2Ddiffusion=log_q_2Ddiffusion, |
| log_q=log_q, |
| posterior=posterior, |
| ) |
|
|
| return ret_dict |
|
|
|
|
| |
|
|
|
|
| class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout( |
| RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): |
|
|
| def __init__( |
| self, |
| vit_decoder: VisionTransformer, |
| triplane_decoder: Triplane_fg_bg_plane, |
| cls_token, |
| normalize_feat=True, |
| sr_ratio=2, |
| use_fusion_blk=True, |
| fusion_blk_depth=2, |
| fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, |
| channel_multiplier=4, |
| **kwargs) -> None: |
| super().__init__(vit_decoder, triplane_decoder, cls_token, |
| normalize_feat, sr_ratio, use_fusion_blk, |
| fusion_blk_depth, fusion_blk, channel_multiplier, |
| **kwargs) |
|
|
|
|
| |
| class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D( |
| RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout |
| ): |
|
|
| def __init__( |
| self, |
| vit_decoder: VisionTransformer, |
| triplane_decoder: Triplane_fg_bg_plane, |
| cls_token, |
| normalize_feat=True, |
| sr_ratio=2, |
| use_fusion_blk=True, |
| fusion_blk_depth=2, |
| fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, |
| channel_multiplier=4, |
| **kwargs) -> None: |
| super().__init__(vit_decoder, triplane_decoder, cls_token, |
| normalize_feat, sr_ratio, use_fusion_blk, |
| fusion_blk_depth, fusion_blk, channel_multiplier, |
| **kwargs) |
|
|
| self.decoder_pred = None |
| self.superresolution.update( |
| dict( |
| conv_sr=Decoder( |
| resolution=128, |
| |
| in_channels=3, |
| |
| ch=32, |
| |
| ch_mult=[1, 2, 2, 4], |
| |
| |
| |
| |
| num_res_blocks=1, |
| dropout=0.0, |
| attn_resolutions=[], |
| out_ch=32, |
| |
| z_channels=vit_decoder.embed_dim, |
| |
| ), |
| |
| )) |
| self.D_roll_out_input = False |
|
|
| |
| def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
|
|
| if self.cls_token: |
| cls_token = latent_from_vit[:, :1] |
| else: |
| cls_token = None |
|
|
| def unflatten_token(x, p=None): |
| B, L, C = x.shape |
| x = x.reshape(B, 3, L // 3, C) |
|
|
| if self.cls_token: |
| x = x[:, :, 1:] |
|
|
| h = w = int((x.shape[2])**.5) |
| assert h * w == x.shape[2] |
|
|
| if p is None: |
| x = x.reshape(shape=(B, 3, h, w, -1)) |
| if not self.D_roll_out_input: |
| x = rearrange( |
| x, 'b n h w c->(b n) c h w' |
| ) |
| else: |
| x = rearrange( |
| x, 'b n h w c->b c h (n w)' |
| ) |
| else: |
| x = x.reshape(shape=(B, 3, h, w, p, p, -1)) |
| if self.D_roll_out_input: |
| x = rearrange( |
| x, 'b n h w p1 p2 c->b c (h p1) (n w p2)' |
| ) |
| else: |
| x = rearrange( |
| x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' |
| ) |
|
|
| return x |
|
|
| latent = unflatten_token( |
| latent_from_vit) |
|
|
| |
| |
|
|
| |
|
|
| |
| latent = self.superresolution['conv_sr'](latent) |
| if not self.D_roll_out_input: |
| latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) |
| else: |
| latent = rearrange(latent, 'b c h (n w)->b (n c) h w', n=3) |
|
|
| ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| return ret_dict |
|
|
| |
|
|
|
|
| class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder( |
| RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D |
| ): |
|
|
| def __init__( |
| self, |
| vit_decoder: VisionTransformer, |
| triplane_decoder: Triplane_fg_bg_plane, |
| cls_token, |
| normalize_feat=True, |
| sr_ratio=2, |
| use_fusion_blk=True, |
| fusion_blk_depth=2, |
| fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, |
| channel_multiplier=4, |
| **kwargs) -> None: |
| super().__init__(vit_decoder, triplane_decoder, cls_token, |
| normalize_feat, sr_ratio, use_fusion_blk, |
| fusion_blk_depth, fusion_blk, channel_multiplier, |
| patch_size=-1, |
| **kwargs) |
|
|
| |
| for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: |
| del blk.skip_linear |
|
|
| @torch.inference_mode() |
| def forward_points(self, |
| planes, |
| points: torch.Tensor, |
| chunk_size: int = 2**16): |
| |
| |
| N, P = points.shape[:2] |
| if planes.ndim == 4: |
| planes = planes.reshape( |
| len(planes), |
| 3, |
| -1, |
| planes.shape[-2], |
| planes.shape[-1]) |
|
|
| |
| outs = [] |
| for i in trange(0, points.shape[1], chunk_size): |
| chunk_points = points[:, i:i + chunk_size] |
|
|
| |
| |
| chunk_out = self.triplane_decoder.renderer._run_model( |
| planes=planes, |
| decoder=self.triplane_decoder.decoder, |
| sample_coordinates=chunk_points, |
| sample_directions=torch.zeros_like(chunk_points), |
| options=self.rendering_kwargs, |
| ) |
| |
|
|
| outs.append(chunk_out) |
| torch.cuda.empty_cache() |
|
|
| |
|
|
| |
| point_features = { |
| k: torch.cat([out[k] for out in outs], dim=1) |
| for k in outs[0].keys() |
| } |
| return point_features |
|
|
| def triplane_decode_grid(self, |
| vit_decode_out, |
| grid_size, |
| aabb: torch.Tensor = None, |
| **kwargs): |
| |
| |
|
|
| assert isinstance(vit_decode_out, dict) |
| planes = vit_decode_out['latent_after_vit'] |
|
|
| |
| if aabb is None: |
| if 'sampler_bbox_min' in self.rendering_kwargs: |
| aabb = torch.tensor([ |
| [self.rendering_kwargs['sampler_bbox_min']] * 3, |
| [self.rendering_kwargs['sampler_bbox_max']] * 3, |
| ], |
| device=planes.device, |
| dtype=planes.dtype).unsqueeze(0).repeat( |
| planes.shape[0], 1, 1) |
| else: |
| aabb = torch.tensor( |
| [ |
| [-self.rendering_kwargs['box_warp'] / 2] * 3, |
| [self.rendering_kwargs['box_warp'] / 2] * 3, |
| ], |
| device=planes.device, |
| dtype=planes.dtype).unsqueeze(0).repeat( |
| planes.shape[0], 1, 1) |
|
|
| assert planes.shape[0] == aabb.shape[ |
| 0], "Batch size mismatch for planes and aabb" |
| N = planes.shape[0] |
|
|
| |
| grid_points = [] |
| for i in range(N): |
| grid_points.append( |
| torch.stack(torch.meshgrid( |
| torch.linspace(aabb[i, 0, 0], |
| aabb[i, 1, 0], |
| grid_size, |
| device=planes.device), |
| torch.linspace(aabb[i, 0, 1], |
| aabb[i, 1, 1], |
| grid_size, |
| device=planes.device), |
| torch.linspace(aabb[i, 0, 2], |
| aabb[i, 1, 2], |
| grid_size, |
| device=planes.device), |
| indexing='ij', |
| ), |
| dim=-1).reshape(-1, 3)) |
| cube_grid = torch.stack(grid_points, dim=0).to(planes.device) |
| |
|
|
| features = self.forward_points(planes, cube_grid) |
|
|
| |
| features = { |
| k: v.reshape(N, grid_size, grid_size, grid_size, -1) |
| for k, v in features.items() |
| } |
|
|
| |
|
|
| return features |
|
|
| def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): |
| |
| pass |
|
|
| def forward_vit_decoder(self, x, img_size=None): |
| |
| return self.vit_decoder(x) |
|
|
| def vit_decode_backbone(self, latent, img_size): |
| return super().vit_decode_backbone(latent, img_size) |
|
|
| |
| def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): |
| return super().vit_decode_postprocess(latent_from_vit, ret_dict) |
|
|
| def vae_reparameterization(self, latent, sample_posterior): |
| return super().vae_reparameterization(latent, sample_posterior) |
|
|