jsflow / REG /utils.py
xiangzai's picture
Add files using upload-large-folder tool
b65e56d verified
import os
from torchvision.datasets.utils import download_url
import torch
import torchvision.models as torchvision_models
import timm
from models import mocov3_vit
import math
import warnings
# code from SiT repository
pretrained_models = {'last.pt'}
def download_model(model_name):
"""
Downloads a pre-trained SiT model from the web.
"""
assert model_name in pretrained_models
local_path = f'pretrained_models/{model_name}'
if not os.path.isfile(local_path):
os.makedirs('pretrained_models', exist_ok=True)
web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/cxedbs4da5ugjq5wg3zrg/last.pt?rlkey=8otgrdkno0nd89po3dpwngwcc&st=apcc645o&dl=0'
download_url(web_path, 'pretrained_models', filename=model_name)
model = torch.load(local_path, map_location=lambda storage, loc: storage)
return model
def fix_mocov3_state_dict(state_dict):
for k in list(state_dict.keys()):
# retain only base_encoder up to before the embedding layer
if k.startswith('module.base_encoder'):
# fix naming bug in checkpoint
new_k = k[len("module.base_encoder."):]
if "blocks.13.norm13" in new_k:
new_k = new_k.replace("norm13", "norm1")
if "blocks.13.mlp.fc13" in k:
new_k = new_k.replace("fc13", "fc1")
if "blocks.14.norm14" in k:
new_k = new_k.replace("norm14", "norm2")
if "blocks.14.mlp.fc14" in k:
new_k = new_k.replace("fc14", "fc2")
# remove prefix
if 'head' not in new_k and new_k.split('.')[0] != 'fc':
state_dict[new_k] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
if 'pos_embed' in state_dict.keys():
state_dict['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
state_dict['pos_embed'], [16, 16],
)
return state_dict
@torch.no_grad()
def load_encoders(enc_type, device, resolution=256):
assert (resolution == 256) or (resolution == 512)
enc_names = enc_type.split(',')
encoders, architectures, encoder_types = [], [], []
for enc_name in enc_names:
encoder_type, architecture, model_config = enc_name.split('-')
# Currently, we only support 512x512 experiments with DINOv2 encoders.
if resolution == 512:
if encoder_type != 'dinov2':
raise NotImplementedError(
"Currently, we only support 512x512 experiments with DINOv2 encoders."
)
architectures.append(architecture)
encoder_types.append(encoder_type)
if encoder_type == 'mocov3':
if architecture == 'vit':
if model_config == 's':
encoder = mocov3_vit.vit_small()
elif model_config == 'b':
encoder = mocov3_vit.vit_base()
elif model_config == 'l':
encoder = mocov3_vit.vit_large()
ckpt = torch.load(f'./ckpts/mocov3_vit{model_config}.pth')
state_dict = fix_mocov3_state_dict(ckpt['state_dict'])
del encoder.head
encoder.load_state_dict(state_dict, strict=True)
encoder.head = torch.nn.Identity()
elif architecture == 'resnet':
raise NotImplementedError()
encoder = encoder.to(device)
encoder.eval()
elif 'dinov2' in encoder_type:
import timm
if 'reg' in encoder_type:
try:
encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
f'dinov2_vit{model_config}14_reg', source='local')
except:
encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14_reg')
else:
try:
encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main',
f'dinov2_vit{model_config}14', source='local')
except:
encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14')
print(f"Now you are using the {enc_name} as the aligning model")
del encoder.head
patch_resolution = 16 * (resolution // 256)
encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
encoder.pos_embed.data, [patch_resolution, patch_resolution],
)
encoder.head = torch.nn.Identity()
encoder = encoder.to(device)
encoder.eval()
elif 'dinov1' == encoder_type:
import timm
from models import dinov1
encoder = dinov1.vit_base()
ckpt = torch.load(f'./ckpts/dinov1_vit{model_config}.pth')
if 'pos_embed' in ckpt.keys():
ckpt['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
ckpt['pos_embed'], [16, 16],
)
del encoder.head
encoder.head = torch.nn.Identity()
encoder.load_state_dict(ckpt, strict=True)
encoder = encoder.to(device)
encoder.forward_features = encoder.forward
encoder.eval()
elif encoder_type == 'clip':
import clip
from models.clip_vit import UpdatedVisionTransformer
encoder_ = clip.load(f"ViT-{model_config}/14", device='cpu')[0].visual
encoder = UpdatedVisionTransformer(encoder_).to(device)
#.to(device)
encoder.embed_dim = encoder.model.transformer.width
encoder.forward_features = encoder.forward
encoder.eval()
elif encoder_type == 'mae':
from models.mae_vit import vit_large_patch16
import timm
kwargs = dict(img_size=256)
encoder = vit_large_patch16(**kwargs).to(device)
with open(f"ckpts/mae_vit{model_config}.pth", "rb") as f:
state_dict = torch.load(f)
if 'pos_embed' in state_dict["model"].keys():
state_dict["model"]['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed(
state_dict["model"]['pos_embed'], [16, 16],
)
encoder.load_state_dict(state_dict["model"])
encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed(
encoder.pos_embed.data, [16, 16],
)
elif encoder_type == 'jepa':
from models.jepa import vit_huge
kwargs = dict(img_size=[224, 224], patch_size=14)
encoder = vit_huge(**kwargs).to(device)
with open(f"ckpts/ijepa_vit{model_config}.pth", "rb") as f:
state_dict = torch.load(f, map_location=device)
new_state_dict = dict()
for key, value in state_dict['encoder'].items():
new_state_dict[key[7:]] = value
encoder.load_state_dict(new_state_dict)
encoder.forward_features = encoder.forward
encoders.append(encoder)
return encoders, encoder_types, architectures
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def load_legacy_checkpoints(state_dict, encoder_depth):
new_state_dict = dict()
for key, value in state_dict.items():
if 'decoder_blocks' in key:
parts =key.split('.')
new_idx = int(parts[1]) + encoder_depth
parts[0] = 'blocks'
parts[1] = str(new_idx)
new_key = '.'.join(parts)
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
return new_state_dict