| from typing import Dict, Tuple |
| from PIL import Image |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from torchvision import transforms |
| from transformers import PreTrainedModel |
|
|
| from .dino import vit_small |
| from .unimatch import UniMatch |
| from .configuration_doduo import DoduoConfig |
|
|
| class DoduoModel(PreTrainedModel): |
| config_class = DoduoConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = CorrSegFlowNet( |
| dino_corr_mask_ratio=config.dino_corr_mask_ratio |
| ) |
|
|
| def forward(self, frame_src, frame_dst): |
| if isinstance(frame_src, Image.Image): |
| frame_src = self.model.process_frame(frame_src) |
| frame_dst = self.model.process_frame(frame_dst) |
| assert frame_src.shape == frame_dst.shape |
| return self.model(frame_src, frame_dst) |
| |
| class CorrSegFlowNet(nn.Module): |
| def __init__( |
| self, |
| dino_corr_mask_ratio: float = 0.1, |
| ): |
| super().__init__() |
|
|
| self.dino_corr_mask_ratio = dino_corr_mask_ratio |
| self.unimatch = UniMatch(bilinear_upsample=True) |
| self.dino = vit_small(patch_size=8, num_classes=0) |
| for k in self.dino.parameters(): |
| k.requires_grad = False |
|
|
| self.transform = transforms.Compose( |
| [ |
| lambda x: transforms.ToTensor()(x)[:3], |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ] |
| ) |
|
|
| def process_frame(self, frame): |
| device = next(self.parameters()).device |
| frame = self.transform(frame) |
| frame = frame.unsqueeze(0).to(device) |
| return frame |
|
|
| def forward( |
| self, |
| frame_src, |
| frame_dst, |
| ): |
| corr_mask = get_dino_corr_mask( |
| self.dino, |
| frame_src, |
| frame_dst, |
| mask_ratio=self.dino_corr_mask_ratio |
| ) |
|
|
| flow, flow_low, correlation, feature0, feature1 = self.unimatch( |
| frame_src, |
| frame_dst, |
| return_feature=True, |
| bidirectional=False, |
| cycle_consistency=False, |
| corr_mask=corr_mask, |
| ) |
| return flow |
|
|
| @torch.no_grad() |
| def extract_dino_feature(model, frame, return_h_w=False): |
| """frame: B, C, H, W""" |
| B = frame.shape[0] |
| out = model.get_intermediate_layers(frame, n=1)[0] |
| out = out[:, 1:, :] |
| h, w = int(frame.shape[2] / model.patch_embed.patch_size), int( |
| frame.shape[3] / model.patch_embed.patch_size |
| ) |
| dim = out.shape[-1] |
| out = out.reshape(B, -1, dim) |
| if return_h_w: |
| return out, h, w |
| return out |
|
|
| @torch.no_grad() |
| def get_dino_corr_mask( |
| model, frame_src, frame_dst, mask_ratio |
| ): |
| |
| |
| |
| |
| feat_1, h, w = extract_dino_feature(model, frame_src, return_h_w=True) |
| feat_2 = extract_dino_feature(model, frame_dst) |
|
|
| feat_1_norm = F.normalize(feat_1, dim=2, p=2) |
| feat_2_norm = F.normalize(feat_2, dim=2, p=2) |
| aff_raw = torch.einsum("bnc,bmc->bnm", [feat_1_norm, feat_2_norm]) |
|
|
| if mask_ratio <= 0: |
| |
| corr_mask = None |
| else: |
| if aff_raw.dtype == torch.float16: |
| aff_raw = aff_raw.float() |
| aff_percentile = torch.quantile(aff_raw, mask_ratio, 2, keepdim=True) |
| |
| corr_mask = aff_raw < aff_percentile |
| return corr_mask |