| ''' |
| Code Reference: |
| Adapted from https://github.com/GT-RIPL/CODA-Prompt |
| ''' |
|
|
| import os |
| import timm |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from functools import partial |
|
|
| from timm.models.vision_transformer import _cfg, PatchEmbed |
| from timm.models.registry import register_model |
| from timm.models.layers import trunc_normal_, DropPath |
| from timm.models.helpers import named_apply, adapt_input_conv |
| from .prompt import L2P, CodaPrompt, DualPrompt |
| from .transformer import MultiHeadAttention_LoRA, VisionTransformer, VisionTransformer_CL_LoRA |
|
|
| def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): |
| |
| embedding_size = pos_embed_checkpoint.shape[-1] |
| num_patches = visual_encoder.patch_embed.num_patches |
| num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches |
| |
| orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
| |
| new_size = int(num_patches ** 0.5) |
|
|
| if orig_size!=new_size: |
| |
| extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
| |
| pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
| pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
| pos_tokens = torch.nn.functional.interpolate( |
| pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) |
| pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
| new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
| print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) |
| |
| return new_pos_embed |
| else: |
| return pos_embed_checkpoint |
| |
| class ViTZoo(nn.Module): |
| def __init__(self, pretrained = False, model_name='vit_base_patch16_224', attn_layer='MultiHeadAttention', **kwargs): |
| super(ViTZoo, self).__init__() |
| |
| self.task_id = None |
| self.feat_dim = 768 |
|
|
| self.feat = VisionTransformer(img_size=224, patch_size=16, embed_dim=768, depth=12, |
| num_heads=12, ckpt_layer=0, |
| drop_path_rate=0, attn_layer=attn_layer, |
| **kwargs |
| ) |
|
|
| if pretrained: |
| print(f'Using pretrained model : {model_name}') |
|
|
| if model_name == 'vit_base_patch16_224.augreg2_in21k_ft_in1k' and os.path.exists('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt'): |
| |
| load_dict = torch.load('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt') |
| else: |
| load_dict = timm.create_model(model_name, pretrained = pretrained).state_dict() |
| |
| key_mapping = { |
| ".norm1.": ".ln_1.", |
| ".norm2.": ".ln_2.", |
| "blocks.": "transformer.blocks." |
| } |
|
|
| modified_load_dict = {} |
| for key in load_dict.keys(): |
| new_key = key |
| for old_key, mapped_key in key_mapping.items(): |
| if old_key in new_key: |
| new_key = new_key.replace(old_key, mapped_key) |
|
|
| modified_load_dict[new_key] = load_dict[key] |
|
|
| self.feat.load_state_dict(modified_load_dict, strict = False) |
|
|
| self.prompt = None |
| self.prompt_flag = '' |
| |
| def create_prompt(self, prompt_flag, **kwargs): |
| self.prompt_flag = prompt_flag |
|
|
| if self.prompt_flag == 'l2p': |
| self.prompt = L2P(**kwargs) |
| elif self.prompt_flag == 'dual': |
| self.prompt = DualPrompt(768, **kwargs) |
| elif self.prompt_flag == 'coda': |
| self.prompt = CodaPrompt(768, **kwargs) |
| |
| |
| def forward(self, image, text=None, pen=False, train=False, **kwargs): |
|
|
| if self.prompt_flag == 'l2p': |
|
|
| with torch.no_grad(): |
| self.eval() |
| cls_features = self.feat(image, prompt_flag = self.prompt_flag) |
|
|
| if train: |
| self.train() |
|
|
| out, reduce_sim = self.feat( |
| x = image, |
| prompt = self.prompt, |
| cls_features = cls_features, |
| prompt_flag = self.prompt_flag |
| ) |
|
|
| return out, reduce_sim |
|
|
| elif self.prompt is not None: |
| with torch.no_grad(): |
| q, _ = self.feat(image) |
| q = q[:,0,:] |
|
|
| |
| out, prompt_loss = self.feat(image, prompt=self.prompt, q=q, train=train, task_id=self.task_id) |
| out = out[:,0,:] |
| else: |
| out, _ = self.feat(image, **kwargs) |
| if len(out.shape) == 3: |
| out = out[:,0,:] |
| |
| out = out.view(out.size(0), -1) |
|
|
| if self.prompt is not None and train: |
| return out, prompt_loss |
| else: |
| return out |
|
|
| class ViT_in21k_adapter(nn.Module): |
| def __init__(self, pretrained=False, **kwargs): |
| super(ViT_in21k_adapter, self).__init__() |
|
|
| self.task_id = None |
| self.feat_dim = 768 |
| |
| if pretrained: |
| print("Using pretrained model") |
| from core.model.backbone.petl import vision_transformer_adapter |
| from easydict import EasyDict |
|
|
| tuning_config = EasyDict( |
| |
| ffn_adapt=True, |
| ffn_option="parallel", |
| ffn_adapter_layernorm_option="none", |
| ffn_adapter_init_option="lora", |
| ffn_adapter_scalar="0.1", |
| ffn_num=64, |
| d_model=768, |
| |
| vpt_on=False, |
| vpt_num=0, |
| ) |
|
|
| zoo_model = vision_transformer_adapter.vit_base_patch16_224_in21k_adapter(num_classes=0, |
| global_pool=False, drop_path_rate=0.0, tuning_config=tuning_config) |
| zoo_model.out_dim=768 |
| zoo_model.eval() |
|
|
| self.prompt = None |
| |
| |
| self.feat = zoo_model |
| |
| def create_prompt(self, prompt_flag, **kwargs): |
| self.prompt_flag = prompt_flag |
| |
| |
| if self.prompt_flag == 'l2p': |
| self.prompt = L2P(768, **kwargs) |
| elif self.prompt_flag == 'dual': |
| self.prompt = DualPrompt(768, **kwargs) |
| elif self.prompt_flag == 'coda': |
| self.prompt = CodaPrompt(768, **kwargs) |
| |
| |
| def forward(self, x, pen=False, train=False): |
| if self.prompt is not None: |
| with torch.no_grad(): |
| q, _ = self.feat(x) |
| q = q[:,0,:] |
| out, prompt_loss = self.feat(x, prompt=self.prompt, q=q, train=train, task_id=self.task_id) |
| out = out[:,0,:] |
| else: |
| out = self.feat(x) |
| |
| out = out.view(out.size(0), -1) |
| |
| |
| if self.prompt is not None and train: |
| return out, prompt_loss |
| else: |
| return out |
|
|
| class ViT_CL_LoRA(nn.Module): |
| def __init__(self, pretrained = False, model_name='vit_base_patch16_224', attn_layer='MultiHeadAttention', **kwargs): |
| super().__init__() |
|
|
| self.task_id = None |
| self.feat_dim = 768 |
| |
| self.feat = VisionTransformer_CL_LoRA(img_size=224, patch_size=16, embed_dim=768, depth=12, |
| num_heads=12, ckpt_layer=0, |
| drop_path_rate=0, attn_layer=attn_layer, |
| **kwargs |
| ) |
|
|
| if pretrained: |
| print(f'Using pretrained model : {model_name}') |
|
|
| if model_name == 'vit_base_patch16_224.augreg2_in21k_ft_in1k' and os.path.exists('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt'): |
| |
| load_dict = torch.load('/home/lvqiexuan/.cache/torch/hub/checkpoints/vit_base_patch16_224.augreg2_in21k_ft_in1k.pt') |
| else: |
| load_dict = timm.create_model(model_name, pretrained = pretrained).state_dict() |
| |
| key_mapping = { |
| ".norm1.": ".ln_1.", |
| ".norm2.": ".ln_2.", |
| "blocks.": "transformer.blocks." |
| } |
|
|
| modified_load_dict = {} |
| for key in load_dict.keys(): |
| new_key = key |
| for old_key, mapped_key in key_mapping.items(): |
| if old_key in new_key: |
| new_key = new_key.replace(old_key, mapped_key) |
|
|
| modified_load_dict[new_key] = load_dict[key] |
|
|
| self.feat.load_state_dict(modified_load_dict, strict = False) |
|
|
| self.prompt = None |
| self.prompt_flag = '' |
|
|
| |
| def forward(self, image, test, text=None, pen=False, train=False, **kwargs): |
|
|
| if self.prompt_flag == 'l2p': |
|
|
| with torch.no_grad(): |
| self.eval() |
| cls_features = self.feat(image, prompt_flag = self.prompt_flag) |
|
|
| if train: |
| self.train() |
|
|
| out, reduce_sim = self.feat( |
| x = image, |
| prompt = self.prompt, |
| cls_features = cls_features, |
| prompt_flag = self.prompt_flag |
| ) |
|
|
| return out, reduce_sim |
|
|
| elif self.prompt is not None: |
| with torch.no_grad(): |
| q, _ = self.feat(image) |
| q = q[:,0,:] |
|
|
| |
| out, prompt_loss = self.feat(image, prompt=self.prompt, q=q, train=train, task_id=self.task_id) |
| out = out[:,0,:] |
| else: |
| out, _ = self.feat(image, test, **kwargs) |
| if len(out.shape) == 3: |
| out = out[:,0,:] |
| |
| out = out.view(out.size(0), -1) |
|
|
| if self.prompt is not None and train: |
| return out, prompt_loss |
| else: |
| return out |
|
|
| def forward_proto(self, x, adapt_index): |
| return self.feat.forward_proto(x, adapt_index) |
|
|
| def forward_general_cls(self, x, t_idx): |
| return self.feat.forward_general_cls(x, t_idx) |
|
|
| def add_adapter_to_list(self): |
| self.feat.add_adapter_to_list() |
|
|
| def vit_pt_imnet(pretrained=False, **kwargs): |
| return ViTZoo(pretrained, **kwargs) |
|
|
| def vit_pt_imnet_in21k_adapter(pretrained=False, **kwargs): |
| return ViT_in21k_adapter(pretrained, **kwargs) |
|
|
| def vit_cl_lora(pretrained=False, **kwargs): |
| return ViT_CL_LoRA(pretrained, **kwargs) |