| |
|
|
| |
|
|
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from timm.models.layers import trunc_normal_, DropPath |
| from models.util import LayerNorm, GRN |
| from collections import OrderedDict |
| import math |
|
|
|
|
| class Block(nn.Module): |
| """ConvNeXtV2 Block. |
| |
| Args: |
| dim (int): Number of input channels. |
| drop_path (float): Stochastic depth rate. Default: 0.0 |
| """ |
|
|
| def __init__(self, dim, drop_path=0.0): |
| super().__init__() |
| self.dwconv = nn.Conv3d( |
| dim, dim, kernel_size=7, padding=3, groups=dim |
| ) |
| self.norm = LayerNorm(dim, eps=1e-6) |
| self.pwconv1 = nn.Linear( |
| dim, 4 * dim |
| ) |
| self.act = nn.GELU() |
| self.grn = GRN(4 * dim) |
| self.pwconv2 = nn.Linear(4 * dim, dim) |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
| def forward(self, x): |
| input = x |
| x = self.dwconv(x) |
| x = x.permute(0, 2, 3, 4, 1) |
| x = self.norm(x) |
| x = self.pwconv1(x) |
| x = self.act(x) |
| x = self.grn(x) |
| x = self.pwconv2(x) |
| x = x.permute(0, 4, 1, 2, 3) |
| x = input + self.drop_path(x) |
| return x |
|
|
|
|
| class ConvNeXtV2(nn.Module): |
| """ConvNeXt V2 |
| |
| Args: |
| in_chans (int): Number of input image channels. Default: 3 |
| num_classes (int): Number of classes for classification head. Default: 1000 |
| depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] |
| dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] |
| drop_path_rate (float): Stochastic depth rate. Default: 0. |
| head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. |
| """ |
|
|
| def __init__( |
| self, |
| in_chans=3, |
| depths=[3, 3, 9, 3], |
| dims=[96, 192, 384, 768], |
| drop_path_rate=0.0, |
| ): |
| super().__init__() |
| self.depths = depths |
| self.downsample_layers = ( |
| nn.ModuleList() |
| ) |
| stem = nn.Sequential( |
| nn.Conv3d(in_chans, dims[0], kernel_size=4, stride=4), |
| LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), |
| ) |
| self.downsample_layers.append(stem) |
| for i in range(3): |
| if i == 2: |
| stride = 1 |
| else: |
| stride = 2 |
| downsample_layer = nn.Sequential( |
| LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), |
| nn.Conv3d(dims[i], dims[i + 1], kernel_size=stride, stride=stride), |
| ) |
| self.downsample_layers.append(downsample_layer) |
|
|
| self.stages = ( |
| nn.ModuleList() |
| ) |
| dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] |
| cur = 0 |
| for i in range(4): |
| stage = nn.Sequential( |
| *[ |
| Block(dim=dims[i], drop_path=dp_rates[cur + j]) |
| for j in range(depths[i]) |
| ] |
| ) |
| self.stages.append(stage) |
| cur += depths[i] |
|
|
| self.norm = nn.LayerNorm(dims[-1], eps=1e-6) |
| |
|
|
| self.apply(self._init_weights) |
| |
| |
| self.embed_dim = dims[-1] |
|
|
| def _init_weights(self, m): |
| if isinstance(m, (nn.Conv3d, nn.Linear)): |
| trunc_normal_(m.weight, std=0.02) |
| nn.init.constant_(m.bias, 0) |
|
|
| def forward_features(self, x): |
| hidden_states_out = [] |
| for i in range(4): |
| x = self.downsample_layers[i](x) |
| x = self.stages[i](x) |
| hidden_states_out.append(x) |
| return self.norm(x.mean([-3, -2, -1])), hidden_states_out |
|
|
| def forward(self, x, ret_hids=False): |
| x, hidden_states_out = self.forward_features(x) |
| if ret_hids: |
| return x, hidden_states_out |
| else: |
| return x |
|
|
|
|
|
|
| def convnextv2_atto(**kwargs): |
| model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) |
| return model |
|
|
|
|
| def convnextv2_femto(**kwargs): |
| model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) |
| return model |
|
|
|
|
| def convnext_pico(**kwargs): |
| model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) |
| return model |
|
|
|
|
| def convnextv2_nano(**kwargs): |
| model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) |
| return model |
|
|
|
|
| def convnextv2_tiny(**kwargs): |
| model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) |
| return model |
|
|
|
|
| def convnextv2_base(**kwargs): |
| model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) |
| return model |
|
|
|
|
| def convnextv2_large(**kwargs): |
| model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) |
| return model |
|
|
|
|
| def convnextv2_huge(**kwargs): |
| model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) |
| return model |
|
|
|
|
| def remap_checkpoint_keys(ckpt): |
| new_ckpt = OrderedDict() |
| ckpt = ckpt["model"] |
|
|
| checkpoint_model_keys = list(ckpt.keys()) |
| for k in checkpoint_model_keys: |
| if "decoder" in k or "mask_token" in k or "proj" in k or "pred" in k: |
| print(f"Removing key {k} from pretrained checkpoint") |
| del ckpt[k] |
|
|
| for k, v in ckpt.items(): |
| if k.startswith("encoder"): |
| k = ".".join(k.split(".")[1:]) |
| if k.endswith("kernel"): |
| k = ".".join(k.split(".")[:-1]) |
| new_k = k + ".weight" |
| if len(v.shape) == 3: |
| kv, in_dim, out_dim = v.shape |
| |
| |
| |
| |
| ks = int( |
| round(kv ** (1 / 3)) |
| ) |
| new_ckpt[new_k] = ( |
| v.permute(2, 1, 0) |
| .reshape(out_dim, in_dim, ks, ks, ks) |
| .permute(0, 1, 4, 3, 2) |
| ) |
| elif len(v.shape) == 2: |
| kv, dim = v.shape |
| |
| |
| |
| if new_k == "downsample_layers.3.1.weight": |
| new_ckpt[new_k] = ( |
| v.permute(1, 0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) |
| ) |
| else: |
| ks = int(round(kv ** (1 / 3))) |
| new_ckpt[new_k] = ( |
| v.permute(1, 0) |
| .reshape(dim, 1, ks, ks, ks) |
| .permute(0, 1, 4, 3, 2) |
| ) |
| continue |
| elif "ln" in k or "linear" in k: |
| k = k.split(".") |
| k.pop(-2) |
| new_k = ".".join(k) |
| else: |
| new_k = k |
| new_ckpt[new_k] = v |
|
|
| |
| for k, v in new_ckpt.items(): |
| if k.endswith("bias") and len(v.shape) != 1: |
| new_ckpt[k] = v.reshape(-1) |
| elif "grn" in k: |
| new_ckpt[k] = v.unsqueeze(0).unsqueeze(1).unsqueeze(0) |
| return new_ckpt |
|
|
|
|
| def load_state_dict( |
| model, state_dict, prefix="", ignore_missing="relative_position_index" |
| ): |
| missing_keys = [] |
| unexpected_keys = [] |
| error_msgs = [] |
| |
| metadata = getattr(state_dict, "_metadata", None) |
| state_dict = state_dict.copy() |
| if metadata is not None: |
| state_dict._metadata = metadata |
|
|
| def load(module, prefix=""): |
| local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
| module._load_from_state_dict( |
| state_dict, |
| prefix, |
| local_metadata, |
| True, |
| missing_keys, |
| unexpected_keys, |
| error_msgs, |
| ) |
| for name, child in module._modules.items(): |
| if child is not None: |
| load(child, prefix + name + ".") |
|
|
| load(model, prefix=prefix) |
|
|
| warn_missing_keys = [] |
| ignore_missing_keys = [] |
| for key in missing_keys: |
| keep_flag = True |
| for ignore_key in ignore_missing.split("|"): |
| if ignore_key in key: |
| keep_flag = False |
| break |
| if keep_flag: |
| warn_missing_keys.append(key) |
| else: |
| ignore_missing_keys.append(key) |
|
|
| missing_keys = warn_missing_keys |
|
|
| if len(missing_keys) > 0: |
| print( |
| "Weights of {} not initialized from pretrained model: {}".format( |
| model.__class__.__name__, missing_keys |
| ) |
| ) |
| if len(unexpected_keys) > 0: |
| print( |
| "Weights from pretrained model not used in {}: {}".format( |
| model.__class__.__name__, unexpected_keys |
| ) |
| ) |
| if len(ignore_missing_keys) > 0: |
| print( |
| "Ignored weights of {} not initialized from pretrained model: {}".format( |
| model.__class__.__name__, ignore_missing_keys |
| ) |
| ) |
| if len(error_msgs) > 0: |
| print("\n".join(error_msgs)) |
|
|
|
|
| |
| |
| |
| |
|
|