| """
|
| Loads model.
|
| Code adapted from LOST: https://github.com/valeoai/LOST
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
|
|
| from torchvision.models.resnet import resnet50
|
| from torchvision.models.vgg import vgg16
|
|
|
| import dino.vision_transformer as vits
|
|
|
|
|
| def get_model(arch, patch_size, device):
|
|
|
|
|
| url = None
|
| if "moco" in arch:
|
| if arch == "moco_vit_small" and patch_size == 16:
|
| url = "moco-v3/vit-s-300ep/vit-s-300ep.pth.tar"
|
| elif arch == "moco_vit_base" and patch_size == 16:
|
| url = "moco-v3/vit-b-300ep/vit-b-300ep.pth.tar"
|
| model = vits.__dict__[arch](num_classes=0)
|
| elif "mae" in arch:
|
| if arch == "mae_vit_base" and patch_size == 16:
|
| url = "mae/visualize/mae_visualize_vit_base.pth"
|
| model = vits.__dict__[arch](num_classes=0)
|
| elif "vit" in arch:
|
| if arch == "vit_small" and patch_size == 16:
|
| url = "dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
| elif arch == "vit_small" and patch_size == 8:
|
| url = "dino/dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
|
| elif arch == "vit_base" and patch_size == 16:
|
| url = "dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
| elif arch == "vit_base" and patch_size == 8:
|
| url = "dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
| elif arch == "resnet50":
|
| url = "dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
|
| model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
|
| else:
|
| raise NotImplementedError
|
|
|
| for p in model.parameters():
|
| p.requires_grad = False
|
|
|
| if url is not None:
|
| print(
|
| "Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
|
| )
|
| state_dict = torch.hub.load_state_dict_from_url(
|
| url="https://dl.fbaipublicfiles.com/" + url
|
| )
|
| if "moco" in arch:
|
| state_dict = state_dict['state_dict']
|
| for k in list(state_dict.keys()):
|
|
|
| if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.head'):
|
|
|
| state_dict[k[len("module.base_encoder."):]] = state_dict[k]
|
|
|
| del state_dict[k]
|
| elif "mae" in arch:
|
| state_dict = state_dict['model']
|
| for k in list(state_dict.keys()):
|
|
|
| if k.startswith('decoder') or k.startswith('mask_token'):
|
|
|
|
|
|
|
| del state_dict[k]
|
|
|
| msg = model.load_state_dict(state_dict, strict=True)
|
| print(
|
| "Pretrained weights found at {} and loaded with msg: {}".format(
|
| url, msg
|
| )
|
| )
|
| else:
|
| print(
|
| "There is no reference weights available for this model => We use random weights."
|
| )
|
|
|
|
|
| model.eval()
|
| model.to(device)
|
| return model
|
|
|