| import os |
| from torchvision.datasets.utils import download_url |
| import torch |
| import torchvision.models as torchvision_models |
| import timm |
| from models import mocov3_vit |
| import math |
| import warnings |
|
|
|
|
| |
| pretrained_models = {'last.pt'} |
|
|
| def download_model(model_name): |
| """ |
| Downloads a pre-trained SiT model from the web. |
| """ |
| assert model_name in pretrained_models |
| local_path = f'pretrained_models/{model_name}' |
| if not os.path.isfile(local_path): |
| os.makedirs('pretrained_models', exist_ok=True) |
| web_path = f'https://www.dl.dropboxusercontent.com/scl/fi/cxedbs4da5ugjq5wg3zrg/last.pt?rlkey=8otgrdkno0nd89po3dpwngwcc&st=apcc645o&dl=0' |
| download_url(web_path, 'pretrained_models', filename=model_name) |
| model = torch.load(local_path, map_location=lambda storage, loc: storage) |
| return model |
|
|
| def fix_mocov3_state_dict(state_dict): |
| for k in list(state_dict.keys()): |
| |
| if k.startswith('module.base_encoder'): |
| |
| new_k = k[len("module.base_encoder."):] |
| if "blocks.13.norm13" in new_k: |
| new_k = new_k.replace("norm13", "norm1") |
| if "blocks.13.mlp.fc13" in k: |
| new_k = new_k.replace("fc13", "fc1") |
| if "blocks.14.norm14" in k: |
| new_k = new_k.replace("norm14", "norm2") |
| if "blocks.14.mlp.fc14" in k: |
| new_k = new_k.replace("fc14", "fc2") |
| |
| if 'head' not in new_k and new_k.split('.')[0] != 'fc': |
| state_dict[new_k] = state_dict[k] |
| |
| del state_dict[k] |
| if 'pos_embed' in state_dict.keys(): |
| state_dict['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed( |
| state_dict['pos_embed'], [16, 16], |
| ) |
| return state_dict |
|
|
| @torch.no_grad() |
| def load_encoders(enc_type, device, resolution=256): |
| assert (resolution == 256) or (resolution == 512) |
| |
| enc_names = enc_type.split(',') |
| encoders, architectures, encoder_types = [], [], [] |
| for enc_name in enc_names: |
| encoder_type, architecture, model_config = enc_name.split('-') |
| |
| if resolution == 512: |
| if encoder_type != 'dinov2': |
| raise NotImplementedError( |
| "Currently, we only support 512x512 experiments with DINOv2 encoders." |
| ) |
|
|
| architectures.append(architecture) |
| encoder_types.append(encoder_type) |
| if encoder_type == 'mocov3': |
| if architecture == 'vit': |
| if model_config == 's': |
| encoder = mocov3_vit.vit_small() |
| elif model_config == 'b': |
| encoder = mocov3_vit.vit_base() |
| elif model_config == 'l': |
| encoder = mocov3_vit.vit_large() |
| ckpt = torch.load(f'./ckpts/mocov3_vit{model_config}.pth') |
| state_dict = fix_mocov3_state_dict(ckpt['state_dict']) |
| del encoder.head |
| encoder.load_state_dict(state_dict, strict=True) |
| encoder.head = torch.nn.Identity() |
| elif architecture == 'resnet': |
| raise NotImplementedError() |
| |
| encoder = encoder.to(device) |
| encoder.eval() |
|
|
| elif 'dinov2' in encoder_type: |
| import timm |
| if 'reg' in encoder_type: |
| try: |
| encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main', |
| f'dinov2_vit{model_config}14_reg', source='local') |
| except: |
| encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14_reg') |
| else: |
| try: |
| encoder = torch.hub.load('your_path/.cache/torch/hub/facebookresearch_dinov2_main', |
| f'dinov2_vit{model_config}14', source='local') |
| except: |
| encoder = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{model_config}14') |
|
|
| print(f"Now you are using the {enc_name} as the aligning model") |
| del encoder.head |
| patch_resolution = 16 * (resolution // 256) |
| encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed( |
| encoder.pos_embed.data, [patch_resolution, patch_resolution], |
| ) |
| encoder.head = torch.nn.Identity() |
| encoder = encoder.to(device) |
| encoder.eval() |
| |
| elif 'dinov1' == encoder_type: |
| import timm |
| from models import dinov1 |
| encoder = dinov1.vit_base() |
| ckpt = torch.load(f'./ckpts/dinov1_vit{model_config}.pth') |
| if 'pos_embed' in ckpt.keys(): |
| ckpt['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed( |
| ckpt['pos_embed'], [16, 16], |
| ) |
| del encoder.head |
| encoder.head = torch.nn.Identity() |
| encoder.load_state_dict(ckpt, strict=True) |
| encoder = encoder.to(device) |
| encoder.forward_features = encoder.forward |
| encoder.eval() |
|
|
| elif encoder_type == 'clip': |
| import clip |
| from models.clip_vit import UpdatedVisionTransformer |
| encoder_ = clip.load(f"ViT-{model_config}/14", device='cpu')[0].visual |
| encoder = UpdatedVisionTransformer(encoder_).to(device) |
| |
| encoder.embed_dim = encoder.model.transformer.width |
| encoder.forward_features = encoder.forward |
| encoder.eval() |
| |
| elif encoder_type == 'mae': |
| from models.mae_vit import vit_large_patch16 |
| import timm |
| kwargs = dict(img_size=256) |
| encoder = vit_large_patch16(**kwargs).to(device) |
| with open(f"ckpts/mae_vit{model_config}.pth", "rb") as f: |
| state_dict = torch.load(f) |
| if 'pos_embed' in state_dict["model"].keys(): |
| state_dict["model"]['pos_embed'] = timm.layers.pos_embed.resample_abs_pos_embed( |
| state_dict["model"]['pos_embed'], [16, 16], |
| ) |
| encoder.load_state_dict(state_dict["model"]) |
|
|
| encoder.pos_embed.data = timm.layers.pos_embed.resample_abs_pos_embed( |
| encoder.pos_embed.data, [16, 16], |
| ) |
|
|
| elif encoder_type == 'jepa': |
| from models.jepa import vit_huge |
| kwargs = dict(img_size=[224, 224], patch_size=14) |
| encoder = vit_huge(**kwargs).to(device) |
| with open(f"ckpts/ijepa_vit{model_config}.pth", "rb") as f: |
| state_dict = torch.load(f, map_location=device) |
| new_state_dict = dict() |
| for key, value in state_dict['encoder'].items(): |
| new_state_dict[key[7:]] = value |
| encoder.load_state_dict(new_state_dict) |
| encoder.forward_features = encoder.forward |
|
|
| encoders.append(encoder) |
| |
| return encoders, encoder_types, architectures |
|
|
|
|
| def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
| |
| |
| def norm_cdf(x): |
| |
| return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|
|
| if (mean < a - 2 * std) or (mean > b + 2 * std): |
| warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
| "The distribution of values may be incorrect.", |
| stacklevel=2) |
|
|
| with torch.no_grad(): |
| |
| |
| |
| l = norm_cdf((a - mean) / std) |
| u = norm_cdf((b - mean) / std) |
|
|
| |
| |
| tensor.uniform_(2 * l - 1, 2 * u - 1) |
|
|
| |
| |
| tensor.erfinv_() |
|
|
| |
| tensor.mul_(std * math.sqrt(2.)) |
| tensor.add_(mean) |
|
|
| |
| tensor.clamp_(min=a, max=b) |
| return tensor |
|
|
|
|
| def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
| return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
|
|
|
|
| def load_legacy_checkpoints(state_dict, encoder_depth): |
| new_state_dict = dict() |
| for key, value in state_dict.items(): |
| if 'decoder_blocks' in key: |
| parts =key.split('.') |
| new_idx = int(parts[1]) + encoder_depth |
| parts[0] = 'blocks' |
| parts[1] = str(new_idx) |
| new_key = '.'.join(parts) |
| new_state_dict[new_key] = value |
| else: |
| new_state_dict[key] = value |
| return new_state_dict |