| import torch |
| import torch.nn as nn |
|
|
| from .attention import ( |
| single_head_full_attention, |
| single_head_full_attention_1d, |
| single_head_split_window_attention, |
| single_head_split_window_attention_1d, |
| ) |
| from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d |
|
|
|
|
| class TransformerLayer(nn.Module): |
| def __init__( |
| self, |
| d_model=128, |
| nhead=1, |
| no_ffn=False, |
| ffn_dim_expansion=4, |
| ): |
| super().__init__() |
|
|
| self.dim = d_model |
| self.nhead = nhead |
| self.no_ffn = no_ffn |
|
|
| |
| self.q_proj = nn.Linear(d_model, d_model, bias=False) |
| self.k_proj = nn.Linear(d_model, d_model, bias=False) |
| self.v_proj = nn.Linear(d_model, d_model, bias=False) |
|
|
| self.merge = nn.Linear(d_model, d_model, bias=False) |
|
|
| self.norm1 = nn.LayerNorm(d_model) |
|
|
| |
| if not self.no_ffn: |
| in_channels = d_model * 2 |
| self.mlp = nn.Sequential( |
| nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), |
| nn.GELU(), |
| nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), |
| ) |
|
|
| self.norm2 = nn.LayerNorm(d_model) |
|
|
| def forward( |
| self, |
| source, |
| target, |
| height=None, |
| width=None, |
| shifted_window_attn_mask=None, |
| shifted_window_attn_mask_1d=None, |
| attn_type="swin", |
| with_shift=False, |
| attn_num_splits=None, |
| ): |
| |
| query, key, value = source, target, target |
|
|
| |
| is_self_attn = (query - key).abs().max() < 1e-6 |
|
|
| |
| query = self.q_proj(query) |
| key = self.k_proj(key) |
| value = self.v_proj(value) |
|
|
| if attn_type == "swin" and attn_num_splits > 1: |
| if self.nhead > 1: |
| |
| |
| raise NotImplementedError |
| else: |
| message = single_head_split_window_attention( |
| query, |
| key, |
| value, |
| num_splits=attn_num_splits, |
| with_shift=with_shift, |
| h=height, |
| w=width, |
| attn_mask=shifted_window_attn_mask, |
| ) |
|
|
| elif attn_type == "self_swin2d_cross_1d": |
| if self.nhead > 1: |
| raise NotImplementedError |
| else: |
| if is_self_attn: |
| if attn_num_splits > 1: |
| message = single_head_split_window_attention( |
| query, |
| key, |
| value, |
| num_splits=attn_num_splits, |
| with_shift=with_shift, |
| h=height, |
| w=width, |
| attn_mask=shifted_window_attn_mask, |
| ) |
| else: |
| |
| message = single_head_full_attention(query, key, value) |
|
|
| else: |
| |
| message = single_head_full_attention_1d( |
| query, |
| key, |
| value, |
| h=height, |
| w=width, |
| ) |
|
|
| elif attn_type == "self_swin2d_cross_swin1d": |
| if self.nhead > 1: |
| raise NotImplementedError |
| else: |
| if is_self_attn: |
| if attn_num_splits > 1: |
| |
| message = single_head_split_window_attention( |
| query, |
| key, |
| value, |
| num_splits=attn_num_splits, |
| with_shift=with_shift, |
| h=height, |
| w=width, |
| attn_mask=shifted_window_attn_mask, |
| ) |
| else: |
| |
| message = single_head_full_attention(query, key, value) |
| else: |
| if attn_num_splits > 1: |
| assert shifted_window_attn_mask_1d is not None |
| |
| message = single_head_split_window_attention_1d( |
| query, |
| key, |
| value, |
| num_splits=attn_num_splits, |
| with_shift=with_shift, |
| h=height, |
| w=width, |
| attn_mask=shifted_window_attn_mask_1d, |
| ) |
| else: |
| message = single_head_full_attention_1d( |
| query, |
| key, |
| value, |
| h=height, |
| w=width, |
| ) |
|
|
| else: |
| message = single_head_full_attention(query, key, value) |
|
|
| message = self.merge(message) |
| message = self.norm1(message) |
|
|
| if not self.no_ffn: |
| message = self.mlp(torch.cat([source, message], dim=-1)) |
| message = self.norm2(message) |
|
|
| return source + message |
|
|
|
|
| class TransformerBlock(nn.Module): |
| """self attention + cross attention + FFN.""" |
|
|
| def __init__( |
| self, |
| d_model=128, |
| nhead=1, |
| ffn_dim_expansion=4, |
| ): |
| super().__init__() |
|
|
| self.self_attn = TransformerLayer( |
| d_model=d_model, |
| nhead=nhead, |
| no_ffn=True, |
| ffn_dim_expansion=ffn_dim_expansion, |
| ) |
|
|
| self.cross_attn_ffn = TransformerLayer( |
| d_model=d_model, |
| nhead=nhead, |
| ffn_dim_expansion=ffn_dim_expansion, |
| ) |
|
|
| def forward( |
| self, |
| source, |
| target, |
| height=None, |
| width=None, |
| shifted_window_attn_mask=None, |
| shifted_window_attn_mask_1d=None, |
| attn_type="swin", |
| with_shift=False, |
| attn_num_splits=None, |
| ): |
| |
|
|
| |
| source = self.self_attn( |
| source, |
| source, |
| height=height, |
| width=width, |
| shifted_window_attn_mask=shifted_window_attn_mask, |
| attn_type=attn_type, |
| with_shift=with_shift, |
| attn_num_splits=attn_num_splits, |
| ) |
|
|
| |
| source = self.cross_attn_ffn( |
| source, |
| target, |
| height=height, |
| width=width, |
| shifted_window_attn_mask=shifted_window_attn_mask, |
| shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, |
| attn_type=attn_type, |
| with_shift=with_shift, |
| attn_num_splits=attn_num_splits, |
| ) |
|
|
| return source |
|
|
|
|
| class FeatureTransformer(nn.Module): |
| def __init__( |
| self, |
| num_layers=6, |
| d_model=128, |
| nhead=1, |
| ffn_dim_expansion=4, |
| ): |
| super().__init__() |
|
|
| self.d_model = d_model |
| self.nhead = nhead |
|
|
| self.layers = nn.ModuleList( |
| [ |
| TransformerBlock( |
| d_model=d_model, |
| nhead=nhead, |
| ffn_dim_expansion=ffn_dim_expansion, |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
|
|
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| def forward( |
| self, |
| feature0, |
| feature1, |
| attn_type="swin", |
| attn_num_splits=None, |
| **kwargs, |
| ): |
|
|
| b, c, h, w = feature0.shape |
| assert self.d_model == c |
|
|
| feature0 = feature0.flatten(-2).permute(0, 2, 1) |
| feature1 = feature1.flatten(-2).permute(0, 2, 1) |
|
|
| |
| if "swin" in attn_type and attn_num_splits > 1: |
| |
| window_size_h = h // attn_num_splits |
| window_size_w = w // attn_num_splits |
|
|
| |
| shifted_window_attn_mask = generate_shift_window_attn_mask( |
| input_resolution=(h, w), |
| window_size_h=window_size_h, |
| window_size_w=window_size_w, |
| shift_size_h=window_size_h // 2, |
| shift_size_w=window_size_w // 2, |
| device=feature0.device, |
| ) |
| else: |
| shifted_window_attn_mask = None |
|
|
| |
| if "swin1d" in attn_type and attn_num_splits > 1: |
| window_size_w = w // attn_num_splits |
|
|
| |
| shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d( |
| input_w=w, |
| window_size_w=window_size_w, |
| shift_size_w=window_size_w // 2, |
| device=feature0.device, |
| ) |
| else: |
| shifted_window_attn_mask_1d = None |
|
|
| |
| concat0 = torch.cat((feature0, feature1), dim=0) |
| concat1 = torch.cat((feature1, feature0), dim=0) |
|
|
| for i, layer in enumerate(self.layers): |
| concat0 = layer( |
| concat0, |
| concat1, |
| height=h, |
| width=w, |
| attn_type=attn_type, |
| with_shift="swin" in attn_type and attn_num_splits > 1 and i % 2 == 1, |
| attn_num_splits=attn_num_splits, |
| shifted_window_attn_mask=shifted_window_attn_mask, |
| shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, |
| ) |
|
|
| |
| concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) |
|
|
| feature0, feature1 = concat0.chunk(chunks=2, dim=0) |
|
|
| |
| feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() |
| feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() |
|
|
| return feature0, feature1 |
|
|