rakib72642's picture
ready init project
08ec965
"""
Loads model.
Code adapted from LOST: https://github.com/valeoai/LOST
"""
import torch
import torch.nn as nn
from torchvision.models.resnet import resnet50
from torchvision.models.vgg import vgg16
import dino.vision_transformer as vits
#import moco.vits as vits_moco
def get_model(arch, patch_size, device):
# Initialize model with pretraining
url = None
if "moco" in arch:
if arch == "moco_vit_small" and patch_size == 16:
url = "moco-v3/vit-s-300ep/vit-s-300ep.pth.tar"
elif arch == "moco_vit_base" and patch_size == 16:
url = "moco-v3/vit-b-300ep/vit-b-300ep.pth.tar"
model = vits.__dict__[arch](num_classes=0)
elif "mae" in arch:
if arch == "mae_vit_base" and patch_size == 16:
url = "mae/visualize/mae_visualize_vit_base.pth"
model = vits.__dict__[arch](num_classes=0)
elif "vit" in arch:
if arch == "vit_small" and patch_size == 16:
url = "dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif arch == "vit_small" and patch_size == 8:
url = "dino/dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
elif arch == "vit_base" and patch_size == 16:
url = "dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif arch == "vit_base" and patch_size == 8:
url = "dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
elif arch == "resnet50":
url = "dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth"
model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
else:
raise NotImplementedError
for p in model.parameters():
p.requires_grad = False
if url is not None:
print(
"Since no pretrained weights have been provided, we load the reference pretrained DINO weights."
)
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/" + url
)
if "moco" in arch:
state_dict = state_dict['state_dict']
for k in list(state_dict.keys()):
# retain only base_encoder up to before the embedding layer
if k.startswith('module.base_encoder') and not k.startswith('module.base_encoder.head'):
# remove prefix
state_dict[k[len("module.base_encoder."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
elif "mae" in arch:
state_dict = state_dict['model']
for k in list(state_dict.keys()):
# retain only base_encoder up to before the embedding layer
if k.startswith('decoder') or k.startswith('mask_token'):
# remove prefix
#state_dict[k[len("module.base_encoder."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=True)
print(
"Pretrained weights found at {} and loaded with msg: {}".format(
url, msg
)
)
else:
print(
"There is no reference weights available for this model => We use random weights."
)
model.eval()
model.to(device)
return model