| """ timm model adapter |
| |
| Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. |
| """ |
| import logging |
| from collections import OrderedDict |
|
|
| import torch |
| import torch.nn as nn |
|
|
| try: |
| import timm |
| from timm.models.layers import Mlp, to_2tuple |
| from timm.models.layers.attention_pool2d import RotAttentionPool2d |
| from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d |
| from timm.models.layers.adaptive_avgmax_pool import SelectAdaptivePool2d |
|
|
| from einops.layers.torch import Reduce |
|
|
| except ImportError as e: |
| timm = None |
|
|
| from llava.open_clip.utils import freeze_batch_norm_2d |
|
|
|
|
| class HeadModule(nn.Module): |
|
|
| def __init__(self, layers): |
| super(HeadModule, self).__init__() |
| self.pool = layers.get("pool", None) |
| self.drop = layers.get("drop", None) |
| self.proj = layers.get("proj", None) |
| self.mlp = layers.get("mlp", None) |
|
|
| def forward(self, x, pool=True): |
| if pool and self.pool is not None: |
| x = self.pool(x) |
| if self.mlp is not None: |
| x = self.mlp(x) |
| if self.proj is not None: |
| assert self.drop is not None |
| x = self.drop(x) |
| x = self.proj(x) |
| return x |
|
|
|
|
| class TimmModel(nn.Module): |
| """ timm model adapter |
| # FIXME this adapter is a work in progress, may change in ways that break weight compat |
| """ |
|
|
| def __init__( |
| self, |
| model_name, |
| embed_dim, |
| image_size=224, |
| pool='avg', |
| proj='linear', |
| proj_bias=False, |
| drop=0., |
| pretrained=False): |
| super().__init__() |
| if timm is None: |
| raise RuntimeError("Please `pip install timm` to use timm models.") |
|
|
| self.image_size = to_2tuple(image_size) |
| self.trunk = timm.create_model(model_name, pretrained=pretrained) |
| feat_size = self.trunk.default_cfg.get('pool_size', None) |
| feature_ndim = 1 if not feat_size else 2 |
| if pool in ('abs_attn', 'rot_attn', 'global_max'): |
| assert pool == 'global_max' or feature_ndim == 2 |
| |
| self.trunk.reset_classifier(0, global_pool='') |
| else: |
| |
| reset_kwargs = dict(global_pool=pool) if pool else {} |
| self.trunk.reset_classifier(0, **reset_kwargs) |
| prev_chs = self.trunk.num_features |
|
|
| head_layers = OrderedDict() |
| if pool == 'abs_attn': |
| head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) |
| prev_chs = embed_dim |
| elif pool == 'rot_attn': |
| head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) |
| prev_chs = embed_dim |
| elif pool == "global_max": |
| head_layers['pool'] = Reduce('b n c -> b c', 'max') |
| prev_chs = embed_dim |
| proj = "" |
| else: |
| assert proj, 'projection layer needed if non-attention pooling is used.' |
|
|
| |
| if proj == 'linear': |
| head_layers['drop'] = nn.Dropout(drop) |
| head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) |
| elif proj == 'mlp': |
| head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) |
|
|
| self.head = HeadModule(head_layers) |
|
|
| def lock(self, unlocked_groups=0, freeze_bn_stats=False): |
| """ lock modules |
| Args: |
| unlocked_groups (int): leave last n layer groups unlocked (default: 0) |
| """ |
| if not unlocked_groups: |
| |
| for param in self.trunk.parameters(): |
| param.requires_grad = False |
| if freeze_bn_stats: |
| freeze_batch_norm_2d(self.trunk) |
| else: |
| |
| try: |
| |
| from timm.models.helpers import group_parameters, group_modules |
| except ImportError: |
| raise RuntimeError( |
| 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') |
| matcher = self.trunk.group_matcher() |
| gparams = group_parameters(self.trunk, matcher) |
| max_layer_id = max(gparams.keys()) |
| max_layer_id = max_layer_id - unlocked_groups |
| for group_idx in range(max_layer_id + 1): |
| group = gparams[group_idx] |
| for param in group: |
| self.trunk.get_parameter(param).requires_grad = False |
| if freeze_bn_stats: |
| gmodules = group_modules(self.trunk, matcher, reverse=True) |
| gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} |
| freeze_batch_norm_2d(self.trunk, gmodules) |
|
|
| @torch.jit.ignore |
| def set_grad_checkpointing(self, enable=True): |
| try: |
| self.trunk.set_grad_checkpointing(enable) |
| except Exception as e: |
| logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') |
|
|
| def forward(self, x, pool=True): |
| x = self.trunk(x) |
| x = self.head(x, pool=pool) |
| return x |
|
|