import torch import torch.nn as nn import timm from timm.models.vision_transformer import PatchEmbed from functools import partial class vit(timm.models.vision_transformer.VisionTransformer): def __init__(self, global_pool=False, **kwargs): super(vit, self).__init__() self.global_pool = global_pool embed_dim = kwargs['embed_dim'] num_classes = kwargs['num_classes'] self.head = nn.Linear(embed_dim, num_classes, bias=True) if self.global_pool: norm_layer = kwargs['norm_layer'] embed_dim = kwargs['embed_dim'] self.fc_norm = norm_layer(embed_dim) del self.norm for param in self.parameters(): param.requires_grad = False for param in self.head.parameters(): param.requires_grad = True def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) if self.global_pool: x = x[:, 1:, :].mean(dim=1) outcome = self.fc_norm(x) else: x = self.norm(x) outcome = x[:, 0] return outcome def forward(self, x): x = self.forward_features(x) x = self.head(x) return x def vit_base_patch16(**kwargs): model = vit(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def vit_large_patch16(**kwargs): model = vit(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model def vit_huge_patch14(**kwargs): model = vit(patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) return model