# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import trunc_normal_, DropPath from models.util import LayerNorm, GRN from collections import OrderedDict import math class Block(nn.Module): """ConvNeXtV2 Block. Args: dim (int): Number of input channels. drop_path (float): Stochastic depth rate. Default: 0.0 """ def __init__(self, dim, drop_path=0.0): super().__init__() self.dwconv = nn.Conv3d( dim, dim, kernel_size=7, padding=3, groups=dim ) # depthwise conv self.norm = LayerNorm(dim, eps=1e-6) self.pwconv1 = nn.Linear( dim, 4 * dim ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.grn = GRN(4 * dim) self.pwconv2 = nn.Linear(4 * dim, dim) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x): input = x x = self.dwconv(x) x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W, D) -> (N, H, W, D, C) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.grn(x) x = self.pwconv2(x) x = x.permute(0, 4, 1, 2, 3) # (N, H, W, D, C) -> (N, C, H, W, D) x = input + self.drop_path(x) return x class ConvNeXtV2(nn.Module): """ConvNeXt V2 Args: in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] drop_path_rate (float): Stochastic depth rate. Default: 0. head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. """ def __init__( self, in_chans=3, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.0, ): super().__init__() self.depths = depths self.downsample_layers = ( nn.ModuleList() ) # stem and 3 intermediate downsampling conv layers stem = nn.Sequential( nn.Conv3d(in_chans, dims[0], kernel_size=4, stride=4), LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), ) self.downsample_layers.append(stem) for i in range(3): if i == 2: stride = 1 else: stride = 2 downsample_layer = nn.Sequential( LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), nn.Conv3d(dims[i], dims[i + 1], kernel_size=stride, stride=stride), ) self.downsample_layers.append(downsample_layer) self.stages = ( nn.ModuleList() ) # 4 feature resolution stages, each consisting of multiple residual blocks dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] cur = 0 for i in range(4): stage = nn.Sequential( *[ Block(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i]) ] ) self.stages.append(stage) cur += depths[i] self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer # self.head = nn.Linear(dims[-1], num_classes) self.apply(self._init_weights) # self.head.weight.data.mul_(head_init_scale) # self.head.bias.data.mul_(head_init_scale) self.embed_dim = dims[-1] def _init_weights(self, m): if isinstance(m, (nn.Conv3d, nn.Linear)): trunc_normal_(m.weight, std=0.02) nn.init.constant_(m.bias, 0) def forward_features(self, x): hidden_states_out = [] for i in range(4): x = self.downsample_layers[i](x) x = self.stages[i](x) hidden_states_out.append(x) return self.norm(x.mean([-3, -2, -1])), hidden_states_out # global average pooling, (N, C, H, W, D) -> (N, C) def forward(self, x, ret_hids=False): x, hidden_states_out = self.forward_features(x) if ret_hids: return x, hidden_states_out else: return x def convnextv2_atto(**kwargs): model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs) return model def convnextv2_femto(**kwargs): model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs) return model def convnext_pico(**kwargs): model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs) return model def convnextv2_nano(**kwargs): model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs) return model def convnextv2_tiny(**kwargs): model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) return model def convnextv2_base(**kwargs): model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) return model def convnextv2_large(**kwargs): model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) return model def convnextv2_huge(**kwargs): model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs) return model def remap_checkpoint_keys(ckpt): new_ckpt = OrderedDict() ckpt = ckpt["model"] checkpoint_model_keys = list(ckpt.keys()) for k in checkpoint_model_keys: if "decoder" in k or "mask_token" in k or "proj" in k or "pred" in k: print(f"Removing key {k} from pretrained checkpoint") del ckpt[k] for k, v in ckpt.items(): if k.startswith("encoder"): k = ".".join(k.split(".")[1:]) # remove encoder in the name if k.endswith("kernel"): k = ".".join(k.split(".")[:-1]) # remove kernel in the name new_k = k + ".weight" if len(v.shape) == 3: # resahpe standard convolution kv, in_dim, out_dim = v.shape # ks = int(math.sqrt(kv)) # # pow(kv, 1/3) # new_ckpt[new_k] = v.permute(2, 1, 0).\ # reshape(out_dim, in_dim, ks, ks).transpose(3, 2) ks = int( round(kv ** (1 / 3)) ) # calculate kernel size assuming cubic kernel new_ckpt[new_k] = ( v.permute(2, 1, 0) .reshape(out_dim, in_dim, ks, ks, ks) .permute(0, 1, 4, 3, 2) ) elif len(v.shape) == 2: # reshape depthwise convolution kv, dim = v.shape # ks = int(math.sqrt(kv)) # new_ckpt[new_k] = v.permute(1, 0).\ # reshape(dim, 1, ks, ks).transpose(3, 2) if new_k == "downsample_layers.3.1.weight": new_ckpt[new_k] = ( v.permute(1, 0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) ) else: ks = int(round(kv ** (1 / 3))) new_ckpt[new_k] = ( v.permute(1, 0) .reshape(dim, 1, ks, ks, ks) .permute(0, 1, 4, 3, 2) ) continue elif "ln" in k or "linear" in k: k = k.split(".") k.pop(-2) # remove ln and linear in the name new_k = ".".join(k) else: new_k = k new_ckpt[new_k] = v # reshape grn affine parameters and biases for k, v in new_ckpt.items(): if k.endswith("bias") and len(v.shape) != 1: new_ckpt[k] = v.reshape(-1) elif "grn" in k: new_ckpt[k] = v.unsqueeze(0).unsqueeze(1).unsqueeze(0) return new_ckpt def load_state_dict( model, state_dict, prefix="", ignore_missing="relative_position_index" ): missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs, ) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + ".") load(model, prefix=prefix) warn_missing_keys = [] ignore_missing_keys = [] for key in missing_keys: keep_flag = True for ignore_key in ignore_missing.split("|"): if ignore_key in key: keep_flag = False break if keep_flag: warn_missing_keys.append(key) else: ignore_missing_keys.append(key) missing_keys = warn_missing_keys if len(missing_keys) > 0: print( "Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys ) ) if len(unexpected_keys) > 0: print( "Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys ) ) if len(ignore_missing_keys) > 0: print( "Ignored weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, ignore_missing_keys ) ) if len(error_msgs) > 0: print("\n".join(error_msgs)) # if __name__ == 'main': # model = convnextv2_base().cuda() # x = torch.rand(1,3,256,256,32).cuda() # print(model(x).shape)