ProFound / models /convnextv2.py
Anonymise's picture
add necessary module
45461c9
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath
from models.util import LayerNorm, GRN
from collections import OrderedDict
import math
class Block(nn.Module):
"""ConvNeXtV2 Block.
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
"""
def __init__(self, dim, drop_path=0.0):
super().__init__()
self.dwconv = nn.Conv3d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(
dim, 4 * dim
) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(4 * dim)
self.pwconv2 = nn.Linear(4 * dim, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 4, 1) # (N, C, H, W, D) -> (N, H, W, D, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
x = x.permute(0, 4, 1, 2, 3) # (N, H, W, D, C) -> (N, C, H, W, D)
x = input + self.drop_path(x)
return x
class ConvNeXtV2(nn.Module):
"""ConvNeXt V2
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
self,
in_chans=3,
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 768],
drop_path_rate=0.0,
):
super().__init__()
self.depths = depths
self.downsample_layers = (
nn.ModuleList()
) # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv3d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
)
self.downsample_layers.append(stem)
for i in range(3):
if i == 2:
stride = 1
else:
stride = 2
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv3d(dims[i], dims[i + 1], kernel_size=stride, stride=stride),
)
self.downsample_layers.append(downsample_layer)
self.stages = (
nn.ModuleList()
) # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[
Block(dim=dims[i], drop_path=dp_rates[cur + j])
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
# self.head = nn.Linear(dims[-1], num_classes)
self.apply(self._init_weights)
# self.head.weight.data.mul_(head_init_scale)
# self.head.bias.data.mul_(head_init_scale)
self.embed_dim = dims[-1]
def _init_weights(self, m):
if isinstance(m, (nn.Conv3d, nn.Linear)):
trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
hidden_states_out = []
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
hidden_states_out.append(x)
return self.norm(x.mean([-3, -2, -1])), hidden_states_out # global average pooling, (N, C, H, W, D) -> (N, C)
def forward(self, x, ret_hids=False):
x, hidden_states_out = self.forward_features(x)
if ret_hids:
return x, hidden_states_out
else:
return x
def convnextv2_atto(**kwargs):
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[40, 80, 160, 320], **kwargs)
return model
def convnextv2_femto(**kwargs):
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[48, 96, 192, 384], **kwargs)
return model
def convnext_pico(**kwargs):
model = ConvNeXtV2(depths=[2, 2, 6, 2], dims=[64, 128, 256, 512], **kwargs)
return model
def convnextv2_nano(**kwargs):
model = ConvNeXtV2(depths=[2, 2, 8, 2], dims=[80, 160, 320, 640], **kwargs)
return model
def convnextv2_tiny(**kwargs):
model = ConvNeXtV2(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
return model
def convnextv2_base(**kwargs):
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
return model
def convnextv2_large(**kwargs):
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
return model
def convnextv2_huge(**kwargs):
model = ConvNeXtV2(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], **kwargs)
return model
def remap_checkpoint_keys(ckpt):
new_ckpt = OrderedDict()
ckpt = ckpt["model"]
checkpoint_model_keys = list(ckpt.keys())
for k in checkpoint_model_keys:
if "decoder" in k or "mask_token" in k or "proj" in k or "pred" in k:
print(f"Removing key {k} from pretrained checkpoint")
del ckpt[k]
for k, v in ckpt.items():
if k.startswith("encoder"):
k = ".".join(k.split(".")[1:]) # remove encoder in the name
if k.endswith("kernel"):
k = ".".join(k.split(".")[:-1]) # remove kernel in the name
new_k = k + ".weight"
if len(v.shape) == 3: # resahpe standard convolution
kv, in_dim, out_dim = v.shape
# ks = int(math.sqrt(kv))
# # pow(kv, 1/3)
# new_ckpt[new_k] = v.permute(2, 1, 0).\
# reshape(out_dim, in_dim, ks, ks).transpose(3, 2)
ks = int(
round(kv ** (1 / 3))
) # calculate kernel size assuming cubic kernel
new_ckpt[new_k] = (
v.permute(2, 1, 0)
.reshape(out_dim, in_dim, ks, ks, ks)
.permute(0, 1, 4, 3, 2)
)
elif len(v.shape) == 2: # reshape depthwise convolution
kv, dim = v.shape
# ks = int(math.sqrt(kv))
# new_ckpt[new_k] = v.permute(1, 0).\
# reshape(dim, 1, ks, ks).transpose(3, 2)
if new_k == "downsample_layers.3.1.weight":
new_ckpt[new_k] = (
v.permute(1, 0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
)
else:
ks = int(round(kv ** (1 / 3)))
new_ckpt[new_k] = (
v.permute(1, 0)
.reshape(dim, 1, ks, ks, ks)
.permute(0, 1, 4, 3, 2)
)
continue
elif "ln" in k or "linear" in k:
k = k.split(".")
k.pop(-2) # remove ln and linear in the name
new_k = ".".join(k)
else:
new_k = k
new_ckpt[new_k] = v
# reshape grn affine parameters and biases
for k, v in new_ckpt.items():
if k.endswith("bias") and len(v.shape) != 1:
new_ckpt[k] = v.reshape(-1)
elif "grn" in k:
new_ckpt[k] = v.unsqueeze(0).unsqueeze(1).unsqueeze(0)
return new_ckpt
def load_state_dict(
model, state_dict, prefix="", ignore_missing="relative_position_index"
):
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(model, prefix=prefix)
warn_missing_keys = []
ignore_missing_keys = []
for key in missing_keys:
keep_flag = True
for ignore_key in ignore_missing.split("|"):
if ignore_key in key:
keep_flag = False
break
if keep_flag:
warn_missing_keys.append(key)
else:
ignore_missing_keys.append(key)
missing_keys = warn_missing_keys
if len(missing_keys) > 0:
print(
"Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys
)
)
if len(unexpected_keys) > 0:
print(
"Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys
)
)
if len(ignore_missing_keys) > 0:
print(
"Ignored weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, ignore_missing_keys
)
)
if len(error_msgs) > 0:
print("\n".join(error_msgs))
# if __name__ == 'main':
# model = convnextv2_base().cuda()
# x = torch.rand(1,3,256,256,32).cuda()
# print(model(x).shape)