| import argparse |
| from typing import Any, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| from timm.models.layers import trunc_normal_ |
| from copy import deepcopy |
| import os |
| import torch.backends.cudnn as cudnn |
|
|
| import models.vision_transformer as vits |
|
|
| class vit(nn.Module): |
| |
| def __init__(self, model_size="base", freeze_transformer=True, pretrained_weights=None): |
| super(ibotvit, self).__init__() |
| self.model_size = model_size |
| self.freeze_transformer = freeze_transformer |
| self.pretrained_weights = pretrained_weights |
|
|
| |
| n_register_tokens = 4 |
| |
| if model_size == "vit_small": |
| self.embedding_size = 384 |
| |
| elif model_size == "vit_base": |
| self.embedding_size = 768 |
|
|
| elif model_size == "vit_large": |
| self.embedding_size = 1024 |
| |
| elif model_size == "giant": |
| self.embedding_size = 1536 |
|
|
| |
| model = vits.__dict__[model_size](patch_size=16) |
| self.transformer = deepcopy(model) |
|
|
| |
| if self.freeze_transformer: |
| for param in self.transformer.parameters(): |
| param.requires_grad = False |
|
|
| |
| if self.pretrained_weights and os.path.isfile(self.pretrained_weights): |
| state_dict = torch.load(self.pretrained_weights, map_location="cpu") |
| if 'teacher' in state_dict: |
| state_dict = state_dict['teacher'] |
| elif 'model' in state_dict: |
| state_dict = state_dict['model'] |
|
|
| |
| state_dict = { |
| (k[len("teacher."):] if k.startswith("teacher.") else k): v |
| for k, v in state_dict.items() |
| } |
| state_dict = { |
| (k[len("backbone."):] if k.startswith("backbone.") else k): v |
| for k, v in state_dict.items() |
| } |
| msg = self.transformer.load_state_dict(state_dict, strict=False) |
| print(model_size, msg) |
| |
|
|
| def forward(self, x): |
| x = self.transformer(x) |
|
|
| return x |
|
|
|
|
|
|
| def build_model(args): |
| |
| net = vit("vit_base", freeze_transformer=True, pretrained_weights=args.pretrained_weights) |
| net.cuda() |
| |
|
|
|
|
| return net |
|
|