| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .backbone import CNNEncoder |
| from .geometry import coords_grid |
| from .matching import ( |
| global_correlation_softmax_prototype, |
| local_correlation_softmax_prototype, |
| ) |
| from .transformer import FeatureTransformer |
| from .utils import feature_add_position |
|
|
|
|
| class UniMatch(nn.Module): |
| def __init__( |
| self, |
| num_scales=1, |
| feature_channels=128, |
| upsample_factor=8, |
| num_head=1, |
| ffn_dim_expansion=4, |
| num_transformer_layers=6, |
| bilinear_upsample=False, |
| corr_fn="global", |
| ): |
| super().__init__() |
|
|
| self.feature_channels = feature_channels |
| self.num_scales = num_scales |
| self.upsample_factor = upsample_factor |
| self.bilinear_upsample = bilinear_upsample |
| if corr_fn == "global": |
| self.corr_fn = global_correlation_softmax_prototype |
| elif corr_fn == "local": |
| self.corr_fn = local_correlation_softmax_prototype |
| else: |
| raise NotImplementedError(f"Correlation function {corr_fn} not implemented") |
|
|
| |
| self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) |
|
|
| |
| self.transformer = FeatureTransformer( |
| num_layers=num_transformer_layers, |
| d_model=feature_channels, |
| nhead=num_head, |
| ffn_dim_expansion=ffn_dim_expansion, |
| ) |
|
|
| |
| |
| if not bilinear_upsample: |
| self.upsampler = nn.Sequential( |
| nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(256, upsample_factor**2 * 9, 1, 1, 0), |
| ) |
|
|
| def extract_feature(self, img0, img1): |
| concat = torch.cat((img0, img1), dim=0) |
| features = self.backbone(concat) |
|
|
| |
| features = features[::-1] |
|
|
| feature0, feature1 = [], [] |
|
|
| for i in range(len(features)): |
| feature = features[i] |
| chunks = torch.chunk(feature, 2, 0) |
| feature0.append(chunks[0]) |
| feature1.append(chunks[1]) |
|
|
| return feature0, feature1 |
|
|
| def correlate_feature(self, feature0, feature1, attn_splits=2, attn_type="swin"): |
| feature0, feature1 = feature_add_position( |
| feature0, feature1, attn_splits, self.feature_channels |
| ) |
| feature0, feature1 = self.transformer( |
| feature0, |
| feature1, |
| attn_type=attn_type, |
| attn_num_splits=attn_splits, |
| ) |
| b, c, h, w = feature0.shape |
| feature0 = feature0.view(b, c, -1).permute(0, 2, 1) |
| feature1 = feature1.view(b, c, -1) |
| correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / ( |
| c**0.5 |
| ) |
| correlation = correlation.view(b, h * w, h * w) |
| return correlation |
|
|
| def forward( |
| self, |
| img0, |
| img1, |
| attn_type="swin", |
| attn_splits=2, |
| return_feature=False, |
| bidirectional=False, |
| cycle_consistency=False, |
| corr_mask=None, |
| ): |
| |
| feature0_list, feature1_list = self.extract_feature(img0, img1) |
| assert self.num_scales == 1 |
| scale_idx = 0 |
| feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] |
|
|
| if cycle_consistency: |
| |
| feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat( |
| (feature1, feature0), dim=0 |
| ) |
|
|
| |
| feature0, feature1 = feature_add_position( |
| feature0, feature1, attn_splits, self.feature_channels |
| ) |
|
|
| |
| feature0, feature1 = self.transformer( |
| feature0, |
| feature1, |
| attn_type=attn_type, |
| attn_num_splits=attn_splits, |
| ) |
| b, c, h, w = feature0.shape |
| |
| flow_coords = coords_grid(b, h, w).to(feature0.device) |
| |
| |
| query_results, correlation = self.corr_fn( |
| feature0, feature1, flow_coords, pred_bidir_flow=bidirectional, corr_mask=corr_mask |
| ) |
| if bidirectional: |
| flow_coords = torch.cat((flow_coords, flow_coords), dim=0) |
| up_feature = torch.cat((feature0, feature1), dim=0) |
| else: |
| up_feature = feature0 |
| flow = query_results - flow_coords |
| flow_up = self.upsample_flow(flow, up_feature, bilinear=self.bilinear_upsample) |
| if return_feature: |
| return flow_up, flow, correlation, feature0, feature1 |
| else: |
| return flow_up, flow, correlation |
|
|
| def forward_features( |
| self, |
| img0, |
| img1, |
| attn_type="swin", |
| attn_splits=2, |
| ): |
|
|
| feature0_list, feature1_list = self.extract_feature(img0, img1) |
| assert self.num_scales == 1 |
| scale_idx = 0 |
| feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] |
| |
| feature0, feature1 = feature_add_position( |
| feature0, feature1, attn_splits, self.feature_channels |
| ) |
|
|
| |
| feature0, feature1 = self.transformer( |
| feature0, |
| feature1, |
| attn_type=attn_type, |
| attn_num_splits=attn_splits, |
| ) |
| return feature0, feature1 |
|
|
| def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, is_depth=False): |
| if bilinear: |
| multiplier = 1 if is_depth else upsample_factor |
| up_flow = ( |
| F.interpolate( |
| flow, scale_factor=upsample_factor, mode="bilinear", align_corners=False |
| ) |
| * multiplier |
| ) |
| else: |
| concat = torch.cat((flow, feature), dim=1) |
| mask = self.upsampler(concat) |
| up_flow = upsample_flow_with_mask( |
| flow, mask, upsample_factor=self.upsample_factor, is_depth=is_depth |
| ) |
| return up_flow |
|
|
|
|
| def upsample_flow_with_mask(flow, up_mask, upsample_factor, is_depth=False): |
| |
|
|
| mask = up_mask |
| b, flow_channel, h, w = flow.shape |
| mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) |
| mask = torch.softmax(mask, dim=2) |
|
|
| multiplier = 1 if is_depth else upsample_factor |
| up_flow = F.unfold(multiplier * flow, [3, 3], padding=1) |
| up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) |
|
|
| up_flow = torch.sum(mask * up_flow, dim=2) |
| up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) |
| up_flow = up_flow.reshape( |
| b, flow_channel, upsample_factor * h, upsample_factor * w |
| ) |
|
|
| return up_flow |