| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Mostly copy-paste from timm library. |
| https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py |
| """ |
| from copy import deepcopy |
| import math |
| from functools import partial |
| from sympy import flatten |
|
|
| import torch |
| import torch.nn as nn |
| from torch import Tensor, pixel_shuffle |
|
|
| from einops import rearrange, repeat |
| from einops.layers.torch import Rearrange |
| from torch.nn.modules import GELU, LayerNorm |
|
|
| |
|
|
| from .utils import trunc_normal_ |
|
|
| from pdb import set_trace as st |
|
|
| try: |
| from xformers.ops import memory_efficient_attention, unbind, fmha |
| from xformers.ops import MemoryEfficientAttentionFlashAttentionOp |
|
|
| XFORMERS_AVAILABLE = True |
| except ImportError: |
| |
| XFORMERS_AVAILABLE = False |
|
|
|
|
| class Attention(nn.Module): |
|
|
| def __init__(self, |
| dim, |
| num_heads=8, |
| qkv_bias=False, |
| qk_scale=None, |
| attn_drop=0., |
| proj_drop=0.): |
| super().__init__() |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| self.scale = qk_scale or head_dim**-0.5 |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def forward(self, x): |
| B, N, C = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, |
| C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| |
| return x |
|
|
|
|
| class MemEffAttention(Attention): |
|
|
| def forward(self, x: Tensor, attn_bias=None) -> Tensor: |
| if not XFORMERS_AVAILABLE: |
| assert attn_bias is None, "xFormers is required for nested tensors usage" |
| return super().forward(x) |
|
|
| B, N, C = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) |
|
|
| q, k, v = unbind(qkv, 2) |
|
|
| x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) |
| |
| x = x.reshape([B, N, C]) |
|
|
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x |
|
|
|
|
| |
| class CrossAttention(nn.Module): |
|
|
| def __init__(self, |
| dim, |
| num_heads=8, |
| qkv_bias=False, |
| qk_scale=None, |
| attn_drop=0., |
| proj_drop=0.): |
| super().__init__() |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| |
| self.scale = qk_scale or head_dim**-0.5 |
|
|
| self.wq = nn.Linear(dim, dim, bias=qkv_bias) |
| self.wk = nn.Linear(dim, dim, bias=qkv_bias) |
| self.wv = nn.Linear(dim, dim, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def forward(self, x): |
|
|
| B, N, C = x.shape |
| q = self.wq(x[:, |
| 0:1, ...]).reshape(B, 1, self.num_heads, |
| C // self.num_heads).permute( |
| 0, 2, 1, |
| 3) |
| k = self.wk(x).reshape(B, N, |
| self.num_heads, C // self.num_heads).permute( |
| 0, 2, 1, 3) |
| v = self.wv(x).reshape(B, N, |
| self.num_heads, C // self.num_heads).permute( |
| 0, 2, 1, 3) |
|
|
| attn = (q @ k.transpose( |
| -2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape( |
| B, 1, C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x |
|
|
|
|
| class Conv3D_Aware_CrossAttention(nn.Module): |
|
|
| def __init__(self, |
| dim, |
| num_heads=8, |
| qkv_bias=False, |
| qk_scale=None, |
| attn_drop=0., |
| proj_drop=0.): |
| super().__init__() |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| |
| self.scale = qk_scale or head_dim**-0.5 |
|
|
| self.wq = nn.Linear(dim, dim, bias=qkv_bias) |
| self.wk = nn.Linear(dim, dim, bias=qkv_bias) |
| self.wv = nn.Linear(dim, dim, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def forward(self, x): |
|
|
| B, group_size, N, C = x.shape |
| p = int(N**0.5) |
| assert p**2 == N, 'check input dim, no [cls] needed here' |
| assert group_size == 3, 'designed for triplane here' |
|
|
| x = x.reshape(B, group_size, p, p, C) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| q_x = torch.empty( |
| B * group_size * N, |
| 1, |
| |
| |
| C, |
| device=x.device) |
| k_x = torch.empty( |
| B * group_size * N, |
| 2 * p, |
| |
| |
| C, |
| device=x.device) |
| v_x = torch.empty_like(k_x) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| index_i, index_j = torch.meshgrid(torch.arange(0, p), |
| torch.arange(0, p), |
| indexing='ij') |
| index_mesh_grid = torch.stack([index_i, index_j], 0).to( |
| x.device).unsqueeze(0).repeat_interleave(B, |
| 0).reshape(B, 2, p, |
| p) |
|
|
| for i in range(group_size): |
| q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( |
| 0, 2, 3, 1, 4).reshape(B * N, 1, C) |
|
|
| |
| plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + |
| 1] |
| plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] |
|
|
| assert plane_yz.shape == plane_zx.shape == ( |
| B, 1, p, p, C), 'check sub plane dimensions' |
|
|
| pooling_plane_yz = torch.gather( |
| plane_yz, |
| dim=2, |
| index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand( |
| -1, -1, -1, p, |
| C)).permute(0, 2, 1, 3, 4) |
| pooling_plane_zx = torch.gather( |
| plane_zx, |
| dim=3, |
| index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand( |
| -1, -1, p, -1, |
| C)).permute(0, 3, 1, 2, 4) |
|
|
| k_x[B * i * N:B * (i + 1) * |
| N] = v_x[B * i * N:B * (i + 1) * N] = torch.cat( |
| [pooling_plane_yz, pooling_plane_zx], |
| dim=2).reshape(B * N, 2 * p, |
| C) |
|
|
| |
| |
| |
|
|
| q = self.wq(q_x).reshape(B * group_size * N, 1, |
| self.num_heads, C // self.num_heads).permute( |
| 0, 2, 1, |
| 3) |
| k = self.wk(k_x).reshape(B * group_size * N, 2 * p, self.num_heads, |
| C // self.num_heads).permute(0, 2, 1, 3) |
| v = self.wv(v_x).reshape(B * group_size * N, 2 * p, self.num_heads, |
| C // self.num_heads).permute(0, 2, 1, 3) |
|
|
| attn = (q @ k.transpose( |
| -2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape( |
| B * 3 * N, 1, |
| C) |
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| |
| x = x.reshape(B, 3, N, C) |
|
|
| return x |
|
|
|
|
| class xformer_Conv3D_Aware_CrossAttention(nn.Module): |
| |
| def __init__(self, |
| dim, |
| num_heads=8, |
| qkv_bias=False, |
| qk_scale=None, |
| attn_drop=0., |
| proj_drop=0.): |
| super().__init__() |
|
|
| |
|
|
| self.num_heads = num_heads |
| self.wq = nn.Linear(dim, dim * 1, bias=qkv_bias) |
| self.w_kv = nn.Linear(dim, dim * 2, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
|
|
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| self.index_mesh_grid = None |
|
|
| def forward(self, x, attn_bias=None): |
|
|
| B, group_size, N, C = x.shape |
| p = int(N**0.5) |
| assert p**2 == N, 'check input dim, no [cls] needed here' |
| assert group_size == 3, 'designed for triplane here' |
|
|
| x = x.reshape(B, group_size, p, p, C) |
|
|
| q_x = torch.empty(B * group_size * N, 1, C, device=x.device) |
| context = torch.empty(B * group_size * N, 2 * p, C, |
| device=x.device) |
|
|
| if self.index_mesh_grid is None: |
| index_i, index_j = torch.meshgrid(torch.arange(0, p), |
| torch.arange(0, p), |
| indexing='ij') |
| index_mesh_grid = torch.stack([index_i, index_j], 0).to( |
| x.device).unsqueeze(0).repeat_interleave(B, 0).reshape( |
| B, 2, p, p) |
| self.index_mesh_grid = index_mesh_grid[0:1] |
| else: |
| index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave( |
| B, 0) |
| assert index_mesh_grid.shape == ( |
| B, 2, p, p), 'check index_mesh_grid dimension' |
|
|
| for i in range(group_size): |
| q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( |
| 0, 2, 3, 1, 4).reshape(B * N, 1, C) |
|
|
| |
| plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + |
| 1] |
| plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] |
|
|
| assert plane_yz.shape == plane_zx.shape == ( |
| B, 1, p, p, C), 'check sub plane dimensions' |
|
|
| pooling_plane_yz = torch.gather( |
| plane_yz, |
| dim=2, |
| index=index_mesh_grid[:, 0:1].reshape(B, 1, N, 1, 1).expand( |
| -1, -1, -1, p, |
| C)).permute(0, 2, 1, 3, 4) |
| pooling_plane_zx = torch.gather( |
| plane_zx, |
| dim=3, |
| index=index_mesh_grid[:, 1:2].reshape(B, 1, 1, N, 1).expand( |
| -1, -1, p, -1, |
| C)).permute(0, 3, 1, 2, 4) |
|
|
| context[B * i * N:B * (i + 1) * N] = torch.cat( |
| [pooling_plane_yz, pooling_plane_zx], |
| dim=2).reshape(B * N, 2 * p, |
| C) |
|
|
| |
|
|
| q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads, |
| C // self.num_heads) |
|
|
| kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2, |
| self.num_heads, C // self.num_heads) |
| k, v = unbind(kv, 2) |
|
|
| x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) |
| |
| x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C) |
|
|
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| return x |
|
|
|
|
| class xformer_Conv3D_Aware_CrossAttention_xygrid( |
| xformer_Conv3D_Aware_CrossAttention): |
| """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention |
| """ |
|
|
| def __init__(self, |
| dim, |
| num_heads=8, |
| qkv_bias=False, |
| qk_scale=None, |
| attn_drop=0.0, |
| proj_drop=0.0): |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, |
| proj_drop) |
|
|
| def forward(self, x, attn_bias=None): |
|
|
| B, group_size, N, C = x.shape |
| p = int(N**0.5) |
| assert p**2 == N, 'check input dim, no [cls] needed here' |
| assert group_size == 3, 'designed for triplane here' |
|
|
| x = x.reshape(B, group_size, p, p, C) |
|
|
| q_x = torch.empty(B * group_size * N, 1, C, device=x.device) |
| context = torch.empty(B * group_size * N, 2 * p, C, |
| device=x.device) |
|
|
| if self.index_mesh_grid is None: |
| index_u, index_v = torch.meshgrid( |
| torch.arange(0, p), torch.arange(0, p), |
| indexing='xy') |
| index_mesh_grid = torch.stack([index_u, index_v], 0).to( |
| x.device).unsqueeze(0).repeat_interleave(B, 0).reshape( |
| B, 2, p, p) |
| self.index_mesh_grid = index_mesh_grid[0:1] |
| else: |
| index_mesh_grid = self.index_mesh_grid.clone().repeat_interleave( |
| B, 0) |
| assert index_mesh_grid.shape == ( |
| B, 2, p, p), 'check index_mesh_grid dimension' |
|
|
| for i in range(group_size): |
| q_x[B * i * N:B * (i + 1) * N] = x[:, i:i + 1].permute( |
| 0, 2, 3, 1, 4).reshape(B * N, 1, C) |
|
|
| |
| plane_yz = x[:, (i + 1) % group_size:(i + 1) % group_size + |
| 1] |
| plane_zx = x[:, (i + 2) % group_size:(i + 2) % group_size + 1] |
|
|
| assert plane_yz.shape == plane_zx.shape == ( |
| B, 1, p, p, C), 'check sub plane dimensions' |
|
|
| pooling_plane_yz = torch.gather( |
| plane_yz, |
| dim=2, |
| index=index_mesh_grid[:, 1:2].reshape(B, 1, N, 1, 1).expand( |
| -1, -1, -1, p, |
| C)).permute(0, 2, 1, 3, 4) |
| pooling_plane_zx = torch.gather( |
| plane_zx, |
| dim=3, |
| index=index_mesh_grid[:, 0:1].reshape(B, 1, 1, N, 1).expand( |
| -1, -1, p, -1, |
| C)).permute(0, 3, 1, 2, 4) |
|
|
| context[B * i * N:B * (i + 1) * N] = torch.cat( |
| [pooling_plane_yz, pooling_plane_zx], |
| dim=2).reshape(B * N, 2 * p, |
| C) |
|
|
| |
| q = self.wq(q_x).reshape(B * group_size * N, 1, self.num_heads, |
| C // self.num_heads) |
|
|
| kv = self.w_kv(context).reshape(B * group_size * N, 2 * p, 2, |
| self.num_heads, C // self.num_heads) |
| k, v = unbind(kv, 2) |
|
|
| x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) |
| |
| x = x.transpose(1, 2).reshape([B * 3 * N, 1, C]).reshape(B, 3, N, C) |
|
|
| x = self.proj(x) |
| x = self.proj_drop(x) |
|
|
| return x |
|
|
|
|
| class xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( |
| xformer_Conv3D_Aware_CrossAttention_xygrid): |
|
|
| def __init__(self, |
| dim, |
| num_heads=8, |
| qkv_bias=False, |
| qk_scale=None, |
| attn_drop=0, |
| proj_drop=0): |
| super().__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, |
| proj_drop) |
|
|
| def forward(self, x, attn_bias=None): |
| |
| B, N, C = x.shape |
| x = x.reshape(B, N, C // 3, 3).permute(0, 3, 1, |
| 2) |
| x_out = super().forward(x, attn_bias) |
| x_out = x_out.permute(0, 2, 3, 1) |
| x_out = x_out.reshape(*x_out.shape[:2], -1) |
| return x_out.contiguous() |
|
|
| class self_cross_attn(nn.Module): |
| def __init__(self, dino_attn, cross_attn, *args, **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
| self.dino_attn = dino_attn |
| self.cross_attn = cross_attn |
| |
| def forward(self, x_norm): |
| y = self.dino_attn(x_norm) + x_norm |
| return self.cross_attn(y) |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
|
|
| class RodinRollOutConv3D(nn.Module): |
| """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention |
| """ |
|
|
| def __init__(self, in_chans, out_chans=None): |
| super().__init__() |
| if out_chans is None: |
| out_chans = in_chans |
|
|
| self.out_chans = out_chans // 3 |
|
|
| self.roll_out_convs = nn.Conv2d(in_chans, |
| self.out_chans, |
| kernel_size=3, |
| padding=1) |
|
|
| def forward(self, x): |
| |
|
|
| B, C3, p, p = x.shape |
| C = C3 // 3 |
| group_size = C3 // C |
| assert group_size == 3 |
|
|
| x = x.reshape(B, 3, C, p, p) |
|
|
| roll_out_x = torch.empty(B, group_size * C, p, 3 * p, |
| device=x.device) |
|
|
| for i in range(group_size): |
| plane_xy = x[:, i] |
|
|
| |
| plane_yz_pooling = x[:, (i + 1) % group_size].mean( |
| dim=-1, keepdim=True).repeat_interleave( |
| p, dim=-1) |
| plane_zx_pooling = x[:, (i + 2) % group_size].mean( |
| dim=-2, keepdim=True).repeat_interleave( |
| p, dim=-2) |
|
|
| roll_out_x[..., i * p:(i + 1) * p] = torch.cat( |
| [plane_xy, plane_yz_pooling, plane_zx_pooling], |
| 1) |
|
|
| x = self.roll_out_convs(roll_out_x) |
|
|
| x = x.reshape(B, self.out_chans, p, 3, p) |
| x = x.permute(0, 3, 1, 2, 4).reshape(B, 3 * self.out_chans, p, |
| p) |
|
|
| return x |
|
|
|
|
| class RodinRollOutConv3D_GroupConv(nn.Module): |
| """implementation wise clearer, but yields identical results with xformer_Conv3D_Aware_CrossAttention |
| """ |
|
|
| def __init__(self, |
| in_chans, |
| out_chans=None, |
| kernel_size=3, |
| stride=1, |
| padding=1): |
| super().__init__() |
| if out_chans is None: |
| out_chans = in_chans |
|
|
| self.roll_out_convs = nn.Conv2d( |
| in_chans * 3, |
| out_chans, |
| kernel_size=kernel_size, |
| groups=3, |
| stride=stride, |
| padding=padding) |
|
|
| |
| def forward(self, x): |
| |
|
|
| B, C3, p, p = x.shape |
| C = C3 // 3 |
| group_size = C3 // C |
| assert group_size == 3 |
|
|
| x = x.reshape(B, 3, C, p, p) |
|
|
| roll_out_x = torch.empty(B, group_size * C * 3, p, p, |
| device=x.device) |
|
|
| for i in range(group_size): |
| plane_xy = x[:, i] |
|
|
| |
| plane_yz_pooling = x[:, (i + 1) % group_size].mean( |
| dim=-1, keepdim=True).repeat_interleave( |
| p, dim=-1) |
| plane_zx_pooling = x[:, (i + 2) % group_size].mean( |
| dim=-2, keepdim=True).repeat_interleave( |
| p, dim=-2) |
|
|
| roll_out_x[:, i * 3 * C:(i + 1) * 3 * C] = torch.cat( |
| [plane_xy, plane_yz_pooling, plane_zx_pooling], |
| 1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| x = self.roll_out_convs(roll_out_x) |
|
|
| return x |
|
|
|
|
| class RodinRollOut_GroupConv_noConv3D(nn.Module): |
| """only roll out and do Conv on individual planes |
| """ |
|
|
| def __init__(self, |
| in_chans, |
| out_chans=None, |
| kernel_size=3, |
| stride=1, |
| padding=1): |
| super().__init__() |
| if out_chans is None: |
| out_chans = in_chans |
|
|
| self.roll_out_inplane_conv = nn.Conv2d( |
| in_chans, |
| out_chans, |
| kernel_size=kernel_size, |
| groups=3, |
| stride=stride, |
| padding=padding) |
|
|
| def forward(self, x): |
| x = self.roll_out_inplane_conv(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
|
|
| class RodinConv3D_SynthesisLayer_mlp_unshuffle_as_residual(nn.Module): |
|
|
| def __init__(self, in_chans, out_chans) -> None: |
| super().__init__() |
|
|
| self.act = nn.LeakyReLU(inplace=True) |
| self.conv = nn.Sequential( |
| RodinRollOutConv3D_GroupConv(in_chans, out_chans), |
| nn.LeakyReLU(inplace=True), |
| ) |
|
|
| self.out_chans = out_chans |
| if in_chans != out_chans: |
| |
| self.short_cut = nn.Linear( |
| in_chans // 3, |
| out_chans // 3 * 4 * 4, |
| bias=True) |
|
|
| |
| else: |
| self.short_cut = None |
|
|
| def shortcut_unpatchify_triplane(self, |
| x, |
| p=None, |
| unpatchify_out_chans=None): |
| """separate triplane version; x shape: B (3*257) 768 |
| """ |
|
|
| assert self.short_cut is not None |
|
|
| |
| B, C3, h, w = x.shape |
| assert h == w |
| L = h * w |
| x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3, |
| 1) |
|
|
| x = self.short_cut(x) |
|
|
| p = h * 4 |
|
|
| x = x.reshape(shape=(B, 3, h, w, p, p, unpatchify_out_chans)) |
| x = torch.einsum('ndhwpqc->ndchpwq', |
| x) |
| x = x.reshape(shape=(B, 3 * self.out_chans, h * p, h * p)) |
| return x |
|
|
| def forward(self, feats): |
|
|
| if self.short_cut is not None: |
| res_feats = self.shortcut_unpatchify_triplane(feats) |
| else: |
| res_feats = feats |
|
|
| |
|
|
| feats = res_feats + self.conv(feats) |
| return self.act(feats) |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
|
|
|
|
| |
| class RodinConv3D_SynthesisLayer(nn.Module): |
|
|
| def __init__(self, in_chans, out_chans) -> None: |
| super().__init__() |
| |
| |
|
|
| self.act = nn.LeakyReLU(inplace=True) |
| self.conv = nn.Sequential( |
| RodinRollOutConv3D_GroupConv(in_chans, out_chans), |
| nn.LeakyReLU(inplace=True), |
| ) |
|
|
| if in_chans != out_chans: |
| self.short_cut = RodinRollOutConv3D_GroupConv(in_chans, out_chans) |
| else: |
| self.short_cut = None |
|
|
| def forward(self, feats): |
| feats_out = self.conv(feats) |
| if self.short_cut is not None: |
| |
| feats_out = self.short_cut( |
| feats |
| ) + feats_out |
| |
| else: |
| feats_out = feats_out + feats |
| return feats_out |
|
|
|
|
| class RodinRollOutConv3DSR2X(nn.Module): |
|
|
| def __init__(self, in_chans, **kwargs) -> None: |
| super().__init__() |
| self.conv3D = RodinRollOutConv3D_GroupConv(in_chans) |
| |
| self.act = nn.LeakyReLU(inplace=True) |
| self.input_resolution = 224 |
|
|
| def forward(self, x): |
| |
| B, C3, p, p = x.shape |
| C = C3 // 3 |
| group_size = C3 // C |
|
|
| assert group_size == 3 |
| |
| |
| assert group_size == 3, 'designed for triplane here' |
|
|
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
| p) |
|
|
| if x.shape[-1] != self.input_resolution: |
| x = torch.nn.functional.interpolate(x, |
| size=(self.input_resolution, |
| self.input_resolution), |
| mode='bilinear', |
| align_corners=False, |
| antialias=True) |
|
|
| x = x + self.conv3D(x) |
|
|
| return x |
|
|
|
|
| class RodinRollOutConv3DSR4X_lite(nn.Module): |
|
|
| def __init__(self, in_chans, input_resolutiopn=256, **kwargs) -> None: |
| super().__init__() |
| self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans) |
| self.conv3D_1 = RodinRollOutConv3D_GroupConv(in_chans) |
|
|
| self.act = nn.LeakyReLU(inplace=True) |
| self.input_resolution = input_resolutiopn |
|
|
| def forward(self, x): |
| |
| B, C3, p, p = x.shape |
| C = C3 // 3 |
| group_size = C3 // C |
|
|
| assert group_size == 3 |
| |
| |
| assert group_size == 3, 'designed for triplane here' |
|
|
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
| p) |
|
|
| if x.shape[-1] != self.input_resolution: |
| x = torch.nn.functional.interpolate(x, |
| size=(self.input_resolution, |
| self.input_resolution), |
| mode='bilinear', |
| align_corners=False, |
| antialias=True) |
|
|
| |
| |
| |
|
|
| x = x + self.act(self.conv3D_0(x)) |
| x = x + self.act(self.conv3D_1(x)) |
|
|
| |
|
|
| return x |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
|
|
| class RodinConv3D4X_lite_mlp_as_residual(nn.Module): |
| """lite 4X version, with MLP unshuffle to change the dimention |
| """ |
|
|
| def __init__(self, |
| in_chans, |
| out_chans, |
| input_resolution=256, |
| interp_mode='bilinear', |
| bcg_triplane=False) -> None: |
| super().__init__() |
|
|
| self.interp_mode = interp_mode |
|
|
| self.act = nn.LeakyReLU(inplace=True) |
|
|
| self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, out_chans) |
| self.conv3D_1 = RodinRollOutConv3D_GroupConv(out_chans, out_chans) |
| self.bcg_triplane = bcg_triplane |
| if bcg_triplane: |
| self.conv3D_1_bg = RodinRollOutConv3D_GroupConv( |
| out_chans, out_chans) |
|
|
| self.act = nn.LeakyReLU(inplace=True) |
| self.input_resolution = input_resolution |
|
|
| self.out_chans = out_chans |
| if in_chans != out_chans: |
| self.short_cut = nn.Linear( |
| in_chans // 3, |
| out_chans // 3, |
| bias=True) |
| else: |
| self.short_cut = None |
|
|
| def shortcut_unpatchify_triplane(self, x, p=None): |
| """separate triplane version; x shape: B (3*257) 768 |
| """ |
|
|
| assert self.short_cut is not None |
|
|
| B, C3, h, w = x.shape |
| assert h == w |
| L = h * w |
| x = x.reshape(B, C3 // 3, 3, L).permute(0, 2, 3, |
| 1) |
|
|
| x = self.short_cut(x) |
|
|
| x = x.permute(0, 1, 3, 2) |
| x = x.reshape(shape=(B, self.out_chans, h, w)) |
|
|
| |
| if w != self.input_resolution: |
| x = torch.nn.functional.interpolate( |
| x, |
| size=(self.input_resolution, self.input_resolution), |
| mode='bilinear', |
| align_corners=False, |
| antialias=True) |
|
|
| return x |
|
|
| def interpolate(self, feats): |
| if self.interp_mode == 'bilinear': |
| return torch.nn.functional.interpolate( |
| feats, |
| size=(self.input_resolution, self.input_resolution), |
| mode='bilinear', |
| align_corners=False, |
| antialias=True) |
| else: |
| return torch.nn.functional.interpolate( |
| feats, |
| size=(self.input_resolution, self.input_resolution), |
| mode='nearest', |
| ) |
|
|
| def forward(self, x): |
|
|
| |
| B, C3, p, p = x.shape |
| C = C3 // 3 |
|
|
| if self.short_cut is not None: |
| res_feats = self.shortcut_unpatchify_triplane(x) |
| else: |
| res_feats = x |
| if res_feats.shape[-1] != self.input_resolution: |
| res_feats = self.interpolate(res_feats) |
| """following forward code copied from lite4x version |
| """ |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
| p) |
|
|
| if x.shape[-1] != self.input_resolution: |
| x = self.interpolate(x) |
|
|
| x0 = res_feats + self.act(self.conv3D_0(x)) |
| x = x0 + self.act(self.conv3D_1(x0)) |
| if self.bcg_triplane: |
| x_bcg = x0 + self.act(self.conv3D_1_bg(x0)) |
| return torch.cat([x, x_bcg], 1) |
| else: |
| return x |
|
|
|
|
| class RodinConv3D4X_lite_mlp_as_residual_litev2( |
| RodinConv3D4X_lite_mlp_as_residual): |
|
|
| def __init__(self, |
| in_chans, |
| out_chans, |
| num_feat=128, |
| input_resolution=256, |
| interp_mode='bilinear', |
| bcg_triplane=False) -> None: |
| super().__init__(in_chans, out_chans, input_resolution, interp_mode, |
| bcg_triplane) |
|
|
| self.conv3D_0 = RodinRollOutConv3D_GroupConv(in_chans, in_chans) |
| self.conv_before_upsample = RodinRollOut_GroupConv_noConv3D( |
| in_chans, num_feat * 3) |
| self.conv3D_1 = RodinRollOut_GroupConv_noConv3D( |
| num_feat * 3, num_feat * 3) |
| self.conv_last = RodinRollOut_GroupConv_noConv3D( |
| num_feat * 3, out_chans) |
| self.short_cut = None |
|
|
| def forward(self, x): |
|
|
| |
| B, C3, p, p = x.shape |
| C = C3 // 3 |
|
|
| |
| |
| |
| |
| |
| |
| """following forward code copied from lite4x version |
| """ |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
| p) |
|
|
| x = x + self.conv3D_0(x) |
| x = self.act(self.conv_before_upsample(x)) |
|
|
| |
| x = self.conv_last(self.act(self.conv3D_1(self.interpolate(x)))) |
|
|
| return x |
|
|
|
|
| class RodinConv3D4X_lite_mlp_as_residual_lite( |
| RodinConv3D4X_lite_mlp_as_residual): |
|
|
| def __init__(self, |
| in_chans, |
| out_chans, |
| input_resolution=256, |
| interp_mode='bilinear') -> None: |
| super().__init__(in_chans, out_chans, input_resolution, interp_mode) |
| """replace the first Rodin Conv 3D with ordinary rollout conv to save memory |
| """ |
| self.conv3D_0 = RodinRollOut_GroupConv_noConv3D(in_chans, out_chans) |
|
|
|
|
| class SR3D(nn.Module): |
| |
| |
|
|
| def __init__(self, *args, **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
|
|
|
|
| class RodinConv3D4X_lite_mlp_as_residual_improved(nn.Module): |
|
|
| def __init__(self, |
| in_chans, |
| num_feat, |
| out_chans, |
| input_resolution=256) -> None: |
| super().__init__() |
|
|
| assert in_chans == 4 * out_chans |
| assert num_feat == 2 * out_chans |
| self.input_resolution = input_resolution |
|
|
| |
| self.upscale = 4 |
|
|
| self.conv_after_body = RodinRollOutConv3D_GroupConv( |
| in_chans, in_chans, 3, 1, 1) |
| self.conv_before_upsample = nn.Sequential( |
| RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1), |
| nn.LeakyReLU(inplace=True)) |
| self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, |
| 1) |
| if self.upscale == 4: |
| self.conv_up2 = RodinRollOutConv3D_GroupConv( |
| num_feat, num_feat, 3, 1, 1) |
| self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, |
| 1) |
| self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3, |
| 1, 1) |
|
|
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
| def forward(self, x): |
|
|
| |
| B, C3, p, p = x.shape |
| C = C3 // 3 |
| """following forward code copied from lite4x version |
| """ |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
| p) |
|
|
| |
| x = self.conv_after_body(x) + x |
| x = self.conv_before_upsample(x) |
| x = self.lrelu( |
| self.conv_up1( |
| torch.nn.functional.interpolate( |
| x, |
| scale_factor=2, |
| mode='nearest', |
| |
| |
| ))) |
| if self.upscale == 4: |
| x = self.lrelu( |
| self.conv_up2( |
| torch.nn.functional.interpolate( |
| x, |
| scale_factor=2, |
| mode='nearest', |
| |
| |
| ))) |
| x = self.conv_last(self.lrelu(self.conv_hr(x))) |
|
|
| assert x.shape[-1] == self.input_resolution |
|
|
| return x |
|
|
|
|
| class RodinConv3D4X_lite_improved_lint_withresidual(nn.Module): |
|
|
| def __init__(self, |
| in_chans, |
| num_feat, |
| out_chans, |
| input_resolution=256) -> None: |
| super().__init__() |
|
|
| assert in_chans == 4 * out_chans |
| assert num_feat == 2 * out_chans |
| self.input_resolution = input_resolution |
|
|
| |
| self.upscale = 4 |
|
|
| self.conv_after_body = RodinRollOutConv3D_GroupConv( |
| in_chans, in_chans, 3, 1, 1) |
| self.conv_before_upsample = nn.Sequential( |
| RodinRollOutConv3D_GroupConv(in_chans, num_feat, 3, 1, 1), |
| nn.LeakyReLU(inplace=True)) |
| self.conv_up1 = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, |
| 1) |
| if self.upscale == 4: |
| self.conv_up2 = RodinRollOutConv3D_GroupConv( |
| num_feat, num_feat, 3, 1, 1) |
| self.conv_hr = RodinRollOutConv3D_GroupConv(num_feat, num_feat, 3, 1, |
| 1) |
| self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, out_chans, 3, |
| 1, 1) |
|
|
| self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
| def forward(self, x): |
|
|
| |
| B, C3, p, p = x.shape |
| C = C3 // 3 |
| """following forward code copied from lite4x version |
| """ |
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
| p) |
|
|
| |
| x = self.conv_after_body(x) + x |
| x = self.conv_before_upsample(x) |
| x = self.lrelu( |
| self.conv_up1( |
| torch.nn.functional.interpolate( |
| x, |
| scale_factor=2, |
| mode='nearest', |
| |
| |
| ))) |
| if self.upscale == 4: |
| x = self.lrelu( |
| self.conv_up2( |
| torch.nn.functional.interpolate( |
| x, |
| scale_factor=2, |
| mode='nearest', |
| |
| |
| ))) |
| x = self.conv_last(self.lrelu(self.conv_hr(x) + x)) |
|
|
| assert x.shape[-1] == self.input_resolution |
|
|
| return x |
|
|
|
|
| class RodinRollOutConv3DSR_FlexibleChannels(nn.Module): |
|
|
| def __init__(self, |
| in_chans, |
| num_out_ch=96, |
| input_resolution=256, |
| **kwargs) -> None: |
| super().__init__() |
|
|
| self.block0 = RodinConv3D_SynthesisLayer(in_chans, |
| num_out_ch) |
| self.block1 = RodinConv3D_SynthesisLayer(num_out_ch, num_out_ch) |
|
|
| self.input_resolution = input_resolution |
|
|
| def forward(self, x): |
| |
| B, C3, p, p = x.shape |
| C = C3 // 3 |
| |
|
|
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
| p) |
|
|
| if x.shape[-1] != self.input_resolution: |
| x = torch.nn.functional.interpolate(x, |
| size=(self.input_resolution, |
| self.input_resolution), |
| mode='bilinear', |
| align_corners=False, |
| antialias=True) |
|
|
| x = self.block0(x) |
| x = self.block1(x) |
|
|
| return x |
|
|
|
|
| |
| class RodinRollOutConv3DSR4X(nn.Module): |
| |
|
|
| def __init__(self, in_chans, **kwargs) -> None: |
| super().__init__() |
| |
| |
|
|
| self.block0 = RodinConv3D_SynthesisLayer(in_chans, 96) |
| self.block1 = RodinConv3D_SynthesisLayer( |
| 96, 96) |
|
|
| self.input_resolution = 64 |
|
|
| def forward(self, x): |
| |
| B, C3, p, p = x.shape |
| C = C3 // 3 |
| |
|
|
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
| p) |
|
|
| if x.shape[-1] != self.input_resolution: |
| x = torch.nn.functional.interpolate(x, |
| size=(self.input_resolution, |
| self.input_resolution), |
| mode='bilinear', |
| align_corners=False, |
| antialias=True) |
|
|
| x = self.block0(x) |
| x = self.block1(x) |
|
|
| return x |
|
|
|
|
| class Upsample3D(nn.Module): |
| """Upsample module. |
| Args: |
| scale (int): Scale factor. Supported scales: 2^n and 3. |
| num_feat (int): Channel number of intermediate features. |
| """ |
|
|
| def __init__(self, scale, num_feat): |
| super().__init__() |
|
|
| m_convs = [] |
| m_pixelshuffle = [] |
|
|
| assert (scale & (scale - 1)) == 0, 'scale = 2^n' |
| self.scale = scale |
|
|
| for _ in range(int(math.log(scale, 2))): |
| m_convs.append( |
| RodinRollOutConv3D_GroupConv(num_feat, 4 * num_feat, 3, 1, 1)) |
| m_pixelshuffle.append(nn.PixelShuffle(2)) |
|
|
| self.m_convs = nn.ModuleList(m_convs) |
| self.m_pixelshuffle = nn.ModuleList(m_pixelshuffle) |
|
|
| |
| def forward(self, x): |
| for scale_idx in range(int(math.log(self.scale, 2))): |
| x = self.m_convs[scale_idx](x) |
| |
| |
| x = x.reshape(x.shape[0] * 3, x.shape[1] // 3, *x.shape[2:]) |
| x = self.m_pixelshuffle[scale_idx](x) |
| x = x.reshape(x.shape[0] // 3, x.shape[1] * 3, *x.shape[2:]) |
|
|
| return x |
|
|
|
|
| class RodinConv3DPixelUnshuffleUpsample(nn.Module): |
|
|
| def __init__(self, |
| output_dim, |
| num_feat=32 * 6, |
| num_out_ch=32 * 3, |
| sr_ratio=4, |
| *args, |
| **kwargs) -> None: |
| super().__init__() |
|
|
| self.conv_after_body = RodinRollOutConv3D_GroupConv( |
| output_dim, output_dim, 3, 1, 1) |
| self.conv_before_upsample = nn.Sequential( |
| RodinRollOutConv3D_GroupConv(output_dim, num_feat, 3, 1, 1), |
| nn.LeakyReLU(inplace=True)) |
| self.upsample = Upsample3D(sr_ratio, num_feat) |
| self.conv_last = RodinRollOutConv3D_GroupConv(num_feat, num_out_ch, 3, |
| 1, 1) |
|
|
| |
| def forward(self, x, input_skip_connection=True, *args, **kwargs): |
| |
| if input_skip_connection: |
| x = self.conv_after_body(x) + x |
| else: |
| x = self.conv_after_body(x) |
|
|
| x = self.conv_before_upsample(x) |
| x = self.upsample(x) |
| x = self.conv_last(x) |
| return x |
|
|
|
|
| class RodinConv3DPixelUnshuffleUpsample_improvedVersion(nn.Module): |
|
|
| def __init__( |
| self, |
| output_dim, |
| num_out_ch=32 * 3, |
| sr_ratio=4, |
| input_resolution=256, |
| ) -> None: |
| super().__init__() |
|
|
| self.input_resolution = input_resolution |
|
|
| |
| |
| self.upsample = Upsample3D(sr_ratio, output_dim) |
| self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, |
| 3, 1, 1) |
|
|
| def forward(self, x, bilinear_upsample=True): |
|
|
| B, C3, p, p = x.shape |
| C = C3 // 3 |
| group_size = C3 // C |
|
|
| assert group_size == 3, 'designed for triplane here' |
|
|
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
| p) |
|
|
| if bilinear_upsample and x.shape[-1] != self.input_resolution: |
| x_bilinear_upsample = torch.nn.functional.interpolate( |
| x, |
| size=(self.input_resolution, self.input_resolution), |
| mode='bilinear', |
| align_corners=False, |
| antialias=True) |
| x = self.upsample(x) + x_bilinear_upsample |
| else: |
| |
| x = self.upsample(x) |
|
|
| x = self.conv_last(x) |
|
|
| return x |
|
|
|
|
| class RodinConv3DPixelUnshuffleUpsample_improvedVersion2(nn.Module): |
| """removed nearest neighbour residual conenctions, add a conv layer residual conenction |
| """ |
|
|
| def __init__( |
| self, |
| output_dim, |
| num_out_ch=32 * 3, |
| sr_ratio=4, |
| input_resolution=256, |
| ) -> None: |
| super().__init__() |
|
|
| self.input_resolution = input_resolution |
|
|
| self.conv_after_body = RodinRollOutConv3D_GroupConv( |
| output_dim, num_out_ch, 3, 1, 1) |
| self.upsample = Upsample3D(sr_ratio, output_dim) |
| self.conv_last = RodinRollOutConv3D_GroupConv(output_dim, num_out_ch, |
| 3, 1, 1) |
|
|
| def forward(self, x, input_skip_connection=True): |
|
|
| B, C3, p, p = x.shape |
| C = C3 // 3 |
| group_size = C3 // C |
|
|
| assert group_size == 3, 'designed for triplane here' |
|
|
| x = x.permute(0, 1, 3, 2).reshape(B, 3 * C, p, |
| p) |
|
|
| if input_skip_connection: |
| x = self.conv_after_body(x) + x |
| else: |
| x = self.conv_after_body(x) |
|
|
| x = self.upsample(x) |
| x = self.conv_last(x) |
|
|
| return x |
|
|
|
|
| class CLSCrossAttentionBlock(nn.Module): |
|
|
| def __init__(self, |
| dim, |
| num_heads, |
| mlp_ratio=4., |
| qkv_bias=False, |
| qk_scale=None, |
| drop=0., |
| attn_drop=0., |
| drop_path=0., |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, |
| has_mlp=False): |
| super().__init__() |
| self.norm1 = norm_layer(dim) |
| self.attn = CrossAttention(dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| attn_drop=attn_drop, |
| proj_drop=drop) |
| |
| self.drop_path = DropPath( |
| drop_path) if drop_path > 0. else nn.Identity() |
| self.has_mlp = has_mlp |
| if has_mlp: |
| self.norm2 = norm_layer(dim) |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = Mlp(in_features=dim, |
| hidden_features=mlp_hidden_dim, |
| act_layer=act_layer, |
| drop=drop) |
|
|
| def forward(self, x): |
| x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) |
| if self.has_mlp: |
| x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
| return x |
|
|
|
|
| class Conv3DCrossAttentionBlock(nn.Module): |
|
|
| def __init__(self, |
| dim, |
| num_heads, |
| mlp_ratio=4., |
| qkv_bias=False, |
| qk_scale=None, |
| drop=0., |
| attn_drop=0., |
| drop_path=0., |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, |
| has_mlp=False): |
| super().__init__() |
| self.norm1 = norm_layer(dim) |
| self.attn = Conv3D_Aware_CrossAttention(dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| attn_drop=attn_drop, |
| proj_drop=drop) |
| |
| self.drop_path = DropPath( |
| drop_path) if drop_path > 0. else nn.Identity() |
| self.has_mlp = has_mlp |
| if has_mlp: |
| self.norm2 = norm_layer(dim) |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = Mlp(in_features=dim, |
| hidden_features=mlp_hidden_dim, |
| act_layer=act_layer, |
| drop=drop) |
|
|
| def forward(self, x): |
| x = x + self.drop_path(self.attn(self.norm1(x))) |
| if self.has_mlp: |
| x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
| return x |
|
|
|
|
| class Conv3DCrossAttentionBlockXformerMHA(Conv3DCrossAttentionBlock): |
|
|
| def __init__(self, |
| dim, |
| num_heads, |
| mlp_ratio=4, |
| qkv_bias=False, |
| qk_scale=None, |
| drop=0, |
| attn_drop=0, |
| drop_path=0, |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, |
| has_mlp=False): |
| super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, |
| attn_drop, drop_path, act_layer, norm_layer, has_mlp) |
| |
| self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid( |
| dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| attn_drop=attn_drop, |
| proj_drop=drop) |
|
|
|
|
| class Conv3DCrossAttentionBlockXformerMHANested( |
| Conv3DCrossAttentionBlockXformerMHA): |
|
|
| def __init__(self, |
| dim, |
| num_heads, |
| mlp_ratio=4, |
| qkv_bias=False, |
| qk_scale=None, |
| drop=0., |
| attn_drop=0., |
| drop_path=0., |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, |
| has_mlp=False): |
| super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, |
| attn_drop, drop_path, act_layer, norm_layer, has_mlp) |
| """for in-place replaing the internal attn in Dino ViT. |
| """ |
|
|
| def forward(self, x): |
| Bx3, N, C = x.shape |
| B, group_size = Bx3 // 3, 3 |
| x = x.reshape(B, group_size, N, C) |
| x = super().forward(x) |
| return x.reshape(B * group_size, N, |
| C) |
|
|
|
|
| class Conv3DCrossAttentionBlockXformerMHANested_withinC( |
| Conv3DCrossAttentionBlockXformerMHANested): |
|
|
| def __init__(self, |
| dim, |
| num_heads, |
| mlp_ratio=4, |
| qkv_bias=False, |
| qk_scale=None, |
| drop=0, |
| attn_drop=0, |
| drop_path=0, |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, |
| has_mlp=False): |
| super().__init__(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, |
| attn_drop, drop_path, act_layer, norm_layer, has_mlp) |
| self.attn = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( |
| dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| attn_drop=attn_drop, |
| proj_drop=drop) |
|
|
| def forward(self, x): |
| |
| x = x + self.drop_path(self.attn(self.norm1(x))) |
| if self.has_mlp: |
| x = x + self.drop_path(self.mlp(self.norm2(x))) |
|
|
| return x |
|
|
|
|
| class TriplaneFusionBlock(nn.Module): |
| """4 ViT blocks + 1 CrossAttentionBlock |
| """ |
|
|
| def __init__(self, |
| vit_blks, |
| num_heads, |
| embed_dim, |
| use_fusion_blk=True, |
| cross_attention_blk=CLSCrossAttentionBlock, |
| *args, |
| **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
|
|
| self.num_branches = 3 |
| self.vit_blks = vit_blks |
|
|
| if use_fusion_blk: |
| self.fusion = nn.ModuleList() |
|
|
| |
| nh = num_heads |
| dim = embed_dim |
|
|
| mlp_ratio = 4 |
| qkv_bias = True |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) |
| drop_path_rate = 0.3 |
| attn_drop = proj_drop = 0.0 |
| qk_scale = None |
|
|
| for d in range(self.num_branches): |
| self.fusion.append( |
| cross_attention_blk( |
| dim=dim, |
| num_heads=nh, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| |
| drop=proj_drop, |
| attn_drop=attn_drop, |
| drop_path=drop_path_rate, |
| norm_layer=norm_layer, |
| has_mlp=False)) |
| else: |
| self.fusion = None |
|
|
| def forward(self, x): |
| |
| """x: B 3 N C, where N = H*W tokens |
| """ |
|
|
| |
|
|
| |
| B, group_size, N, C = x.shape |
| assert group_size == 3, 'triplane' |
| x = x.view(B * group_size, N, C) |
|
|
| for blk in self.vit_blks: |
| x = blk(x) |
|
|
| if self.fusion is None: |
| return x.view(B, group_size, N, C) |
|
|
| |
| |
| |
|
|
| outs_b = x.chunk(chunks=3, |
| dim=0) |
|
|
| |
| proj_cls_token = [x[:, 0:1] for x in outs_b] |
| |
| outs = [] |
| for i in range(self.num_branches): |
| tmp = torch.cat( |
| (proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, |
| ...]), |
| dim=1) |
| tmp = self.fusion[i](tmp) |
| |
| reverted_proj_cls_token = tmp[:, 0:1, ...] |
| tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), |
| dim=1) |
| outs.append(tmp) |
| |
| outs = torch.stack(outs, 1) |
| return outs |
|
|
|
|
| class TriplaneFusionBlockv2(nn.Module): |
| """4 ViT blocks + 1 CrossAttentionBlock |
| """ |
|
|
| def __init__(self, |
| vit_blks, |
| num_heads, |
| embed_dim, |
| use_fusion_blk=True, |
| fusion_ca_blk=Conv3DCrossAttentionBlock, |
| *args, |
| **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
|
|
| self.num_branches = 3 |
| self.vit_blks = vit_blks |
|
|
| if use_fusion_blk: |
| |
|
|
| |
| nh = num_heads |
| dim = embed_dim |
|
|
| mlp_ratio = 4 |
| qkv_bias = True |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) |
| drop_path_rate = 0.3 |
| attn_drop = proj_drop = 0.0 |
| qk_scale = None |
|
|
| |
| self.fusion = fusion_ca_blk( |
| dim=dim, |
| num_heads=nh, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| |
| drop=proj_drop, |
| attn_drop=attn_drop, |
| drop_path=drop_path_rate, |
| norm_layer=norm_layer, |
| has_mlp=False) |
| else: |
| self.fusion = None |
|
|
| def forward(self, x): |
| |
| """x: B 3 N C, where N = H*W tokens |
| """ |
|
|
| |
|
|
| |
| B, group_size, N, C = x.shape |
| assert group_size == 3, 'triplane' |
| x = x.reshape(B * group_size, N, C) |
|
|
| for blk in self.vit_blks: |
| x = blk(x) |
|
|
| if self.fusion is None: |
| return x.reshape(B, group_size, N, C) |
|
|
| x = x.reshape(B, group_size, N, C) |
| |
| return self.fusion(x) |
|
|
|
|
| class TriplaneFusionBlockv3(TriplaneFusionBlockv2): |
|
|
| def __init__(self, |
| vit_blks, |
| num_heads, |
| embed_dim, |
| use_fusion_blk=True, |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA, |
| *args, |
| **kwargs) -> None: |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, |
| fusion_ca_blk, *args, **kwargs) |
|
|
|
|
| class TriplaneFusionBlockv4(TriplaneFusionBlockv3): |
|
|
| def __init__(self, |
| vit_blks, |
| num_heads, |
| embed_dim, |
| use_fusion_blk=True, |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHA, |
| *args, |
| **kwargs) -> None: |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, |
| fusion_ca_blk, *args, **kwargs) |
| """OOM? directly replace the atten here |
| """ |
|
|
| assert len(vit_blks) == 2 |
| |
| del self.vit_blks[1].attn, self.vit_blks[1].ls1, self.vit_blks[1].norm1 |
|
|
| def ffn_residual_func(self, tx_blk, x: Tensor) -> Tensor: |
| return tx_blk.ls2( |
| tx_blk.mlp(tx_blk.norm2(x)) |
| ) |
|
|
| def forward(self, x): |
| """x: B 3 N C, where N = H*W tokens |
| """ |
| assert self.fusion is not None |
|
|
| B, group_size, N, C = x.shape |
| x = x.reshape(B * group_size, N, C) |
|
|
| |
| x = self.vit_blks[0](x) |
|
|
| |
| x = x + self.fusion(x.reshape(B, group_size, N, C)).reshape( |
| B * group_size, N, C) |
| x = x + self.ffn_residual_func(self.vit_blks[1], x) |
| return x.reshape(B, group_size, N, C) |
|
|
|
|
| class TriplaneFusionBlockv4_nested(nn.Module): |
|
|
| def __init__(self, |
| vit_blks, |
| num_heads, |
| embed_dim, |
| use_fusion_blk=True, |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, |
| *args, |
| **kwargs) -> None: |
| super().__init__() |
|
|
| self.num_branches = 3 |
| self.vit_blks = vit_blks |
|
|
| assert use_fusion_blk |
|
|
| assert len(vit_blks) == 2 |
|
|
| |
| del self.vit_blks[ |
| 1].attn |
|
|
| |
| nh = num_heads |
| dim = embed_dim |
|
|
| mlp_ratio = 4 |
| qkv_bias = True |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) |
| drop_path_rate = 0.3 |
| attn_drop = proj_drop = 0.0 |
| qk_scale = None |
|
|
| self.vit_blks[1].attn = fusion_ca_blk( |
| dim=dim, |
| num_heads=nh, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| |
| drop=proj_drop, |
| attn_drop=attn_drop, |
| drop_path=drop_path_rate, |
| norm_layer=norm_layer, |
| has_mlp=False) |
|
|
| def forward(self, x): |
| """x: B 3 N C, where N = H*W tokens |
| """ |
|
|
| |
|
|
| |
| B, group_size, N, C = x.shape |
| assert group_size == 3, 'triplane' |
| x = x.reshape(B * group_size, N, C) |
|
|
| for blk in self.vit_blks: |
| x = blk(x) |
|
|
| |
| return x.reshape(B, group_size, N, C) |
|
|
|
|
| class TriplaneFusionBlockv4_nested_init_from_dino(nn.Module): |
|
|
| def __init__(self, |
| vit_blks, |
| num_heads, |
| embed_dim, |
| use_fusion_blk=True, |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, |
| init_from_dino=True, |
| *args, |
| **kwargs) -> None: |
| super().__init__() |
|
|
| self.num_branches = 3 |
| self.vit_blks = vit_blks |
|
|
| assert use_fusion_blk |
|
|
| assert len(vit_blks) == 2 |
|
|
| |
| nh = num_heads |
| dim = embed_dim |
|
|
| mlp_ratio = 4 |
| qkv_bias = True |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) |
| drop_path_rate = 0.3 |
| attn_drop = proj_drop = 0.0 |
| qk_scale = None |
|
|
| attn_3d = fusion_ca_blk( |
| dim=dim, |
| num_heads=nh, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| |
| drop=proj_drop, |
| attn_drop=attn_drop, |
| drop_path=drop_path_rate, |
| norm_layer=norm_layer, |
| has_mlp=False) |
|
|
| |
| if init_from_dino: |
| merged_qkv_linear = self.vit_blks[1].attn.qkv |
| attn_3d.attn.proj.load_state_dict( |
| self.vit_blks[1].attn.proj.state_dict()) |
|
|
| |
| attn_3d.attn.wq.weight.data = merged_qkv_linear.weight.data[: |
| dim, :] |
| attn_3d.attn.w_kv.weight.data = merged_qkv_linear.weight.data[ |
| dim:, :] |
|
|
| |
| if qkv_bias: |
| attn_3d.attn.wq.bias.data = merged_qkv_linear.bias.data[:dim] |
| attn_3d.attn.w_kv.bias.data = merged_qkv_linear.bias.data[dim:] |
|
|
| del self.vit_blks[1].attn |
| |
| self.vit_blks[1].attn = attn_3d |
|
|
| def forward(self, x): |
| """x: B 3 N C, where N = H*W tokens |
| """ |
|
|
| |
|
|
| |
| B, group_size, N, C = x.shape |
| assert group_size == 3, 'triplane' |
| x = x.reshape(B * group_size, N, C) |
|
|
| for blk in self.vit_blks: |
| x = blk(x) |
|
|
| |
| return x.reshape(B, group_size, N, C) |
|
|
|
|
| class TriplaneFusionBlockv4_nested_init_from_dino_lite(nn.Module): |
|
|
| def __init__(self, |
| vit_blks, |
| num_heads, |
| embed_dim, |
| use_fusion_blk=True, |
| fusion_ca_blk=None, |
| *args, |
| **kwargs) -> None: |
| super().__init__() |
|
|
| self.num_branches = 3 |
| self.vit_blks = vit_blks |
|
|
| assert use_fusion_blk |
|
|
| assert len(vit_blks) == 2 |
|
|
| |
| nh = num_heads |
| dim = embed_dim |
|
|
| mlp_ratio = 4 |
| qkv_bias = True |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) |
| drop_path_rate = 0.3 |
| attn_drop = proj_drop = 0.0 |
| qk_scale = None |
|
|
| attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( |
| dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| attn_drop=attn_drop, |
| proj_drop=proj_drop) |
|
|
| del self.vit_blks[1].attn |
| |
| self.vit_blks[1].attn = attn_3d |
|
|
| def forward(self, x): |
| """x: B N C, where N = H*W tokens. Just raw ViT forward pass |
| """ |
|
|
| |
| B, N, C = x.shape |
|
|
| for blk in self.vit_blks: |
| x = blk(x) |
|
|
| return x |
|
|
| class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge(nn.Module): |
|
|
| def __init__(self, |
| vit_blks, |
| num_heads, |
| embed_dim, |
| use_fusion_blk=True, |
| fusion_ca_blk=None, |
| *args, |
| **kwargs) -> None: |
| super().__init__() |
|
|
| self.vit_blks = vit_blks |
|
|
| assert use_fusion_blk |
| assert len(vit_blks) == 2 |
|
|
| |
| nh = num_heads |
| dim = embed_dim |
| qkv_bias = True |
| attn_drop = proj_drop = 0.0 |
| qk_scale = None |
|
|
| if False: |
| for blk in self.vit_blks: |
| attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid_withinC( |
| dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| attn_drop=attn_drop, |
| proj_drop=proj_drop) |
| blk.attn = self_cross_attn(blk.attn, attn_3d) |
|
|
| def forward(self, x): |
| """x: B N C, where N = H*W tokens. Just raw ViT forward pass |
| """ |
|
|
| |
| B, N, C = x.shape |
|
|
| for blk in self.vit_blks: |
| x = blk(x) |
|
|
| return x |
|
|
| class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge): |
| |
| def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) |
|
|
|
|
| def forward(self, x): |
| """x: B 3 N C, where N = H*W tokens |
| """ |
|
|
| |
|
|
| |
| B, group_size, N, C = x.shape |
| x = x.reshape(B, group_size*N, C) |
|
|
| for blk in self.vit_blks: |
| x = blk(x) |
|
|
| x = x.reshape(B, group_size, N, C) |
|
|
| return x |
|
|
| class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout(TriplaneFusionBlockv4_nested_init_from_dino_lite_merge): |
| |
| def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) |
|
|
|
|
| def forward(self, x): |
| """x: B 3 N C, where N = H*W tokens |
| """ |
|
|
| |
|
|
| |
| B, group_size, N, C = x.shape |
| x = x.reshape(B*group_size, N, C) |
| x = self.vit_blks[0](x) |
|
|
| x = x.reshape(B,group_size*N, C) |
| x = self.vit_blks[1](x) |
|
|
| x = x.reshape(B, group_size, N, C) |
|
|
| return x |
|
|
|
|
| class TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_add3DAttn(TriplaneFusionBlockv4_nested_init_from_dino): |
| |
| def __init__(self, vit_blks, num_heads, embed_dim, use_fusion_blk=True, fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, init_from_dino=True, *args, **kwargs) -> None: |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, fusion_ca_blk, init_from_dino, *args, **kwargs) |
|
|
|
|
| def forward(self, x): |
| """x: B 3 N C, where N = H*W tokens |
| """ |
|
|
| B, group_size, N, C = x.shape |
| x = x.reshape(B, group_size*N, C) |
| x = self.vit_blks[0](x) |
|
|
| |
| x = x.reshape(B, group_size, N, C).reshape(B*group_size, N, C) |
| x = self.vit_blks[1](x) |
| return x.reshape(B, group_size, N, C) |
|
|
| return x |
|
|
|
|
| class TriplaneFusionBlockv5_ldm_addCA(nn.Module): |
|
|
| def __init__(self, |
| vit_blks, |
| num_heads, |
| embed_dim, |
| use_fusion_blk=True, |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, |
| *args, |
| **kwargs) -> None: |
| super().__init__() |
|
|
| self.num_branches = 3 |
| self.vit_blks = vit_blks |
|
|
| assert use_fusion_blk |
|
|
| assert len(vit_blks) == 2 |
|
|
| |
| |
| |
| self.norm_for_atten_3d = deepcopy(self.vit_blks[1].norm1) |
|
|
| |
| nh = num_heads |
| dim = embed_dim |
|
|
| mlp_ratio = 4 |
| qkv_bias = True |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) |
| drop_path_rate = 0.3 |
| attn_drop = proj_drop = 0.0 |
| qk_scale = None |
|
|
| self.attn_3d = xformer_Conv3D_Aware_CrossAttention_xygrid( |
| dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| attn_drop=attn_drop, |
| proj_drop=proj_drop) |
|
|
| def forward(self, x): |
| """x: B 3 N C, where N = H*W tokens |
| """ |
|
|
| |
|
|
| |
| B, group_size, N, C = x.shape |
| assert group_size == 3, 'triplane' |
|
|
| flatten_token = lambda x: x.reshape(B * group_size, N, C) |
| unflatten_token = lambda x: x.reshape(B, group_size, N, C) |
|
|
| x = flatten_token(x) |
| x = self.vit_blks[0](x) |
|
|
| x = unflatten_token(x) |
| x = self.attn_3d(self.norm_for_atten_3d(x)) + x |
|
|
| x = flatten_token(x) |
| x = self.vit_blks[1](x) |
|
|
| return unflatten_token(x) |
|
|
|
|
| class TriplaneFusionBlockv6_ldm_addCA_Init3DAttnfrom2D( |
| TriplaneFusionBlockv5_ldm_addCA): |
|
|
| def __init__(self, |
| vit_blks, |
| num_heads, |
| embed_dim, |
| use_fusion_blk=True, |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, |
| *args, |
| **kwargs) -> None: |
| super().__init__(vit_blks, num_heads, embed_dim, use_fusion_blk, |
| fusion_ca_blk, *args, **kwargs) |
|
|
| def forward(self, x): |
| """x: B 3 N C, where N = H*W tokens |
| """ |
|
|
| |
|
|
| |
| B, group_size, N, C = x.shape |
| assert group_size == 3, 'triplane' |
|
|
| flatten_token = lambda x: x.reshape(B * group_size, N, C) |
| unflatten_token = lambda x: x.reshape(B, group_size, N, C) |
|
|
| x = flatten_token(x) |
| x = self.vit_blks[0](x) |
|
|
| x = unflatten_token(x) |
| x = self.attn_3d(self.norm_for_atten_3d(x)) + x |
|
|
| x = flatten_token(x) |
| x = self.vit_blks[1](x) |
|
|
| return unflatten_token(x) |
|
|
|
|
| class TriplaneFusionBlockv5_ldm_add_dualCA(nn.Module): |
|
|
| def __init__(self, |
| vit_blks, |
| num_heads, |
| embed_dim, |
| use_fusion_blk=True, |
| fusion_ca_blk=Conv3DCrossAttentionBlockXformerMHANested, |
| *args, |
| **kwargs) -> None: |
| super().__init__() |
|
|
| self.num_branches = 3 |
| self.vit_blks = vit_blks |
|
|
| assert use_fusion_blk |
|
|
| assert len(vit_blks) == 2 |
|
|
| |
| |
| |
| self.norm_for_atten_3d_0 = deepcopy(self.vit_blks[0].norm1) |
| self.norm_for_atten_3d_1 = deepcopy(self.vit_blks[1].norm1) |
|
|
| |
| nh = num_heads |
| dim = embed_dim |
|
|
| mlp_ratio = 4 |
| qkv_bias = True |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) |
| drop_path_rate = 0.3 |
| attn_drop = proj_drop = 0.0 |
| qk_scale = None |
|
|
| self.attn_3d_0 = xformer_Conv3D_Aware_CrossAttention_xygrid( |
| dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| attn_drop=attn_drop, |
| proj_drop=proj_drop) |
|
|
| self.attn_3d_1 = deepcopy(self.attn_3d_0) |
|
|
| def forward(self, x): |
| """x: B 3 N C, where N = H*W tokens |
| """ |
|
|
| |
|
|
| |
| B, group_size, N, C = x.shape |
| assert group_size == 3, 'triplane' |
|
|
| flatten_token = lambda x: x.reshape(B * group_size, N, C) |
| unflatten_token = lambda x: x.reshape(B, group_size, N, C) |
|
|
| x = flatten_token(x) |
| x = self.vit_blks[0](x) |
|
|
| x = unflatten_token(x) |
| x = self.attn_3d_0(self.norm_for_atten_3d_0(x)) + x |
|
|
| x = flatten_token(x) |
| x = self.vit_blks[1](x) |
|
|
| x = unflatten_token(x) |
| x = self.attn_3d_1(self.norm_for_atten_3d_1(x)) + x |
|
|
| return unflatten_token(x) |
|
|
|
|
| def drop_path(x, drop_prob: float = 0., training: bool = False): |
| if drop_prob == 0. or not training: |
| return x |
| keep_prob = 1 - drop_prob |
| shape = (x.shape[0], ) + (1, ) * ( |
| x.ndim - 1) |
| random_tensor = keep_prob + torch.rand( |
| shape, dtype=x.dtype, device=x.device) |
| random_tensor.floor_() |
| output = x.div(keep_prob) * random_tensor |
| return output |
|
|
|
|
| class DropPath(nn.Module): |
| """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
| """ |
|
|
| def __init__(self, drop_prob=None): |
| super(DropPath, self).__init__() |
| self.drop_prob = drop_prob |
|
|
| def forward(self, x): |
| return drop_path(x, self.drop_prob, self.training) |
|
|
|
|
| class Mlp(nn.Module): |
|
|
| def __init__(self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.GELU, |
| drop=0.): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class Block(nn.Module): |
|
|
| def __init__(self, |
| dim, |
| num_heads, |
| mlp_ratio=4., |
| qkv_bias=False, |
| qk_scale=None, |
| drop=0., |
| attn_drop=0., |
| drop_path=0., |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm): |
| super().__init__() |
| self.norm1 = norm_layer(dim) |
| self.attn = Attention(dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| attn_drop=attn_drop, |
| proj_drop=drop) |
| self.drop_path = DropPath( |
| drop_path) if drop_path > 0. else nn.Identity() |
| self.norm2 = norm_layer(dim) |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = Mlp(in_features=dim, |
| hidden_features=mlp_hidden_dim, |
| act_layer=act_layer, |
| drop=drop) |
|
|
| def forward(self, x, return_attention=False): |
| y, attn = self.attn(self.norm1(x)) |
| if return_attention: |
| return attn |
| x = x + self.drop_path(y) |
| x = x + self.drop_path(self.mlp(self.norm2(x))) |
| return x |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """ Image to Patch Embedding |
| """ |
|
|
| def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): |
| super().__init__() |
| num_patches = (img_size // patch_size) * (img_size // patch_size) |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.num_patches = num_patches |
|
|
| self.proj = nn.Conv2d(in_chans, |
| embed_dim, |
| kernel_size=patch_size, |
| stride=patch_size) |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| x = self.proj(x).flatten(2).transpose(1, 2) |
| return x |
|
|
|
|
| class VisionTransformer(nn.Module): |
| """ Vision Transformer """ |
|
|
| def __init__(self, |
| img_size=[224], |
| patch_size=16, |
| in_chans=3, |
| num_classes=0, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| mlp_ratio=4., |
| qkv_bias=False, |
| qk_scale=None, |
| drop_rate=0., |
| attn_drop_rate=0., |
| drop_path_rate=0., |
| norm_layer='nn.LayerNorm', |
| patch_embedding=True, |
| cls_token=True, |
| pixel_unshuffle=False, |
| **kwargs): |
| super().__init__() |
| self.num_features = self.embed_dim = embed_dim |
| self.patch_size = patch_size |
|
|
| |
| norm_layer = partial(nn.LayerNorm, eps=1e-6) |
|
|
| if patch_embedding: |
| self.patch_embed = PatchEmbed(img_size=img_size[0], |
| patch_size=patch_size, |
| in_chans=in_chans, |
| embed_dim=embed_dim) |
| num_patches = self.patch_embed.num_patches |
| self.img_size = self.patch_embed.img_size |
| else: |
| self.patch_embed = None |
| self.img_size = img_size[0] |
| num_patches = (img_size[0] // patch_size) * (img_size[0] // |
| patch_size) |
|
|
| if cls_token: |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| self.pos_embed = nn.Parameter( |
| torch.zeros(1, num_patches + 1, embed_dim)) |
| else: |
| self.cls_token = None |
| self.pos_embed = nn.Parameter( |
| torch.zeros(1, num_patches, embed_dim)) |
|
|
| self.pos_drop = nn.Dropout(p=drop_rate) |
|
|
| dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) |
| ] |
| self.blocks = nn.ModuleList([ |
| Block(dim=embed_dim, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| qk_scale=qk_scale, |
| drop=drop_rate, |
| attn_drop=attn_drop_rate, |
| drop_path=dpr[i], |
| norm_layer=norm_layer) for i in range(depth) |
| ]) |
| self.norm = norm_layer(embed_dim) |
|
|
| |
| self.head = nn.Linear( |
| embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
|
|
| trunc_normal_(self.pos_embed, std=.02) |
| if cls_token: |
| trunc_normal_(self.cls_token, std=.02) |
| self.apply(self._init_weights) |
|
|
| |
| |
| |
| |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| trunc_normal_(m.weight, std=.02) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
|
|
| def interpolate_pos_encoding(self, x, w, h): |
| npatch = x.shape[1] - 1 |
| N = self.pos_embed.shape[1] - 1 |
| if npatch == N and w == h: |
| return self.pos_embed |
| patch_pos_embed = self.pos_embed[:, 1:] |
| dim = x.shape[-1] |
| w0 = w // self.patch_size |
| h0 = h // self.patch_size |
| |
| |
| w0, h0 = w0 + 0.1, h0 + 0.1 |
|
|
| patch_pos_embed = nn.functional.interpolate( |
| patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), |
| dim).permute(0, 3, 1, 2), |
| scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), |
| mode='bicubic', |
| ) |
| assert int(w0) == patch_pos_embed.shape[-2] and int( |
| h0) == patch_pos_embed.shape[-1] |
| patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(2, -1, dim) |
|
|
| if self.cls_token is not None: |
| class_pos_embed = self.pos_embed[:, 0] |
| return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), |
| dim=1) |
| return patch_pos_embed |
|
|
| def prepare_tokens(self, x): |
| B, nc, w, h = x.shape |
| x = self.patch_embed(x) |
|
|
| |
| cls_tokens = self.cls_token.expand(B, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
|
|
| |
| x = x + self.interpolate_pos_encoding(x, w, h) |
|
|
| return self.pos_drop(x) |
|
|
| def forward(self, x): |
| x = self.prepare_tokens(x) |
| for blk in self.blocks: |
| x = blk(x) |
| x = self.norm(x) |
| return x[:, 1:] |
| |
|
|
| def get_last_selfattention(self, x): |
| x = self.prepare_tokens(x) |
| for i, blk in enumerate(self.blocks): |
| if i < len(self.blocks) - 1: |
| x = blk(x) |
| else: |
| |
| return blk(x, return_attention=True) |
|
|
| def get_intermediate_layers(self, x, n=1): |
| x = self.prepare_tokens(x) |
| |
| output = [] |
| for i, blk in enumerate(self.blocks): |
| x = blk(x) |
| if len(self.blocks) - i <= n: |
| output.append(self.norm(x)) |
| return output |
|
|
|
|
| def vit_tiny(patch_size=16, **kwargs): |
| model = VisionTransformer(patch_size=patch_size, |
| embed_dim=192, |
| depth=12, |
| num_heads=3, |
| mlp_ratio=4, |
| qkv_bias=True, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| **kwargs) |
| return model |
|
|
|
|
| def vit_small(patch_size=16, **kwargs): |
| model = VisionTransformer( |
| patch_size=patch_size, |
| embed_dim=384, |
| depth=12, |
| num_heads=6, |
| mlp_ratio=4, |
| qkv_bias=True, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| **kwargs) |
| return model |
|
|
|
|
| def vit_base(patch_size=16, **kwargs): |
| model = VisionTransformer(patch_size=patch_size, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| mlp_ratio=4, |
| qkv_bias=True, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| **kwargs) |
| return model |
|
|
|
|
| vits = vit_small |
| vitb = vit_base |
|
|