| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .utils import merge_splits, merge_splits_1d, split_feature, split_feature_1d |
|
|
|
|
| def single_head_full_attention(q, k, v): |
| |
| assert q.dim() == k.dim() == v.dim() == 3 |
|
|
| scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** 0.5) |
| attn = torch.softmax(scores, dim=2) |
| out = torch.matmul(attn, v) |
|
|
| return out |
|
|
|
|
| def single_head_full_attention_1d( |
| q, |
| k, |
| v, |
| h=None, |
| w=None, |
| ): |
| |
|
|
| assert h is not None and w is not None |
| assert q.size(1) == h * w |
|
|
| b, _, c = q.size() |
|
|
| q = q.view(b, h, w, c) |
| k = k.view(b, h, w, c) |
| v = v.view(b, h, w, c) |
|
|
| scale_factor = c**0.5 |
|
|
| scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor |
|
|
| attn = torch.softmax(scores, dim=-1) |
|
|
| out = torch.matmul(attn, v).view(b, -1, c) |
|
|
| return out |
|
|
|
|
| def single_head_split_window_attention( |
| q, |
| k, |
| v, |
| num_splits=1, |
| with_shift=False, |
| h=None, |
| w=None, |
| attn_mask=None, |
| ): |
| |
| |
| assert q.dim() == k.dim() == v.dim() == 3 |
|
|
| assert h is not None and w is not None |
| assert q.size(1) == h * w |
|
|
| b, _, c = q.size() |
|
|
| b_new = b * num_splits * num_splits |
|
|
| window_size_h = h // num_splits |
| window_size_w = w // num_splits |
|
|
| q = q.view(b, h, w, c) |
| k = k.view(b, h, w, c) |
| v = v.view(b, h, w, c) |
|
|
| scale_factor = c**0.5 |
|
|
| if with_shift: |
| assert attn_mask is not None |
| shift_size_h = window_size_h // 2 |
| shift_size_w = window_size_w // 2 |
|
|
| q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
| k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
| v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) |
|
|
| q = split_feature(q, num_splits=num_splits, channel_last=True) |
| k = split_feature(k, num_splits=num_splits, channel_last=True) |
| v = split_feature(v, num_splits=num_splits, channel_last=True) |
|
|
| scores = ( |
| torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) / scale_factor |
| ) |
|
|
| if with_shift: |
| scores += attn_mask.repeat(b, 1, 1) |
|
|
| attn = torch.softmax(scores, dim=-1) |
|
|
| out = torch.matmul(attn, v.view(b_new, -1, c)) |
|
|
| out = merge_splits( |
| out.view(b_new, h // num_splits, w // num_splits, c), |
| num_splits=num_splits, |
| channel_last=True, |
| ) |
|
|
| |
| if with_shift: |
| out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) |
|
|
| out = out.view(b, -1, c) |
|
|
| return out |
|
|
|
|
| def single_head_split_window_attention_1d( |
| q, |
| k, |
| v, |
| relative_position_bias=None, |
| num_splits=1, |
| with_shift=False, |
| h=None, |
| w=None, |
| attn_mask=None, |
| ): |
| |
|
|
| assert h is not None and w is not None |
| assert q.size(1) == h * w |
|
|
| b, _, c = q.size() |
|
|
| b_new = b * num_splits * h |
|
|
| window_size_w = w // num_splits |
|
|
| q = q.view(b * h, w, c) |
| k = k.view(b * h, w, c) |
| v = v.view(b * h, w, c) |
|
|
| scale_factor = c**0.5 |
|
|
| if with_shift: |
| assert attn_mask is not None |
| shift_size_w = window_size_w // 2 |
|
|
| q = torch.roll(q, shifts=-shift_size_w, dims=1) |
| k = torch.roll(k, shifts=-shift_size_w, dims=1) |
| v = torch.roll(v, shifts=-shift_size_w, dims=1) |
|
|
| q = split_feature_1d(q, num_splits=num_splits) |
| k = split_feature_1d(k, num_splits=num_splits) |
| v = split_feature_1d(v, num_splits=num_splits) |
|
|
| scores = ( |
| torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1)) / scale_factor |
| ) |
|
|
| if with_shift: |
| |
| scores += attn_mask.repeat(b * h, 1, 1) |
|
|
| attn = torch.softmax(scores, dim=-1) |
|
|
| out = torch.matmul(attn, v.view(b_new, -1, c)) |
|
|
| out = merge_splits_1d(out, h, num_splits=num_splits) |
|
|
| |
| if with_shift: |
| out = torch.roll(out, shifts=shift_size_w, dims=2) |
|
|
| out = out.view(b, -1, c) |
|
|
| return out |
|
|
|
|
| class SelfAttnPropagation(nn.Module): |
| """ |
| flow propagation with self-attention on feature |
| query: feature0, key: feature0, value: flow |
| """ |
|
|
| def __init__( |
| self, |
| in_channels, |
| **kwargs, |
| ): |
| super().__init__() |
|
|
| self.q_proj = nn.Linear(in_channels, in_channels) |
| self.k_proj = nn.Linear(in_channels, in_channels) |
|
|
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| def forward( |
| self, |
| feature0, |
| flow, |
| local_window_attn=False, |
| local_window_radius=1, |
| **kwargs, |
| ): |
| |
| if local_window_attn: |
| return self.forward_local_window_attn( |
| feature0, flow, local_window_radius=local_window_radius |
| ) |
|
|
| b, c, h, w = feature0.size() |
|
|
| query = feature0.view(b, c, h * w).permute(0, 2, 1) |
|
|
| |
| |
| |
| |
| |
| |
| query = self.q_proj(query) |
| key = self.k_proj(query) |
|
|
| value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) |
|
|
| scores = torch.matmul(query, key.permute(0, 2, 1)) / (c**0.5) |
| prob = torch.softmax(scores, dim=-1) |
|
|
| out = torch.matmul(prob, value) |
| out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) |
|
|
| return out |
|
|
| def forward_local_window_attn( |
| self, |
| feature0, |
| flow, |
| local_window_radius=1, |
| ): |
| assert flow.size(1) == 2 or flow.size(1) == 1 |
| assert local_window_radius > 0 |
|
|
| b, c, h, w = feature0.size() |
|
|
| value_channel = flow.size(1) |
|
|
| feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1)).reshape( |
| b * h * w, 1, c |
| ) |
|
|
| kernel_size = 2 * local_window_radius + 1 |
|
|
| feature0_proj = ( |
| self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)) |
| .permute(0, 2, 1) |
| .reshape(b, c, h, w) |
| ) |
|
|
| feature0_window = F.unfold( |
| feature0_proj, kernel_size=kernel_size, padding=local_window_radius |
| ) |
|
|
| feature0_window = ( |
| feature0_window.view(b, c, kernel_size**2, h, w) |
| .permute(0, 3, 4, 1, 2) |
| .reshape(b * h * w, c, kernel_size**2) |
| ) |
|
|
| flow_window = F.unfold( |
| flow, kernel_size=kernel_size, padding=local_window_radius |
| ) |
|
|
| flow_window = ( |
| flow_window.view(b, value_channel, kernel_size**2, h, w) |
| .permute(0, 3, 4, 2, 1) |
| .reshape(b * h * w, kernel_size**2, value_channel) |
| ) |
|
|
| scores = torch.matmul(feature0_reshape, feature0_window) / ( |
| c**0.5 |
| ) |
|
|
| prob = torch.softmax(scores, dim=-1) |
|
|
| out = ( |
| torch.matmul(prob, flow_window) |
| .view(b, h, w, value_channel) |
| .permute(0, 3, 1, 2) |
| .contiguous() |
| ) |
|
|
| return out |
|
|