| |
| |
| |
| |
| |
| |
| |
|
|
| from torch import nn |
|
|
| from timm.models import register_model |
| from timm.models.vision_transformer import ( |
| VisionTransformer, |
| _create_vision_transformer as _timm_create_vision_transformer, |
| Mlp, |
| Block, |
| LayerScale as TIMMLayerScale, |
| ) |
|
|
| from . import dinov2_arch |
|
|
|
|
| @register_model |
| def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: |
| """ ViT-Tiny (Vit-Ti/16) |
| """ |
| model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3, weight_init='skip') |
| model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
| return model |
|
|
|
|
| @register_model |
| def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: |
| """ ViT-Small (ViT-S/16) |
| """ |
| model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, weight_init='skip') |
| model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
| return model |
|
|
|
|
| @register_model |
| def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: |
| """ ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929). |
| ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. |
| """ |
| model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, weight_init='skip') |
| model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) |
| return model |
|
|
|
|
| @register_model |
| def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: |
| """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929). |
| """ |
| model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16, weight_init='skip') |
| if pretrained: |
| |
| model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs)) |
| else: |
| model = _create_vision_transformer('vit_huge_patch16_224', pretrained=False, **dict(model_args, **kwargs)) |
| return model |
|
|
|
|
| @register_model |
| def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer: |
| """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929). |
| """ |
| model = vit_huge_patch16_224(pretrained=pretrained, weight_init='skip', **kwargs) |
|
|
| for m in model.modules(): |
| if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm): |
| m.norm = nn.LayerNorm(m.fc1.out_features) |
|
|
| return model |
|
|
|
|
| @register_model |
| def vit_giant_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: |
| """ ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929). |
| """ |
| model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24, weight_init='skip') |
| model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs)) |
| return model |
|
|
|
|
| @register_model |
| def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: |
| model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6, weight_init='skip') |
| model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs)) |
| return model |
|
|
|
|
| def _create_vision_transformer(*args, **kwargs): |
| model = _timm_create_vision_transformer(*args, **kwargs) |
| _patch_layer_scale(model) |
| return model |
|
|
|
|
| def _patch_layer_scale(model: VisionTransformer): |
| def replace_ls(old_ls: TIMMLayerScale): |
| new_ls = dinov2_arch.LayerScale(old_ls.gamma.shape[0], inplace=old_ls.inplace) |
| new_ls.load_state_dict(old_ls.state_dict()) |
| return new_ls |
|
|
| |
| |
| for mod in model.modules(): |
| if isinstance(mod, Block): |
| if isinstance(mod.ls1, TIMMLayerScale): |
| mod.ls1 = replace_ls(mod.ls1) |
| if isinstance(mod.ls2, TIMMLayerScale): |
| mod.ls2 = replace_ls(mod.ls2) |
| pass |
|
|
|
|