| |
| import os |
| import logging |
| from collections import OrderedDict |
|
|
| import torch |
| from torch import nn |
| from einops import rearrange |
| from timm.models.layers import DropPath |
| from timm.models.registry import register_model |
|
|
| import torch.utils.checkpoint as checkpoint |
|
|
| |
|
|
| logger = logging.getLogger(__name__) |
|
|
| def load_temp_embed_with_mismatch(temp_embed_old, temp_embed_new, add_zero=True): |
| """ |
| Add/Remove extra temporal_embeddings as needed. |
| https://arxiv.org/abs/2104.00650 shows adding zero paddings works. |
| |
| temp_embed_old: (1, num_frames_old, 1, d) |
| temp_embed_new: (1, num_frames_new, 1, d) |
| add_zero: bool, if True, add zero, else, interpolate trained embeddings. |
| """ |
| |
| num_frms_new = temp_embed_new.shape[1] |
| num_frms_old = temp_embed_old.shape[1] |
| logger.info(f"Load temporal_embeddings, lengths: {num_frms_old}-->{num_frms_new}") |
| if num_frms_new > num_frms_old: |
| if add_zero: |
| temp_embed_new[ |
| :, :num_frms_old |
| ] = temp_embed_old |
| else: |
| temp_embed_new = interpolate_temporal_pos_embed(temp_embed_old, num_frms_new) |
| elif num_frms_new < num_frms_old: |
| temp_embed_new = temp_embed_old[:, :num_frms_new] |
| else: |
| temp_embed_new = temp_embed_old |
| return temp_embed_new |
|
|
|
|
| |
| MODEL_PATH = '' |
| _MODELS = { |
| "ViT-L/14": os.path.join(MODEL_PATH, "ViCLIP-L_InternVid-FLT-10M.pth"), |
| "ViT-B/16": os.path.join(MODEL_PATH, "ViCLIP-B-InternVid-FLT-10M.pth"), |
| } |
|
|
|
|
| class QuickGELU(nn.Module): |
| def forward(self, x): |
| return x * torch.sigmoid(1.702 * x) |
|
|
|
|
| class ResidualAttentionBlock(nn.Module): |
| def __init__(self, d_model, n_head, drop_path=0., attn_mask=None, dropout=0.): |
| super().__init__() |
|
|
| self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| |
| self.attn = nn.MultiheadAttention(d_model, n_head, dropout=dropout) |
| self.ln_1 = nn.LayerNorm(d_model) |
| self.mlp = nn.Sequential(OrderedDict([ |
| ("c_fc", nn.Linear(d_model, d_model * 4)), |
| ("gelu", QuickGELU()), |
| ("drop1", nn.Dropout(dropout)), |
| ("c_proj", nn.Linear(d_model * 4, d_model)), |
| ("drop2", nn.Dropout(dropout)), |
| ])) |
| self.ln_2 = nn.LayerNorm(d_model) |
| self.attn_mask = attn_mask |
|
|
| def attention(self, x): |
| self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
| return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] |
|
|
| def forward(self, x): |
| x = x + self.drop_path1(self.attention(self.ln_1(x))) |
| x = x + self.drop_path2(self.mlp(self.ln_2(x))) |
| return x |
|
|
|
|
| class Transformer(nn.Module): |
| def __init__(self, width, layers, heads, drop_path=0., checkpoint_num=0, dropout=0.): |
| super().__init__() |
| dpr = [x.item() for x in torch.linspace(0, drop_path, layers)] |
| self.resblocks = nn.ModuleList() |
| for idx in range(layers): |
| self.resblocks.append(ResidualAttentionBlock(width, heads, drop_path=dpr[idx], dropout=dropout)) |
| self.checkpoint_num = checkpoint_num |
|
|
| def forward(self, x): |
| for idx, blk in enumerate(self.resblocks): |
| if idx < self.checkpoint_num: |
| x = checkpoint.checkpoint(blk, x) |
| else: |
| x = blk(x) |
| return x |
|
|
|
|
| class VisionTransformer(nn.Module): |
| def __init__( |
| self, input_resolution, patch_size, width, layers, heads, output_dim=None, |
| kernel_size=1, num_frames=8, drop_path=0, checkpoint_num=0, dropout=0., |
| temp_embed=True, |
| ): |
| super().__init__() |
| self.output_dim = output_dim |
| self.conv1 = nn.Conv3d( |
| 3, width, |
| (kernel_size, patch_size, patch_size), |
| (kernel_size, patch_size, patch_size), |
| (0, 0, 0), bias=False |
| ) |
|
|
| scale = width ** -0.5 |
| self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
| self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) |
| self.ln_pre = nn.LayerNorm(width) |
| if temp_embed: |
| self.temporal_positional_embedding = nn.Parameter(torch.zeros(1, num_frames, width)) |
| |
| self.transformer = Transformer( |
| width, layers, heads, drop_path=drop_path, checkpoint_num=checkpoint_num, |
| dropout=dropout) |
|
|
| self.ln_post = nn.LayerNorm(width) |
| if output_dim is not None: |
| self.proj = nn.Parameter(torch.empty(width, output_dim)) |
| else: |
| self.proj = None |
| |
| self.dropout = nn.Dropout(dropout) |
|
|
| def get_num_layers(self): |
| return len(self.transformer.resblocks) |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| return {'positional_embedding', 'class_embedding', 'temporal_positional_embedding'} |
| |
| def mask_tokens(self, inputs, masking_prob=0.0): |
| B, L, _ = inputs.shape |
|
|
| |
| Lm = int(masking_prob * L) |
| masked_indices = torch.zeros(B, L) |
| indices = torch.argsort(torch.rand_like(masked_indices), dim=-1)[:, :Lm] |
| batch_indices = ( |
| torch.arange(masked_indices.shape[0]).unsqueeze(-1).expand_as(indices) |
| ) |
| masked_indices[batch_indices, indices] = 1 |
|
|
| masked_indices = masked_indices.bool() |
|
|
| return inputs[~masked_indices].reshape(B, -1, inputs.shape[-1]) |
|
|
| def forward(self, x, masking_prob=0.0): |
| x = self.conv1(x) |
| B, C, T, H, W = x.shape |
| x = x.permute(0, 2, 3, 4, 1).reshape(B * T, H * W, C) |
|
|
| x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) |
| x = x + self.positional_embedding.to(x.dtype) |
|
|
| |
| cls_tokens = x[:B, :1, :] |
| x = x[:, 1:] |
| x = rearrange(x, '(b t) n m -> (b n) t m', b=B, t=T) |
| if hasattr(self, 'temporal_positional_embedding'): |
| if x.size(1) == 1: |
| |
| x = x + self.temporal_positional_embedding.mean(1) |
| else: |
| x = x + self.temporal_positional_embedding |
| x = rearrange(x, '(b n) t m -> b (n t) m', b=B, t=T) |
|
|
| if masking_prob > 0.0: |
| x = self.mask_tokens(x, masking_prob) |
|
|
| x = torch.cat((cls_tokens, x), dim=1) |
|
|
| x = self.ln_pre(x) |
|
|
| x = x.permute(1, 0, 2) |
| x = self.transformer(x) |
|
|
| x = self.ln_post(x) |
|
|
| if self.proj is not None: |
| x = self.dropout(x[0]) @ self.proj |
| else: |
| x = x.permute(1, 0, 2) |
|
|
| return x |
|
|
|
|
| def inflate_weight(weight_2d, time_dim, center=True): |
| logger.info(f'Init center: {center}') |
| if center: |
| weight_3d = torch.zeros(*weight_2d.shape) |
| weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) |
| middle_idx = time_dim // 2 |
| weight_3d[:, :, middle_idx, :, :] = weight_2d |
| else: |
| weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) |
| weight_3d = weight_3d / time_dim |
| return weight_3d |
|
|
|
|
| def load_state_dict(model, state_dict, input_resolution=224, patch_size=16, center=True): |
| state_dict_3d = model.state_dict() |
| for k in state_dict.keys(): |
| if k in state_dict_3d.keys() and state_dict[k].shape != state_dict_3d[k].shape: |
| if len(state_dict_3d[k].shape) <= 2: |
| logger.info(f'Ignore: {k}') |
| continue |
| logger.info(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}') |
| time_dim = state_dict_3d[k].shape[2] |
| state_dict[k] = inflate_weight(state_dict[k], time_dim, center=center) |
|
|
| pos_embed_checkpoint = state_dict['positional_embedding'] |
| embedding_size = pos_embed_checkpoint.shape[-1] |
| num_patches = (input_resolution // patch_size) ** 2 |
| orig_size = int((pos_embed_checkpoint.shape[-2] - 1) ** 0.5) |
| new_size = int(num_patches ** 0.5) |
| if orig_size != new_size: |
| logger.info(f'Pos_emb from {orig_size} to {new_size}') |
| extra_tokens = pos_embed_checkpoint[:1] |
| pos_tokens = pos_embed_checkpoint[1:] |
| pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
| pos_tokens = torch.nn.functional.interpolate( |
| pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) |
| pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(0, 2) |
| new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0) |
| state_dict['positional_embedding'] = new_pos_embed |
| |
| message = model.load_state_dict(state_dict, strict=False) |
| logger.info(f"Load pretrained weights: {message}") |
|
|
|
|
| @register_model |
| def clip_joint_b16( |
| pretrained=False, input_resolution=224, kernel_size=1, |
| center=True, num_frames=8, drop_path=0., checkpoint_num=0, |
| dropout=0., |
| ): |
| model = VisionTransformer( |
| input_resolution=input_resolution, patch_size=16, |
| width=768, layers=12, heads=12, output_dim=512, |
| kernel_size=kernel_size, num_frames=num_frames, |
| drop_path=drop_path, checkpoint_num=checkpoint_num, |
| dropout=dropout, |
| ) |
| |
| if pretrained: |
| if isinstance(pretrained, str): |
| model_name = pretrained |
| else: |
| model_name = "ViT-B/16" |
| |
| logger.info('load pretrained weights') |
| state_dict = torch.load(_MODELS[model_name], map_location='cpu') |
| load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=16, center=center) |
| return model.eval() |
|
|
|
|
| @register_model |
| def clip_joint_l14( |
| pretrained=False, input_resolution=224, kernel_size=1, |
| center=True, num_frames=8, drop_path=0., checkpoint_num=0, |
| dropout=0., |
| ): |
| model = VisionTransformer( |
| input_resolution=input_resolution, patch_size=14, |
| width=1024, layers=24, heads=16, output_dim=768, |
| kernel_size=kernel_size, num_frames=num_frames, |
| drop_path=drop_path, checkpoint_num=checkpoint_num, |
| dropout=dropout, |
| ) |
| |
| if pretrained: |
| if isinstance(pretrained, str): |
| model_name = pretrained |
| else: |
| model_name = "ViT-L/14" |
| logger.info('load pretrained weights') |
| state_dict = torch.load(_MODELS[model_name], map_location='cpu') |
| load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center) |
| return model.eval() |
|
|
|
|
| @register_model |
| def clip_joint_l14_336( |
| pretrained=True, input_resolution=336, kernel_size=1, |
| center=True, num_frames=8, drop_path=0. |
| ): |
| raise NotImplementedError |
| model = VisionTransformer( |
| input_resolution=input_resolution, patch_size=14, |
| width=1024, layers=24, heads=16, output_dim=768, |
| kernel_size=kernel_size, num_frames=num_frames, |
| drop_path=drop_path, |
| ) |
| if pretrained: |
| logger.info('load pretrained weights') |
| state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu') |
| load_state_dict(model, state_dict, input_resolution=input_resolution, patch_size=14, center=center) |
| return model.eval() |
|
|
|
|
| def interpolate_pos_embed_vit(state_dict, new_model): |
| key = "vision_encoder.temporal_positional_embedding" |
| if key in state_dict: |
| vision_temp_embed_new = new_model.state_dict()[key] |
| vision_temp_embed_new = vision_temp_embed_new.unsqueeze(2) |
| vision_temp_embed_old = state_dict[key] |
| vision_temp_embed_old = vision_temp_embed_old.unsqueeze(2) |
|
|
| state_dict[key] = load_temp_embed_with_mismatch( |
| vision_temp_embed_old, vision_temp_embed_new, add_zero=False |
| ).squeeze(2) |
|
|
| key = "text_encoder.positional_embedding" |
| if key in state_dict: |
| text_temp_embed_new = new_model.state_dict()[key] |
| text_temp_embed_new = text_temp_embed_new.unsqueeze(0).unsqueeze(2) |
| text_temp_embed_old = state_dict[key] |
| text_temp_embed_old = text_temp_embed_old.unsqueeze(0).unsqueeze(2) |
|
|
| state_dict[key] = load_temp_embed_with_mismatch( |
| text_temp_embed_old, text_temp_embed_new, add_zero=False |
| ).squeeze(2).squeeze(0) |
| return state_dict |
|
|
|
|
| if __name__ == '__main__': |
| import time |
| from fvcore.nn import FlopCountAnalysis |
| from fvcore.nn import flop_count_table |
| import numpy as np |
|
|
| seed = 4217 |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| num_frames = 8 |
|
|
| |
| |
| model = clip_joint_l14(pretrained=False) |
|
|
| flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224)) |
| s = time.time() |
| logger.info(flop_count_table(flops, max_depth=1)) |
| logger.info(time.time()-s) |
| |
|
|