diff --git a/.gitattributes b/.gitattributes index c01c6cef9a7364e6ff3e8c4e65d21997f624a2e8..9b47bca9173253310814e86d86ba1bcae75364af 100644 --- a/.gitattributes +++ b/.gitattributes @@ -91,3 +91,5 @@ visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_B8_att visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_B9_attn_proj.pdf filter=lfs diff=lfs merge=lfs -text visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_all_layers.pdf filter=lfs diff=lfs merge=lfs -text visual-aids/vit-tiny-reluact-16-224/erf_vit_tiny_relu_16_224_w_pretrained_average.pdf filter=lfs diff=lfs merge=lfs -text +models/__pycache__/vit.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +models/__pycache__/vit.cpython-312.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/models/MAE_SDT.py b/models/MAE_SDT.py new file mode 100644 index 0000000000000000000000000000000000000000..2f877e5cef8b9249b51dbec70ed788b478264fc9 --- /dev/null +++ b/models/MAE_SDT.py @@ -0,0 +1,639 @@ +from functools import partial +import torch +import torch.nn as nn +import torchinfo +from timm.models.layers import to_2tuple, trunc_normal_, DropPath +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg +from einops.layers.torch import Rearrange +import torch.nn.functional as F +from timm.models.vision_transformer import PatchEmbed, Block + +from spikingjelly.clock_driven import layer +import copy +from torchvision import transforms +import matplotlib.pyplot as plt + +import models.encoder as encoder +from .util.pos_embed import get_2d_sincos_pos_embed + +import torch + +#timestep +T=4 + + +class multispike(torch.autograd.Function): + @staticmethod + def forward(ctx, input, lens=T): + ctx.save_for_backward(input) + ctx.lens = lens + return torch.floor(torch.clamp(input, 0, lens) + 0.5) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + grad_input = grad_output.clone() + temp1 = 0 < input + temp2 = input < ctx.lens + return grad_input * temp1.float() * temp2.float(), None + + +class Multispike(nn.Module): + def __init__(self, spike=multispike,norm=T): + super().__init__() + self.lens = norm + self.spike = spike + self.norm=norm + + def forward(self, inputs): + return self.spike.apply(inputs)/self.norm + + + + +def MS_conv_unit(in_channels, out_channels,kernel_size=1,padding=0,groups=1): + return nn.Sequential( + layer.SeqToANNContainer( + encoder.SparseConv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, groups=groups,bias=True), + encoder.SparseBatchNorm2d(out_channels) + ) + ) +class MS_ConvBlock(nn.Module): + def __init__(self, dim, + mlp_ratio=4.0): + super().__init__() + + self.neuron1 = Multispike() + self.conv1 = MS_conv_unit(dim, dim * mlp_ratio, 3, 1) + + self.neuron2 = Multispike() + self.conv2 = MS_conv_unit(dim*mlp_ratio, dim, 3, 1) + + + def forward(self, x, mask=None): + short_cut = x + x = self.neuron1(x) + x = self.conv1(x) + x = self.neuron2(x) + x = self.conv2(x) + x = x +short_cut + return x + +class MS_MLP(nn.Module): + def __init__( + self, in_features, hidden_features=None, out_features=None, drop=0.0, layer=0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1) + self.fc1_bn = nn.BatchNorm1d(hidden_features) + self.fc1_lif = Multispike() + + + self.fc2_conv = nn.Conv1d( + hidden_features, out_features, kernel_size=1, stride=1 + ) + self.fc2_bn = nn.BatchNorm1d(out_features) + self.fc2_lif = Multispike() + + self.c_hidden = hidden_features + self.c_output = out_features + + def forward(self, x): + T, B, C, N= x.shape + + x = self.fc1_lif(x) + x = self.fc1_conv(x.flatten(0, 1)) + x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous() + + x = self.fc2_lif(x) + x = self.fc2_conv(x.flatten(0, 1)) + x = self.fc2_bn(x).reshape(T, B, C, N).contiguous() + + return x + +class RepConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + bias=False, + ): + super().__init__() + # TODO in_channel-> 2*in_channel->in_channel + self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5))) + self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel)) + def forward(self, x): + return self.conv2(self.conv1(x)) +class RepConv2(nn.Module): + def __init__( + self, + in_channel, + out_channel, + bias=False, + ): + super().__init__() + # TODO in_channel-> 2*in_channel->in_channel + self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5))) + self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel)) + def forward(self, x): + return self.conv2(self.conv1(x)) + +class MS_Attention_Conv_qkv_id(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.num_heads = num_heads + self.scale = 0.125 + self.sr_ratio=sr_ratio + + self.head_lif = Multispike() + + # track 1: split convs + self.q_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim)) + self.k_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim)) + self.v_conv = nn.Sequential(RepConv(dim,dim*sr_ratio), nn.BatchNorm1d(dim*sr_ratio)) + + # track 2: merge (prefer) NOTE: need `chunk` in forward + # self.qkv_conv = nn.Sequential(RepConv(dim,dim * 3), nn.BatchNorm2d(dim * 3)) + + self.q_lif = Multispike() + + self.k_lif = Multispike() + + self.v_lif = Multispike() + + self.attn_lif = Multispike() + + self.proj_conv = nn.Sequential(RepConv(sr_ratio*dim,dim), nn.BatchNorm1d(dim)) + + def forward(self, x): + T, B, C, N = x.shape + + x = self.head_lif(x) + + x_for_qkv = x.flatten(0, 1) + q_conv_out = self.q_conv(x_for_qkv).reshape(T, B, C, N) + + q_conv_out = self.q_lif(q_conv_out) + + q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, + 4) + + k_conv_out = self.k_conv(x_for_qkv).reshape(T, B, C, N) + + k_conv_out = self.k_lif(k_conv_out) + + k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, + 4) + + v_conv_out = self.v_conv(x_for_qkv).reshape(T, B, self.sr_ratio*C, N) + + v_conv_out = self.v_lif(v_conv_out) + + v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, self.sr_ratio*C // self.num_heads).permute(0, 1, 3, 2, + 4) + + x = k.transpose(-2, -1) @ v + x = (q @ x) * self.scale + x = x.transpose(3, 4).reshape(T, B, self.sr_ratio*C, N) + x = self.attn_lif(x) + + x = self.proj_conv(x.flatten(0, 1)).reshape(T, B, C, N) + return x + + + + +class MS_DownSampling(nn.Module): + def __init__( + self, + in_channels=2, + embed_dims=256, + kernel_size=3, + stride=2, + padding=1, + first_layer=True, + + ): + super().__init__() + + self.encode_conv = encoder.SparseConv2d( + in_channels, + embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + self.encode_bn = encoder.SparseBatchNorm2d(embed_dims) + self.first_layer = first_layer + if not first_layer: + self.encode_spike = Multispike() + + def forward(self, x): + T, B, _, _, _ = x.shape + if hasattr(self, "encode_spike"): + x = self.encode_spike(x) + x = self.encode_conv(x.flatten(0, 1)) + _, _, H, W = x.shape + x = self.encode_bn(x).reshape(T, B, -1, H, W) + + return x + +class MS_Block(nn.Module): + def __init__( + self, + dim, + choice, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + sr_ratio=1,init_values=1e-6,finetune=False, + ): + super().__init__() + self.model=choice + if self.model=="base": + self.rep_conv=RepConv2(dim,dim) #if have param==83M + self.lif = Multispike() + self.attn = MS_Attention_Conv_qkv_id( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + ) + self.finetune = finetune + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) + + if self.finetune: + self.layer_scale1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.layer_scale2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + + def forward(self, x): + # T, B, C, N = x.shape + if self.model=="base": + x= x + self.rep_conv(self.lif(x).flatten(0, 1)).reshape(T, B, C, N) + # TODO: need channel-wise layer scale, init as 1e-6 + if self.finetune: + x = x + self.drop_path(self.attn(x) * self.layer_scale1.unsqueeze(0).unsqueeze(0).unsqueeze(-1)) + x = x + self.drop_path(self.mlp(x) * self.layer_scale2.unsqueeze(0).unsqueeze(0).unsqueeze(-1)) + else: + x = x + self.attn(x) + x = x + self.mlp(x) + return x + +class Spikmae(nn.Module): + def __init__(self, T=1,choice=None, + img_size_h=224, + img_size_w=224, + patch_size=16, + embed_dim=[128, 256, 512], + num_heads=8, + mlp_ratios=4, + in_channels=3, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), #norm_layer=nn.LayerNorm shaokun + depths=8, + sr_ratios=1, + decoder_embed_dim=768, + decoder_depth=4, + decoder_num_heads=16, + mlp_ratio=4., + norm_pix_loss=False, nb_classes=1000): + super().__init__() + + self.num_classes = num_classes + self.depths = depths + self.T = 1 + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths) + ] # stochastic depth decay rule + + self.downsample1_1 = MS_DownSampling( + in_channels=in_channels, + embed_dims=embed_dim[0] // 2, + kernel_size=7, + stride=2, + padding=3, + first_layer=True, + ) + + self.ConvBlock1_1 = nn.ModuleList( + [MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)] + ) + + self.downsample1_2 = MS_DownSampling( + in_channels=embed_dim[0] // 2, + embed_dims=embed_dim[0], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + + ) + + self.ConvBlock1_2 = nn.ModuleList( + [MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)] + ) + + self.downsample2 = MS_DownSampling( + in_channels=embed_dim[0], + embed_dims=embed_dim[1], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + + ) + + self.ConvBlock2_1 = nn.ModuleList( + [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] + ) + + self.ConvBlock2_2 = nn.ModuleList( + [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] + ) + + self.downsample3 = MS_DownSampling( + in_channels=embed_dim[1], + embed_dims=embed_dim[2], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + + ) + + self.block3 = nn.ModuleList( + [ + MS_Block( + dim=embed_dim[2], + choice=choice, + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + finetune=False, + ) + for j in range(depths) + ] + ) + + self.norm = nn.BatchNorm1d(embed_dim[-1]) + self.downsample_raito =16 + + num_patches = 196 + + self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim[-1],num_patches), requires_grad=False) + + ## MAE decoder vit + self.decoder_embed = nn.Linear(embed_dim[-1], decoder_embed_dim,bias=True) + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + # Try larned decoder + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) + self.decoder_blocks = nn.ModuleList([ + Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=False, norm_layer=norm_layer) + for i in range(decoder_depth)]) + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_channels,bias=True) # decoder to patch + self.initialize_weights() + + def initialize_weights(self): + num_patches=196 + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[1], int(num_patches ** .5), + cls_token=False) + + self.pos_embed.data.copy_(torch.from_numpy(pos_embed.transpose(1,0)).float().unsqueeze(0)) + + decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], + int(num_patches** .5), cls_token=False) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + torch.nn.init.normal_(self.mask_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + num_patches=196 + T, N, _, _, _ = x.shape # batch, length, dim + L = num_patches + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + # active is inverse mask + active = torch.ones([N, L], device=x.device) + active[:, len_keep:] = 0 + active = torch.gather(active, dim=1, index=ids_restore) + + return ids_keep, active, ids_restore + + def forward_encoder(self, x , mask_ratio=1.0): + x = (x.unsqueeze(0)).repeat(self.T, 1, 1, 1, 1) + # step1. Mask + ids_keep, active, ids_restore = self.random_masking(x , mask_ratio) + B,N=active.shape + active_b1ff=active.reshape(B,1,14,14) + + + encoder._cur_active = active_b1ff + active_hw = active_b1ff.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) + active_hw = active_hw.unsqueeze(0) + masked_bchw = x * active_hw + x = masked_bchw + x = self.downsample1_1(x) + for blk in self.ConvBlock1_1: + x = blk(x) + x = self.downsample1_2(x) + for blk in self.ConvBlock1_2: + x = blk(x) + + x = self.downsample2(x) + for blk in self.ConvBlock2_1: + x = blk(x) + for blk in self.ConvBlock2_2: + x = blk(x) + + x = self.downsample3(x) + x = x.flatten(3) + for blk in self.block3: + x = blk(x) + + x = x.mean(0) + x = self.norm(x).transpose(-1, -2).contiguous() + return x, active,ids_restore,active_hw + + def forward_decoder(self, x, ids_restore): + # embed tokens + B, N, C = x.shape + x = self.decoder_embed(x) # B, N, C + # append mask tokens to sequence + # ids_restore#1,196 + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x = x_ +# + # add pos embed + x = x + self.decoder_pos_embed + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + x = self.decoder_pred(x) + + return x + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = 16 + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = 16 + h = w = int(x.shape[1] ** .5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + def forward_loss(self, imgs, pred, mask): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + + inp, rec = self.patchify(imgs), pred # inp and rec: (B, L = f*f, N = C*downsample_raito**2) + mean = inp.mean(dim=-1, keepdim=True) + var = (inp.var(dim=-1, keepdim=True) + 1e-6) ** .5 + inp = (inp - mean) / var + l2_loss = ((rec - inp) ** 2).mean(dim=2, keepdim=False) # (B, L, C) ==mean==> (B, L) + non_active = mask.logical_not().int().view(mask.shape[0], -1) # (B, 1, f, f) => (B, L) + recon_loss = l2_loss.mul_(non_active).sum() / (non_active.sum() + 1e-8) # loss only on masked (non-active) patches + return recon_loss,mean,var + + def forward(self, imgs, mask_ratio=0.5,vis=False): + + latent, active, ids_restore,active_hw = self.forward_encoder(imgs, mask_ratio) + rec = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] + recon_loss,mean,var = self.forward_loss(imgs, rec, active) + if vis: + masked_bchw = imgs * active_hw.flatten(0,1) + rec_bchw = self.unpatchify(rec * var + mean) + rec_or_inp = torch.where(active_hw.flatten(0,1).bool(), imgs, rec_bchw) + return imgs, masked_bchw, rec_or_inp + else: + return recon_loss + + +def spikmae_12_512(**kwargs): + model = Spikmae( + T=1, + choice="base", + img_size_h=224, + img_size_w=224, + patch_size=16, + embed_dim=[128,256,512], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=12, + sr_ratios=1, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=4, + **kwargs) + return model +def spikmae_12_768(**kwargs): + model = Spikmae( + T=1, + choice="large", + img_size_h=224, + img_size_w=224, + patch_size=16, + embed_dim=[192,384,768], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=12, + sr_ratios=1, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=4, + **kwargs) + return model + + + + +if __name__ == "__main__": + model = spikmae_12_768() + x=torch.randn(1,3,224,224) + loss = model(x,mask_ratio=0.50) + print('loss',loss) + torchinfo.summary(model, (1, 3, 224, 224)) + print(f"number of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/__pycache__/MAE_SDT.cpython-311.pyc b/models/__pycache__/MAE_SDT.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a77507879027885a99d3b5bec5e078e02ca727f Binary files /dev/null and b/models/__pycache__/MAE_SDT.cpython-311.pyc differ diff --git a/models/__pycache__/MAE_SDT.cpython-312.pyc b/models/__pycache__/MAE_SDT.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a3d99cb44eddf6ecab7358eed4a5f036dcf183b Binary files /dev/null and b/models/__pycache__/MAE_SDT.cpython-312.pyc differ diff --git a/models/__pycache__/__init__.cpython-311.pyc b/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8107133a6016ecb24d213aed82e725539b1926b1 Binary files /dev/null and b/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/models/__pycache__/__init__.cpython-312.pyc b/models/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f962b81335d9764f5c854c1bd303b4d63957b296 Binary files /dev/null and b/models/__pycache__/__init__.cpython-312.pyc differ diff --git a/models/__pycache__/__init__.cpython-39.pyc b/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60f9a0418bf0d8d8a7545ca1a8756eb7a066cbf7 Binary files /dev/null and b/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/__pycache__/encoder.cpython-311.pyc b/models/__pycache__/encoder.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..519dd2ce687aaa64cb66a56e7fcf40d158ca9235 Binary files /dev/null and b/models/__pycache__/encoder.cpython-311.pyc differ diff --git a/models/__pycache__/encoder.cpython-312.pyc b/models/__pycache__/encoder.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a062cdf68511f42adee99d1c18fafdd2094224a6 Binary files /dev/null and b/models/__pycache__/encoder.cpython-312.pyc differ diff --git a/models/__pycache__/metaformer.cpython-311.pyc b/models/__pycache__/metaformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98ba06c203ad5c9b3680965f0d719b1ce84f489e Binary files /dev/null and b/models/__pycache__/metaformer.cpython-311.pyc differ diff --git a/models/__pycache__/metaformer.cpython-312.pyc b/models/__pycache__/metaformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7feb64a2211ce82a10162353d80b64303bf08edc Binary files /dev/null and b/models/__pycache__/metaformer.cpython-312.pyc differ diff --git a/models/__pycache__/neuron.cpython-311.pyc b/models/__pycache__/neuron.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cc4c76dc919a215b7eb7e6a84c334cad9e19833 Binary files /dev/null and b/models/__pycache__/neuron.cpython-311.pyc differ diff --git a/models/__pycache__/neuron.cpython-312.pyc b/models/__pycache__/neuron.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f1e8c951329176f4ac45799ea4ea9c2f98bf00e Binary files /dev/null and b/models/__pycache__/neuron.cpython-312.pyc differ diff --git a/models/__pycache__/qk_model_v1_1003.cpython-311.pyc b/models/__pycache__/qk_model_v1_1003.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da5ef946d5def611409023c7df80e79d4d75c31b Binary files /dev/null and b/models/__pycache__/qk_model_v1_1003.cpython-311.pyc differ diff --git a/models/__pycache__/qkformer.cpython-311.pyc b/models/__pycache__/qkformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af2e8ada457fac366a7e6d2c5ea76c2f99390553 Binary files /dev/null and b/models/__pycache__/qkformer.cpython-311.pyc differ diff --git a/models/__pycache__/qkformer.cpython-312.pyc b/models/__pycache__/qkformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..330ab677b6c167f1dcd62687afdffdb028d2595c Binary files /dev/null and b/models/__pycache__/qkformer.cpython-312.pyc differ diff --git a/models/__pycache__/sd_former_v1.cpython-311.pyc b/models/__pycache__/sd_former_v1.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e717fb6294282341f99c6a17e4a119cf873288bc Binary files /dev/null and b/models/__pycache__/sd_former_v1.cpython-311.pyc differ diff --git a/models/__pycache__/sd_former_v1.cpython-312.pyc b/models/__pycache__/sd_former_v1.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7dbcbd43c7484af067672c133d07b8e9b82cb8f Binary files /dev/null and b/models/__pycache__/sd_former_v1.cpython-312.pyc differ diff --git a/models/__pycache__/sdtv3.cpython-311.pyc b/models/__pycache__/sdtv3.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a36ce6dfc746fbe187e4e2317d95c519ec29990 Binary files /dev/null and b/models/__pycache__/sdtv3.cpython-311.pyc differ diff --git a/models/__pycache__/sdtv3.cpython-312.pyc b/models/__pycache__/sdtv3.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21c4187130a7b0f93486503e2c71de66ffc728f0 Binary files /dev/null and b/models/__pycache__/sdtv3.cpython-312.pyc differ diff --git a/models/__pycache__/sdtv3.cpython-39.pyc b/models/__pycache__/sdtv3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17e2163546b4a2c6f67ed5da3cc19df824ef034b Binary files /dev/null and b/models/__pycache__/sdtv3.cpython-39.pyc differ diff --git a/models/__pycache__/sdtv3_large.cpython-311.pyc b/models/__pycache__/sdtv3_large.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24d0ba5dd5670b5aa98d140076e2da5c7d34f2d5 Binary files /dev/null and b/models/__pycache__/sdtv3_large.cpython-311.pyc differ diff --git a/models/__pycache__/sdtv3_large.cpython-312.pyc b/models/__pycache__/sdtv3_large.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94462c9f303b69923ebe73956cc1945c77b2817c Binary files /dev/null and b/models/__pycache__/sdtv3_large.cpython-312.pyc differ diff --git a/models/__pycache__/spikformer.cpython-311.pyc b/models/__pycache__/spikformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c3398fa0a54bbbc4075776f8561f47f62d1cc23 Binary files /dev/null and b/models/__pycache__/spikformer.cpython-311.pyc differ diff --git a/models/__pycache__/spikformer.cpython-312.pyc b/models/__pycache__/spikformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..017753abf161cbc1f77299f8d963f5af533b1d38 Binary files /dev/null and b/models/__pycache__/spikformer.cpython-312.pyc differ diff --git a/models/__pycache__/vit.cpython-311.pyc b/models/__pycache__/vit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a0e98a71086bf50b8fd60dc6111ed1a9b42bed1 --- /dev/null +++ b/models/__pycache__/vit.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c206e15daa2f79c2abc87acf17b7e6263bb292fe86a0581f10e58b94da50c3d5 +size 204918 diff --git a/models/__pycache__/vit.cpython-312.pyc b/models/__pycache__/vit.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a0300673e1c814385db8d5226d0336775a88d2a --- /dev/null +++ b/models/__pycache__/vit.cpython-312.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:541616a16f1f3839624aff7ca6c0d0f168227ee19a642340a85e71e77d6ea63d +size 183274 diff --git a/models/encoder.py b/models/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d62a8194317109413342a3b8f37a7653dd5a877d --- /dev/null +++ b/models/encoder.py @@ -0,0 +1,158 @@ +# Copyright (c) ByteDance, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from timm.models.layers import DropPath + + +_cur_active: torch.Tensor = None # B1ff +# todo: try to use `gather` for speed? +def _get_active_ex_or_ii(H, W, returning_active_ex=True): + h_repeat, w_repeat = H // _cur_active.shape[-2], W // _cur_active.shape[-1] + active_ex = _cur_active.repeat_interleave(h_repeat, dim=2).repeat_interleave(w_repeat, dim=3) + return active_ex if returning_active_ex else active_ex.squeeze(1).nonzero(as_tuple=True) # ii: bi, hi, wi + + +def sp_conv_forward(self, x: torch.Tensor): + x = super(type(self), self).forward(x) + x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True) # (BCHW) *= (B1HW), mask the output of conv + return x + + +def sp_bn_forward(self, x: torch.Tensor): + ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False) + + bhwc = x.permute(0, 2, 3, 1) + nc = bhwc[ii] # select the features on non-masked positions to form a flatten feature `nc` + nc = super(type(self), self).forward(nc) # use BN1d to normalize this flatten feature `nc` + + bchw = torch.zeros_like(bhwc) + bchw[ii] = nc + bchw = bchw.permute(0, 3, 1, 2) + return bchw + + +class SparseConv2d(nn.Conv2d): + forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details + + +class SparseMaxPooling(nn.MaxPool2d): + forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details + + +class SparseAvgPooling(nn.AvgPool2d): + forward = sp_conv_forward # hack: override the forward function; see `sp_conv_forward` above for more details + + +class SparseBatchNorm2d(nn.BatchNorm1d): + forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details + + +class SparseSyncBatchNorm2d(nn.SyncBatchNorm): + forward = sp_bn_forward # hack: override the forward function; see `sp_bn_forward` above for more details + + +class SparseConvNeXtLayerNorm(nn.LayerNorm): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last", sparse=True): + if data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + super().__init__(normalized_shape, eps, elementwise_affine=True) + self.data_format = data_format + self.sparse = sparse + + def forward(self, x): + if x.ndim == 4: # BHWC or BCHW + if self.data_format == "channels_last": # BHWC + if self.sparse: + ii = _get_active_ex_or_ii(H=x.shape[1], W=x.shape[2], returning_active_ex=False) + nc = x[ii] + nc = super(SparseConvNeXtLayerNorm, self).forward(nc) + + x = torch.zeros_like(x) + x[ii] = nc + return x + else: + return super(SparseConvNeXtLayerNorm, self).forward(x) + else: # channels_first, BCHW + if self.sparse: + ii = _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=False) + bhwc = x.permute(0, 2, 3, 1) + nc = bhwc[ii] + nc = super(SparseConvNeXtLayerNorm, self).forward(nc) + + x = torch.zeros_like(bhwc) + x[ii] = nc + return x.permute(0, 3, 1, 2) + else: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + else: # BLC or BC + if self.sparse: + raise NotImplementedError + else: + return super(SparseConvNeXtLayerNorm, self).forward(x) + + def __repr__(self): + return super(SparseConvNeXtLayerNorm, self).__repr__()[:-1] + f', ch={self.data_format.split("_")[-1]}, sp={self.sparse})' + + +class SparseConvNeXtBlock(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, sparse=True, ks=7): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=ks, padding=ks//2, groups=dim) # depthwise conv + self.norm = SparseConvNeXtLayerNorm(dim, eps=1e-6, sparse=sparse) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path: nn.Module = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.sparse = sparse + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) # GELU(0) == (0), so there is no need to mask x (no need to `x *= _get_active_ex_or_ii`) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + if self.sparse: + x *= _get_active_ex_or_ii(H=x.shape[2], W=x.shape[3], returning_active_ex=True) + + x = input + self.drop_path(x) + return x + + def __repr__(self): + return super(SparseConvNeXtBlock, self).__repr__()[:-1] + f', sp={self.sparse})' + + + diff --git a/models/metaformer.py b/models/metaformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f076d6bcc29761a6cf0c2a21a5951616f448560c --- /dev/null +++ b/models/metaformer.py @@ -0,0 +1,1538 @@ +# Copyright 2022 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +MetaFormer baselines including IdentityFormer, RandFormer, PoolFormerV2, +ConvFormer and CAFormer. +Some implementations are modified from timm (https://github.com/rwightman/pytorch-image-models). +""" +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.layers import trunc_normal_, DropPath +from timm.models.registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers.helpers import to_2tuple + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': 1.0, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'identityformer_s12': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s12.pth'), + 'identityformer_s24': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s24.pth'), + 'identityformer_s36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_s36.pth'), + 'identityformer_m36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m36.pth'), + 'identityformer_m48': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/identityformer/identityformer_m48.pth'), + + + 'randformer_s12': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s12.pth'), + 'randformer_s24': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s24.pth'), + 'randformer_s36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_s36.pth'), + 'randformer_m36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m36.pth'), + 'randformer_m48': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/randformer/randformer_m48.pth'), + + 'poolformerv2_s12': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s12.pth'), + 'poolformerv2_s24': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s24.pth'), + 'poolformerv2_s36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_s36.pth'), + 'poolformerv2_m36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m36.pth'), + 'poolformerv2_m48': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/poolformerv2/poolformerv2_m48.pth'), + + + + 'convformer_s18': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18.pth'), + 'convformer_s18_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384.pth', + input_size=(3, 384, 384)), + 'convformer_s18_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21ft1k.pth'), + 'convformer_s18_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_s18_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s18_in21k.pth', + num_classes=21841), + + 'convformer_s36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36.pth'), + 'convformer_s36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384.pth', + input_size=(3, 384, 384)), + 'convformer_s36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21ft1k.pth'), + 'convformer_s36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_s36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_s36_in21k.pth', + num_classes=21841), + + 'convformer_m36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36.pth'), + 'convformer_m36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384.pth', + input_size=(3, 384, 384)), + 'convformer_m36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21ft1k.pth'), + 'convformer_m36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_m36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_m36_in21k.pth', + num_classes=21841), + + 'convformer_b36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36.pth'), + 'convformer_b36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384.pth', + input_size=(3, 384, 384)), + 'convformer_b36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21ft1k.pth'), + 'convformer_b36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'convformer_b36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/convformer/convformer_b36_in21k.pth', + num_classes=21841), + + + 'caformer_s18': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18.pth'), + 'caformer_s18_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384.pth', + input_size=(3, 384, 384)), + 'caformer_s18_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21ft1k.pth'), + 'caformer_s18_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_s18_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s18_in21k.pth', + num_classes=21841), + + 'caformer_s36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36.pth'), + 'caformer_s36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384.pth', + input_size=(3, 384, 384)), + 'caformer_s36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21ft1k.pth'), + 'caformer_s36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_s36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_s36_in21k.pth', + num_classes=21841), + + 'caformer_m36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36.pth'), + 'caformer_m36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384.pth', + input_size=(3, 384, 384)), + 'caformer_m36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21ft1k.pth'), + 'caformer_m36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_m36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_m36_in21k.pth', + num_classes=21841), + + 'caformer_b36': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36.pth'), + 'caformer_b36_384': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384.pth', + input_size=(3, 384, 384)), + 'caformer_b36_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21ft1k.pth'), + 'caformer_b36_384_in21ft1k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_384_in21ft1k.pth', + input_size=(3, 384, 384)), + 'caformer_b36_in21k': _cfg( + url='https://huggingface.co/sail/dl/resolve/main/caformer/caformer_b36_in21k.pth', + num_classes=21841), +} + + +class Downsampling(nn.Module): + """ + Downsampling implemented by a layer of convolution. + """ + def __init__(self, in_channels, out_channels, + kernel_size, stride=1, padding=0, + pre_norm=None, post_norm=None, pre_permute=False): + super().__init__() + self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() + self.pre_permute = pre_permute + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, + stride=stride, padding=padding) + self.post_norm = post_norm(out_channels) if post_norm else nn.Identity() + + def forward(self, x): + x = self.pre_norm(x) + if self.pre_permute: + # if take [B, H, W, C] as input, permute it to [B, C, H, W] + x = x.permute(0, 3, 1, 2) + x = self.conv(x) + x = x.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C] + x = self.post_norm(x) + return x + + +class Scale(nn.Module): + """ + Scale vector by element multiplications. + """ + def __init__(self, dim, init_value=1.0, trainable=True): + super().__init__() + self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable) + + def forward(self, x): + return x * self.scale + + +class SquaredReLU(nn.Module): + """ + Squared ReLU: https://arxiv.org/abs/2109.08668 + """ + def __init__(self, inplace=False): + super().__init__() + self.relu = nn.ReLU(inplace=inplace) + def forward(self, x): + return torch.square(self.relu(x)) + + +class StarReLU(nn.Module): + """ + StarReLU: s * relu(x) ** 2 + b + """ + def __init__(self, scale_value=1.0, bias_value=0.0, + scale_learnable=True, bias_learnable=True, + mode=None, inplace=False): + super().__init__() + self.inplace = inplace + self.relu = nn.ReLU(inplace=inplace) + self.scale = nn.Parameter(scale_value * torch.ones(1), + requires_grad=scale_learnable) + self.bias = nn.Parameter(bias_value * torch.ones(1), + requires_grad=bias_learnable) + def forward(self, x): + return self.scale * self.relu(x)**2 + self.bias + + +class Attention(nn.Module): + """ + Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762. + Modified from timm. + """ + def __init__(self, dim, head_dim=32, num_heads=None, qkv_bias=False, + attn_drop=0., proj_drop=0., proj_bias=False, **kwargs): + super().__init__() + + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + + self.num_heads = num_heads if num_heads else dim // head_dim + if self.num_heads == 0: + self.num_heads = 1 + + self.attention_dim = self.num_heads * self.head_dim + + self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x): + B, H, W, C = x.shape + N = H * W + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.attention_dim) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class RandomMixing(nn.Module): + def __init__(self, num_tokens=196, **kwargs): + super().__init__() + self.random_matrix = nn.parameter.Parameter( + data=torch.softmax(torch.rand(num_tokens, num_tokens), dim=-1), + requires_grad=False) + def forward(self, x): + B, H, W, C = x.shape + x = x.reshape(B, H*W, C) + x = torch.einsum('mn, bnc -> bmc', self.random_matrix, x) + x = x.reshape(B, H, W, C) + return x + + +class LayerNormGeneral(nn.Module): + r""" General LayerNorm for different situations. + + Args: + affine_shape (int, list or tuple): The shape of affine weight and bias. + Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm, + the affine_shape is the same as normalized_dim by default. + To adapt to different situations, we offer this argument here. + normalized_dim (tuple or list): Which dims to compute mean and variance. + scale (bool): Flag indicates whether to use scale or not. + bias (bool): Flag indicates whether to use scale or not. + + We give several examples to show how to specify the arguments. + + LayerNorm (https://arxiv.org/abs/1607.06450): + For input shape of (B, *, C) like (B, N, C) or (B, H, W, C), + affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True; + For input shape of (B, C, H, W), + affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True. + + Modified LayerNorm (https://arxiv.org/abs/2111.11418) + that is idental to partial(torch.nn.GroupNorm, num_groups=1): + For input shape of (B, N, C), + affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True; + For input shape of (B, H, W, C), + affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True; + For input shape of (B, C, H, W), + affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True. + + For the several metaformer baslines, + IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False); + ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False). + """ + def __init__(self, affine_shape=None, normalized_dim=(-1, ), scale=True, + bias=True, eps=1e-5): + super().__init__() + self.normalized_dim = normalized_dim + self.use_scale = scale + self.use_bias = bias + self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None + self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None + self.eps = eps + + def forward(self, x): + c = x - x.mean(self.normalized_dim, keepdim=True) + s = c.pow(2).mean(self.normalized_dim, keepdim=True) + x = c / torch.sqrt(s + self.eps) + if self.use_scale: + x = x * self.weight + if self.use_bias: + x = x + self.bias + return x + + +class LayerNormWithoutBias(nn.Module): + """ + Equal to partial(LayerNormGeneral, bias=False) but faster, + because it directly utilizes otpimized F.layer_norm + """ + def __init__(self, normalized_shape, eps=1e-5, **kwargs): + super().__init__() + self.eps = eps + self.bias = None + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.normalized_shape = normalized_shape + def forward(self, x): + return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) + + +class SepConv(nn.Module): + r""" + Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. + """ + def __init__(self, dim, expansion_ratio=2, + act1_layer=StarReLU, act2_layer=nn.Identity, + bias=False, kernel_size=7, padding=3, + **kwargs, ): + super().__init__() + med_channels = int(expansion_ratio * dim) + self.pwconv1 = nn.Linear(dim, med_channels, bias=bias) + self.act1 = act1_layer() + self.dwconv = nn.Conv2d( + med_channels, med_channels, kernel_size=kernel_size, + padding=padding, groups=med_channels, bias=bias) # depthwise conv + self.act2 = act2_layer() + self.pwconv2 = nn.Linear(med_channels, dim, bias=bias) + + def forward(self, x): + x = self.pwconv1(x) + x = self.act1(x) + x = x.permute(0, 3, 1, 2) + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) + x = self.act2(x) + x = self.pwconv2(x) + return x + + +class Pooling(nn.Module): + """ + Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418 + Modfiled for [B, H, W, C] input + """ + def __init__(self, pool_size=3, **kwargs): + super().__init__() + self.pool = nn.AvgPool2d( + pool_size, stride=1, padding=pool_size//2, count_include_pad=False) + + def forward(self, x): + y = x.permute(0, 3, 1, 2) + y = self.pool(y) + y = y.permute(0, 2, 3, 1) + return y - x + + +class Mlp(nn.Module): + """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks. + Mostly copied from timm. + """ + def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs): + super().__init__() + in_features = dim + out_features = out_features or in_features + hidden_features = int(mlp_ratio * in_features) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class MlpHead(nn.Module): + """ MLP classification head + """ + def __init__(self, dim, num_classes=1000, mlp_ratio=4, act_layer=SquaredReLU, + norm_layer=nn.LayerNorm, head_dropout=0., bias=True): + super().__init__() + hidden_features = int(mlp_ratio * dim) + self.fc1 = nn.Linear(dim, hidden_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) + self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) + self.head_dropout = nn.Dropout(head_dropout) + + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.head_dropout(x) + x = self.fc2(x) + return x + + +class MetaFormerBlock(nn.Module): + """ + Implementation of one MetaFormer block. + """ + def __init__(self, dim, + token_mixer=nn.Identity, mlp=Mlp, + norm_layer=nn.LayerNorm, + drop=0., drop_path=0., + layer_scale_init_value=None, res_scale_init_value=None + ): + + super().__init__() + + self.norm1 = norm_layer(dim) + self.token_mixer = token_mixer(dim=dim, drop=drop) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) \ + if layer_scale_init_value else nn.Identity() + self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) \ + if res_scale_init_value else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp(dim=dim, drop=drop) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) \ + if layer_scale_init_value else nn.Identity() + self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) \ + if res_scale_init_value else nn.Identity() + + def forward(self, x): + x = self.res_scale1(x) + \ + self.layer_scale1( + self.drop_path1( + self.token_mixer(self.norm1(x)) + ) + ) + x = self.res_scale2(x) + \ + self.layer_scale2( + self.drop_path2( + self.mlp(self.norm2(x)) + ) + ) + return x + + +r""" +downsampling (stem) for the first stage is a layer of conv with k7, s4 and p2 +downsamplings for the last 3 stages is a layer of conv with k3, s2 and p1 +DOWNSAMPLE_LAYERS_FOUR_STAGES format: [Downsampling, Downsampling, Downsampling, Downsampling] +use `partial` to specify some arguments +""" +DOWNSAMPLE_LAYERS_FOUR_STAGES = [partial(Downsampling, + kernel_size=7, stride=4, padding=2, + post_norm=partial(LayerNormGeneral, bias=False, eps=1e-6) + )] + \ + [partial(Downsampling, + kernel_size=3, stride=2, padding=1, + pre_norm=partial(LayerNormGeneral, bias=False, eps=1e-6), pre_permute=True + )]*3 + + +class MetaFormer(nn.Module): + r""" MetaFormer + A PyTorch impl of : `MetaFormer Baselines for Vision` - + https://arxiv.org/abs/2210.13452 + + Args: + in_chans (int): Number of input image channels. Default: 3. + num_classes (int): Number of classes for classification head. Default: 1000. + depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2]. + dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512]. + downsample_layers: (list or tuple): Downsampling layers before each stage. + token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity. + mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp. + norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False). + drop_path_rate (float): Stochastic depth rate. Default: 0. + head_dropout (float): dropout for MLP classifier. Default: 0. + layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: None. + None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239. + res_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: [None, None, 1.0, 1.0]. + None means not use the layer scale. From: https://arxiv.org/abs/2110.09456. + output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6). + head_fn: classification head. Default: nn.Linear. + """ + def __init__(self, in_chans=3, num_classes=1000, + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, + token_mixers=nn.Identity, + mlps=Mlp, + norm_layers=partial(LayerNormWithoutBias, eps=1e-6), # partial(LayerNormGeneral, eps=1e-6, bias=False), + drop_path_rate=0., + head_dropout=0.0, + layer_scale_init_values=None, + res_scale_init_values=[None, None, 1.0, 1.0], + output_norm=partial(nn.LayerNorm, eps=1e-6), + head_fn=nn.Linear, + **kwargs, + ): + super().__init__() + self.num_classes = num_classes + + if not isinstance(depths, (list, tuple)): + depths = [depths] # it means the model has only one stage + if not isinstance(dims, (list, tuple)): + dims = [dims] + + num_stage = len(depths) + self.num_stage = num_stage + + if not isinstance(downsample_layers, (list, tuple)): + downsample_layers = [downsample_layers] * num_stage + down_dims = [in_chans] + dims + self.downsample_layers = nn.ModuleList( + [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)] + ) + + if not isinstance(token_mixers, (list, tuple)): + token_mixers = [token_mixers] * num_stage + + if not isinstance(mlps, (list, tuple)): + mlps = [mlps] * num_stage + + if not isinstance(norm_layers, (list, tuple)): + norm_layers = [norm_layers] * num_stage + + dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + if not isinstance(layer_scale_init_values, (list, tuple)): + layer_scale_init_values = [layer_scale_init_values] * num_stage + if not isinstance(res_scale_init_values, (list, tuple)): + res_scale_init_values = [res_scale_init_values] * num_stage + + self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks + cur = 0 + for i in range(num_stage): + stage = nn.Sequential( + *[MetaFormerBlock(dim=dims[i], + token_mixer=token_mixers[i], + mlp=mlps[i], + norm_layer=norm_layers[i], + drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_values[i], + res_scale_init_value=res_scale_init_values[i], + ) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = output_norm(dims[-1]) + + if head_dropout > 0.0: + self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout) + else: + self.head = head_fn(dims[-1], num_classes) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'norm'} + + def forward_features(self, x): + for i in range(self.num_stage): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return self.norm(x.mean([1, 2])) # (B, H, W, C) -> (B, C) + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + + +@register_model +def identityformer_s12(pretrained=False, **kwargs): + model = MetaFormer( + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['identityformer_s12'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def identityformer_s24(pretrained=False, **kwargs): + model = MetaFormer( + depths=[4, 4, 12, 4], + dims=[64, 128, 320, 512], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['identityformer_s24'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def identityformer_s36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[64, 128, 320, 512], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['identityformer_s36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def identityformer_m36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[96, 192, 384, 768], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['identityformer_m36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def identityformer_m48(pretrained=False, **kwargs): + model = MetaFormer( + depths=[8, 8, 24, 8], + dims=[96, 192, 384, 768], + token_mixers=nn.Identity, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['identityformer_m48'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def randformer_s12(pretrained=False, **kwargs): + model = MetaFormer( + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['randformer_s12'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def randformer_s24(pretrained=False, **kwargs): + model = MetaFormer( + depths=[4, 4, 12, 4], + dims=[64, 128, 320, 512], + token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['randformer_s24'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def randformer_s36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[64, 128, 320, 512], + token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['randformer_s36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def randformer_m36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[96, 192, 384, 768], + token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['randformer_m36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def randformer_m48(pretrained=False, **kwargs): + model = MetaFormer( + depths=[8, 8, 24, 8], + dims=[96, 192, 384, 768], + token_mixers=[nn.Identity, nn.Identity, RandomMixing, partial(RandomMixing, num_tokens=49)], + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['randformer_m48'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + + +@register_model +def poolformerv2_s12(pretrained=False, **kwargs): + model = MetaFormer( + depths=[2, 2, 6, 2], + dims=[64, 128, 320, 512], + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['poolformerv2_s12'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def poolformerv2_s24(pretrained=False, **kwargs): + model = MetaFormer( + depths=[4, 4, 12, 4], + dims=[64, 128, 320, 512], + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['poolformerv2_s24'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def poolformerv2_s36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[64, 128, 320, 512], + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['poolformerv2_s36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def poolformerv2_m36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[6, 6, 18, 6], + dims=[96, 192, 384, 768], + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['poolformerv2_m36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def poolformerv2_m48(pretrained=False, **kwargs): + model = MetaFormer( + depths=[8, 8, 24, 8], + dims=[96, 192, 384, 768], + token_mixers=Pooling, + norm_layers=partial(LayerNormGeneral, normalized_dim=(1, 2, 3), eps=1e-6, bias=False), + **kwargs) + model.default_cfg = default_cfgs['poolformerv2_m48'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s18(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s18'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s18_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s18_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s18_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s18_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s18_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s18_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s18_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s18_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_s36_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_s36_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_m36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_m36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_m36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_m36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_m36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_m36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_m36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_m36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_m36_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_m36_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_b36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_b36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_b36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_b36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_b36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_b36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_b36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_b36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def convformer_b36_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=SepConv, + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['convformer_b36_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s18(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s18'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s18_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s18_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s18_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s18_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s18_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s18_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s18_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 3, 9, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s18_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_s36_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[64, 128, 320, 512], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_s36_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_m36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_m36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_m36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_m36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_m36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_m36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_m36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_m36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_m364_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[96, 192, 384, 576], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_m364_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_b36(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_b36'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_b36_384(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_b36_384'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_b36_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_b36_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_b36_384_in21ft1k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_b36_384_in21ft1k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def caformer_b36_in21k(pretrained=False, **kwargs): + model = MetaFormer( + depths=[3, 12, 18, 3], + dims=[128, 256, 512, 768], + token_mixers=[SepConv, SepConv, Attention, Attention], + head_fn=MlpHead, + **kwargs) + model.default_cfg = default_cfgs['caformer_b36_in21k'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model \ No newline at end of file diff --git a/models/neuron.py b/models/neuron.py new file mode 100644 index 0000000000000000000000000000000000000000..86a24f84cd087566e9699112df8af2a82dc5e1a4 --- /dev/null +++ b/models/neuron.py @@ -0,0 +1,1587 @@ +from abc import abstractmethod +from typing import Callable, overload +import torch +import torch.nn as nn +from spikingjelly.clock_driven import surrogate, base, lava_exchange +from spikingjelly import configure +import math +import numpy as np +import logging +import cupy +from spikingjelly.clock_driven import neuron_kernel, cu_kernel_opt + + +try: + import lava.lib.dl.slayer as slayer + +except BaseException as e: + logging.info(f'spikingjelly.clock_driven.neuron: {e}') + slayer = None + + +def check_backend(backend: str): + if backend == 'torch': + return + elif backend == 'cupy': + assert cupy is not None, 'CuPy is not installed! You can install it from "https://github.com/cupy/cupy".' + elif backend == 'lava': + assert slayer is not None, 'Lava-DL is not installed! You can install it from "https://github.com/lava-nc/lava-dl".' + else: + raise NotImplementedError(backend) + + +class BaseNode(base.MemoryModule): + def __init__(self, v_threshold: float = 1., v_reset: float = 0., + surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): + """ + * :ref:`API in English ` + + .. _BaseNode.__init__-cn: + + :param v_threshold: 神经元的阈值电压 + :type v_threshold: float + + :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; + 如果设置为 ``None``,则电压会被减去 ``v_threshold`` + :type v_reset: float + + :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 + :type surrogate_function: Callable + + :param detach_reset: 是否将reset过程的计算图分离 + :type detach_reset: bool + + 可微分SNN神经元的基类神经元。 + + * :ref:`中文API ` + + .. _BaseNode.__init__-en: + + :param v_threshold: threshold voltage of neurons + :type v_threshold: float + + :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to + ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold`` + :type v_reset: float + + :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation + :type surrogate_function: Callable + + :param detach_reset: whether detach the computation graph of reset + :type detach_reset: bool + + This class is the base class of differentiable spiking neurons. + """ + assert isinstance(v_reset, float) or v_reset is None + assert isinstance(v_threshold, float) + assert isinstance(detach_reset, bool) + super().__init__() + + if v_reset is None: + self.register_memory('v', 0.) + else: + self.register_memory('v', v_reset) + + self.register_memory('v_threshold', v_threshold) + self.register_memory('v_reset', v_reset) + + self.detach_reset = detach_reset + self.surrogate_function = surrogate_function + + @abstractmethod + def neuronal_charge(self, x: torch.Tensor): + """ + * :ref:`API in English ` + + .. _BaseNode.neuronal_charge-cn: + + 定义神经元的充电差分方程。子类必须实现这个函数。 + + * :ref:`中文API ` + + .. _BaseNode.neuronal_charge-en: + + + Define the charge difference equation. The sub-class must implement this function. + """ + raise NotImplementedError + + def neuronal_fire(self): + """ + * :ref:`API in English ` + + .. _BaseNode.neuronal_fire-cn: + + 根据当前神经元的电压、阈值,计算输出脉冲。 + + * :ref:`中文API ` + + .. _BaseNode.neuronal_fire-en: + + + Calculate out spikes of neurons by their current membrane potential and threshold voltage. + """ + + return self.surrogate_function(self.v - self.v_threshold) + + def neuronal_reset(self, spike): + """ + * :ref:`API in English ` + + .. _BaseNode.neuronal_reset-cn: + + 根据当前神经元释放的脉冲,对膜电位进行重置。 + + * :ref:`中文API ` + + .. _BaseNode.neuronal_reset-en: + + + Reset the membrane potential according to neurons' output spikes. + """ + if self.detach_reset: + spike_d = spike.detach() + else: + spike_d = spike + + if self.v_reset is None: + # soft reset + self.v = self.v - spike_d * self.v_threshold + + else: + # hard reset + self.v = (1. - spike_d) * self.v + spike_d * self.v_reset + + def extra_repr(self): + return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}' + + def forward(self, x: torch.Tensor): + """ + + * :ref:`API in English ` + + .. _BaseNode.forward-cn: + + :param x: 输入到神经元的电压增量 + :type x: torch.Tensor + + :return: 神经元的输出脉冲 + :rtype: torch.Tensor + + 按照充电、放电、重置的顺序进行前向传播。 + + * :ref:`中文API ` + + .. _BaseNode.forward-en: + + :param x: increment of voltage inputted to neurons + :type x: torch.Tensor + + :return: out spikes of neurons + :rtype: torch.Tensor + + Forward by the order of `neuronal_charge`, `neuronal_fire`, and `neuronal_reset`. + + """ + self.neuronal_charge(x) + spike = self.neuronal_fire() + self.neuronal_reset(spike) + return spike + + +class AdaptiveBaseNode(BaseNode): + def __init__(self, v_threshold: float = 1., v_reset: float = 0., + v_rest: float = 0., w_rest: float = 0, tau_w: float = 2., a: float = 0., b: float = 0., + surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): + # b: jump amplitudes + # a: subthreshold coupling + assert isinstance(w_rest, float) + assert isinstance(v_rest, float) + assert isinstance(tau_w, float) + assert isinstance(a, float) + assert isinstance(b, float) + + super.__init__(v_threshold, v_reset, surrogate_function, detach_reset) + + self.register_memory('w', w_rest) + + self.w_rest = w_rest + self.v_rest = v_rest + self.tau_w = tau_w + self.a = a + self.b = b + + def neuronal_adaptation(self, spike): + self.w = self.w + 1. / self.tau_w * (self.a * (self.v - self.v_rest) - self.w) + self.b * spike + + def extra_repr(self): + return super().extra_repr() + f', v_rest={self.v_rest}, w_rest={self.w_rest}, tau_w={self.tau_w}, a={self.a}, b={self.b}' + + @overload + def forward(self, x: torch.Tensor): + self.neuronal_charge(x) + spike = self.neuronal_fire() + self.neuronal_adaptation(spike) + self.neuronal_reset(spike) + return spike + + +class IFNode(BaseNode): + def __init__(self, v_threshold: float = 1., v_reset: float = 0., + surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, + cupy_fp32_inference=False): + """ + * :ref:`API in English ` + + .. _IFNode.__init__-cn: + + :param v_threshold: 神经元的阈值电压 + :type v_threshold: float + + :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; + 如果设置为 ``None``,则电压会被减去 ``v_threshold`` + :type v_reset: float + + :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 + :type surrogate_function: Callable + + :param detach_reset: 是否将reset过程的计算图分离 + :type detach_reset: bool + + :param cupy_fp32_inference: 若为 `True`,在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速 + :type cupy_fp32_inference: bool + + Integrate-and-Fire 神经元模型,可以看作理想积分器,无输入时电压保持恒定,不会像LIF神经元那样衰减。其阈下神经动力学方程为: + + .. math:: + V[t] = V[t-1] + X[t] + + * :ref:`中文API ` + + .. _IFNode.__init__-en: + + :param v_threshold: threshold voltage of neurons + :type v_threshold: float + + :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to + ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold`` + :type v_reset: float + + :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation + :type surrogate_function: Callable + + :param detach_reset: whether detach the computation graph of reset + :type detach_reset: bool + + :param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this + module will use `cupy` to accelerate + :type cupy_fp32_inference: bool + + The Integrate-and-Fire neuron, which can be seen as a ideal integrator. The voltage of the IF neuron will not decay + as that of the LIF neuron. The subthreshold neural dynamics of it is as followed: + + .. math:: + V[t] = V[t-1] + X[t] + + """ + super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) + + if cupy_fp32_inference: + check_backend('cupy') + self.cupy_fp32_inference = cupy_fp32_inference + + def neuronal_charge(self, x: torch.Tensor): + self.v = self.v + x + + def forward(self, x: torch.Tensor): + if self.cupy_fp32_inference and cupy is not None and not self.training and x.dtype == torch.float32: + # cupy is installed && eval mode && fp32 + device_id = x.get_device() + if device_id < 0: + return super().forward(x) + + # use cupy to accelerate + if isinstance(self.v, float): + v = torch.zeros_like(x) + if self.v != 0.: + torch.fill_(v, self.v) + self.v = v + + if self.v_reset is None: + hard_reset = False + else: + hard_reset = True + + code = rf''' + extern "C" __global__ + void IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward( + const float * x, const float & v_threshold, {'const float & v_reset,' if hard_reset else ''} + float * spike, float * v, + const int & numel) + ''' + + code += r''' + { + const int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < numel) + { + v[index] += x[index]; + spike[index] = (float) (v[index] >= v_threshold); + ''' + + code += rf''' + {'v[index] = (1.0f - spike[index]) * v[index] + spike[index] * v_reset;' if hard_reset else 'v[index] -= spike[index] * v_threshold;'} + ''' + + code += r''' + } + } + ''' + if hasattr(self, 'cp_kernel'): + if self.cp_kernel.code != code: + # replace codes + del self.cp_kernel + self.cp_kernel = cupy.RawKernel(code, + f"IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward", + options=configure.cuda_compiler_options, + backend=configure.cuda_compiler_backend) + else: + self.cp_kernel = cupy.RawKernel(code, + f"IFNode_{'hard' if hard_reset else 'soft'}_reset_inference_forward", + options=configure.cuda_compiler_options, + backend=configure.cuda_compiler_backend) + + with cu_kernel_opt.DeviceEnvironment(device_id): + numel = x.numel() + threads = configure.cuda_threads + blocks = cu_kernel_opt.cal_blocks(numel) + cp_numel = cupy.asarray(numel) + cp_v_threshold = cupy.asarray(self.v_threshold, dtype=np.float32) + if hard_reset: + cp_v_reset = cupy.asarray(self.v_reset, dtype=np.float32) + + spike = torch.zeros_like(x) + if hard_reset: + x, cp_v_threshold, cp_v_reset, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, + cp_v_threshold, + cp_v_reset, + spike, self.v, + cp_numel) + kernel_args = [x, cp_v_threshold, cp_v_reset, spike, self.v, cp_numel] + else: + x, cp_v_threshold, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, cp_v_threshold, spike, + self.v, cp_numel) + kernel_args = [x, cp_v_threshold, spike, self.v, cp_numel] + self.cp_kernel( + (blocks,), (threads,), + cu_kernel_opt.wrap_args_to_raw_kernel( + device_id, + *kernel_args + ) + ) + return spike + else: + return super().forward(x) + + +class MultiStepIFNode(IFNode): + def __init__(self, v_threshold: float = 1., v_reset: float = 0., + surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch', + lava_s_cale=1 << 6): + """ + * :ref:`API in English ` + + .. _MultiStepIFNode.__init__-cn: + + :param v_threshold: 神经元的阈值电压 + :type v_threshold: float + + :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; + 如果设置为 ``None``,则电压会被减去 ``v_threshold`` + :type v_reset: float + + :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 + :type surrogate_function: Callable + + :param detach_reset: 是否将reset过程的计算图分离 + :type detach_reset: bool + + :param backend: 使用哪种计算后端,可以为 ``'torch'`` 或 ``'cupy'``。``'cupy'`` 速度更快,但仅支持GPU。 + :type backend: str + + 多步版本的 :class:`spikingjelly.clock_driven.neuron.IFNode`。 + + .. tip:: + + 对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够 + 使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。 + + .. tip:: + + 阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。 + + * :ref:`中文API ` + + .. _MultiStepIFNode.__init__-en: + + :param v_threshold: threshold voltage of neurons + :type v_threshold: float + + :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to + ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold`` + :type v_reset: float + + :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation + :type surrogate_function: Callable + + :param detach_reset: whether detach the computation graph of reset + :type detach_reset: bool + + :param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU + :type backend: str + + The multi-step version of :class:`spikingjelly.clock_driven.neuron.IFNode`. + + .. admonition:: Tip + :class: tip + + The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at + time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T`` + time-steps by ``.v_seq`` and ``.spike_seq``. + + .. admonition:: Tip + :class: tip + + Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step + and multi-step propagation. + + """ + super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) + + self.register_memory('v_seq', None) + + check_backend(backend) + + self.backend = backend + + self.lava_s_cale = lava_s_cale + + if backend == 'lava': + self.lava_neuron = self.to_lava() + else: + self.lava_neuron = None + + def forward(self, x_seq: torch.Tensor): + assert x_seq.dim() > 1 + # x_seq.shape = [T, *] + + if self.backend == 'torch': + spike_seq = [] + self.v_seq = [] + for t in range(x_seq.shape[0]): + spike_seq.append(super().forward(x_seq[t]).unsqueeze(0)) + self.v_seq.append(self.v.unsqueeze(0)) + spike_seq = torch.cat(spike_seq, 0) + self.v_seq = torch.cat(self.v_seq, 0) + return spike_seq + + elif self.backend == 'cupy': + if isinstance(self.v, float): + v_init = self.v + self.v = torch.zeros_like(x_seq[0].data) + if v_init != 0.: + torch.fill_(self.v, v_init) + + spike_seq, self.v_seq = neuron_kernel.MultiStepIFNodePTT.apply( + x_seq.flatten(1), self.v.flatten(0), self.v_threshold, self.v_reset, self.detach_reset, + self.surrogate_function.cuda_code) + + spike_seq = spike_seq.reshape(x_seq.shape) + self.v_seq = self.v_seq.reshape(x_seq.shape) + + self.v = self.v_seq[-1].clone() + + return spike_seq + + elif self.backend == 'lava': + if self.lava_neuron is None: + self.lava_neuron = self.to_lava() + + spike, self.v = lava_exchange.lava_neuron_forward(self.lava_neuron, x_seq, self.v) + + return spike + + else: + raise NotImplementedError(self.backend) + + def extra_repr(self): + return super().extra_repr() + f', backend={self.backend}' + + def to_lava(self): + return lava_exchange.to_lava_neuron(self) + + def reset(self): + super().reset() + if self.lava_neuron is not None: + self.lava_neuron.current_state.zero_() + self.lava_neuron.voltage_state.zero_() + + +class LIFNode(BaseNode): + def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1., + v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), + detach_reset: bool = False, cupy_fp32_inference=False): + """ + * :ref:`API in English ` + + .. _LIFNode.__init__-cn: + + :param tau: 膜电位时间常数 + :type tau: float + + :param decay_input: 输入是否会衰减 + :type decay_input: bool + + :param v_threshold: 神经元的阈值电压 + :type v_threshold: float + + :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; + 如果设置为 ``None``,则电压会被减去 ``v_threshold`` + :type v_reset: float + + :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 + :type surrogate_function: Callable + + :param detach_reset: 是否将reset过程的计算图分离 + :type detach_reset: bool + + :param cupy_fp32_inference: 若为 `True`,在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速 + :type cupy_fp32_inference: bool + + Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为: + + 若 ``decay_input == True``: + + .. math:: + V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) + + 若 ``decay_input == False``: + + .. math:: + V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] + + .. tip:: + + 在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速。 + + * :ref:`中文API ` + + .. _LIFNode.__init__-en: + + :param tau: membrane time constant + :type tau: float + + :param decay_input: whether the input will decay + :type decay_input: bool + + :param v_threshold: threshold voltage of neurons + :type v_threshold: float + + :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to + ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold`` + :type v_reset: float + + :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation + :type surrogate_function: Callable + + :param detach_reset: whether detach the computation graph of reset + :type detach_reset: bool + + :param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this + module will use `cupy` to accelerate + :type cupy_fp32_inference: bool + + The Leaky Integrate-and-Fire neuron, which can be seen as a leaky integrator. + The subthreshold neural dynamics of it is as followed: + + IF ``decay_input == True``: + + .. math:: + V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) + + IF ``decay_input == False``: + + .. math:: + V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] + + .. admonition:: Tip + :class: tip + + If this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this + module will use `cupy` to accelerate. + + """ + assert isinstance(tau, float) and tau > 1. + + super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) + self.tau = tau + self.decay_input = decay_input + + if cupy_fp32_inference: + check_backend('cupy') + self.cupy_fp32_inference = cupy_fp32_inference + + def extra_repr(self): + return super().extra_repr() + f', tau={self.tau}' + + def neuronal_charge(self, x: torch.Tensor): + if self.decay_input: + if self.v_reset is None or self.v_reset == 0.: + self.v = self.v + (x - self.v) / self.tau + else: + self.v = self.v + (x - (self.v - self.v_reset)) / self.tau + + else: + if self.v_reset is None or self.v_reset == 0.: + self.v = self.v * (1. - 1. / self.tau) + x + else: + self.v = self.v - (self.v - self.v_reset) / self.tau + x + + def forward(self, x: torch.Tensor): + if self.cupy_fp32_inference and cupy is not None and not self.training and x.dtype == torch.float32: + # cupy is installed && eval mode && fp32 + device_id = x.get_device() + if device_id < 0: + return super().forward(x) + + # use cupy to accelerate + if isinstance(self.v, float): + v = torch.zeros_like(x) + if self.v != 0.: + torch.fill_(v, self.v) + self.v = v + + if self.v_reset is None: + hard_reset = False + else: + hard_reset = True + + code = rf''' + extern "C" __global__ + void LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward( + const float * x, const float & v_threshold, {'const float & v_reset,' if hard_reset else ''} const float & tau, + float * spike, float * v, + const int & numel) + ''' + + code += r''' + { + const int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < numel) + { + + ''' + + if self.decay_input: + if hard_reset: + code += r''' + v[index] += (x[index] - (v[index] - v_reset)) / tau; + ''' + else: + code += r''' + v[index] += (x[index] - v[index]) / tau; + ''' + else: + if hard_reset: + code += r''' + v[index] = x[index] + v[index] - (v[index] - v_reset) / tau; + ''' + else: + code += r''' + v[index] = x[index] + v[index] * (1.0f - 1.0f / tau); + ''' + + code += rf''' + spike[index] = (float) (v[index] >= v_threshold); + {'v[index] = (1.0f - spike[index]) * v[index] + spike[index] * v_reset;' if hard_reset else 'v[index] -= spike[index] * v_threshold;'} + ''' + + code += r''' + } + } + ''' + if hasattr(self, 'cp_kernel'): + if self.cp_kernel.code != code: + # replace codes + del self.cp_kernel + self.cp_kernel = cupy.RawKernel(code, + f"LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward", + options=configure.cuda_compiler_options, + backend=configure.cuda_compiler_backend) + else: + self.cp_kernel = cupy.RawKernel(code, + f"LIFNode_{'hard' if hard_reset else 'soft'}_reset_decayInput_{self.decay_input}_inference_forward", + options=configure.cuda_compiler_options, + backend=configure.cuda_compiler_backend) + + with cu_kernel_opt.DeviceEnvironment(device_id): + numel = x.numel() + threads = configure.cuda_threads + blocks = cu_kernel_opt.cal_blocks(numel) + cp_numel = cupy.asarray(numel) + cp_v_threshold = cupy.asarray(self.v_threshold, dtype=np.float32) + if hard_reset: + cp_v_reset = cupy.asarray(self.v_reset, dtype=np.float32) + cp_tau = cupy.asarray(self.tau, dtype=np.float32) + spike = torch.zeros_like(x) + if hard_reset: + x, cp_v_threshold, cp_v_reset, cp_tau, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, + cp_v_threshold, + cp_v_reset, + cp_tau, + spike, + self.v, + cp_numel) + kernel_args = [x, cp_v_threshold, cp_v_reset, cp_tau, spike, self.v, cp_numel] + else: + x, cp_v_threshold, cp_tau, spike, self.v, cp_numel = cu_kernel_opt.get_contiguous(x, cp_v_threshold, + cp_tau, spike, + self.v, cp_numel) + kernel_args = [x, cp_v_threshold, cp_tau, spike, self.v, cp_numel] + + self.cp_kernel( + (blocks,), (threads,), + cu_kernel_opt.wrap_args_to_raw_kernel( + device_id, + *kernel_args + ) + ) + return spike + else: + return super().forward(x) + + +class MultiStepLIFNode(LIFNode): + def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1., + v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), + detach_reset: bool = False, backend='torch', lava_s_cale=1 << 6): + """ + * :ref:`API in English ` + + .. _MultiStepLIFNode.__init__-cn: + + :param tau: 膜电位时间常数 + :type tau: float + + :param decay_input: 输入是否会衰减 + :type decay_input: bool + + :param v_threshold: 神经元的阈值电压 + :type v_threshold: float + + :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; + 如果设置为 ``None``,则电压会被减去 ``v_threshold`` + :type v_reset: float + + :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 + :type surrogate_function: Callable + + :param detach_reset: 是否将reset过程的计算图分离 + :type detach_reset: bool + + :param backend: 使用哪种计算后端,可以为 ``'torch'`` 或 ``'cupy'``。``'cupy'`` 速度更快,但仅支持GPU。 + :type backend: str + + 多步版本的 :class:`spikingjelly.clock_driven.neuron.LIFNode`。 + + .. tip:: + + 对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够 + 使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。 + + .. tip:: + + 阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。 + + * :ref:`中文API ` + + .. _MultiStepLIFNode.__init__-en: + + :param tau: membrane time constant + :type tau: float + + :param decay_input: whether the input will decay + :type decay_input: bool + + :param v_threshold: threshold voltage of neurons + :type v_threshold: float + + :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to + ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold`` + :type v_reset: float + + :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation + :type surrogate_function: Callable + + :param detach_reset: whether detach the computation graph of reset + :type detach_reset: bool + + :param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU + :type backend: str + + The multi-step version of :class:`spikingjelly.clock_driven.neuron.LIFNode`. + + .. admonition:: Tip + :class: tip + + The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at + time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T`` + time-steps by ``.v_seq`` and ``.spike_seq``. + + .. admonition:: Tip + :class: tip + + Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step + and multi-step propagation. + + """ + super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset) + self.register_memory('v_seq', None) + + check_backend(backend) + + self.backend = backend + + self.lava_s_cale = lava_s_cale + + if backend == 'lava': + self.lava_neuron = self.to_lava() + else: + self.lava_neuron = None + + def forward(self, x_seq: torch.Tensor): + assert x_seq.dim() > 1 + # x_seq.shape = [T, *] + + if self.backend == 'torch': + spike_seq = [] + self.v_seq = [] + for t in range(x_seq.shape[0]): + spike_seq.append(super().forward(x_seq[t]).unsqueeze(0)) + self.v_seq.append(self.v.unsqueeze(0)) + spike_seq = torch.cat(spike_seq, 0) + self.v_seq = torch.cat(self.v_seq, 0) + return spike_seq + + elif self.backend == 'cupy': + if isinstance(self.v, float): + v_init = self.v + self.v = torch.zeros_like(x_seq[0].data) + if v_init != 0.: + torch.fill_(self.v, v_init) + + spike_seq, self.v_seq = neuron_kernel.MultiStepLIFNodePTT.apply( + x_seq.flatten(1), self.v.flatten(0), self.decay_input, self.tau, self.v_threshold, self.v_reset, + self.detach_reset, self.surrogate_function.cuda_code) + + spike_seq = spike_seq.reshape(x_seq.shape) + self.v_seq = self.v_seq.reshape(x_seq.shape) + + self.v = self.v_seq[-1].clone() + + return spike_seq + + elif self.backend == 'lava': + if self.lava_neuron is None: + self.lava_neuron = self.to_lava() + + spike, self.v = lava_exchange.lava_neuron_forward(self.lava_neuron, x_seq, self.v) + + return spike + + else: + raise NotImplementedError(self.backend) + + def extra_repr(self): + return super().extra_repr() + f', backend={self.backend}' + + def to_lava(self): + return lava_exchange.to_lava_neuron(self) + + def reset(self): + super().reset() + if self.lava_neuron is not None: + self.lava_neuron.current_state.zero_() + self.lava_neuron.voltage_state.zero_() + + +class ParametricLIFNode(BaseNode): + def __init__(self, init_tau: float = 2.0, decay_input: bool = True, v_threshold: float = 1., + v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), + detach_reset: bool = False): + """ + * :ref:`API in English ` + + .. _ParametricLIFNode.__init__-cn: + + :param init_tau: 膜电位时间常数的初始值 + :type init_tau: float + + :param decay_input: 输入是否会衰减 + :type decay_input: bool + + :param v_threshold: 神经元的阈值电压 + :type v_threshold: float + + :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; + 如果设置为 ``None``,则电压会被减去 ``v_threshold`` + :type v_reset: float + + :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 + :type surrogate_function: Callable + + :param detach_reset: 是否将reset过程的计算图分离 + :type detach_reset: bool + + `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks `_ + 提出的 Parametric Leaky Integrate-and-Fire (PLIF)神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为: + + 若 ``decay_input == True``: + + .. math:: + V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) + + 若 ``decay_input == False``: + + .. math:: + V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] + + 其中 :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`,:math:`w` 是可学习的参数。 + + * :ref:`中文API ` + + .. _ParametricLIFNode.__init__-en: + + :param init_tau: the initial value of membrane time constant + :type init_tau: float + + :param decay_input: whether the input will decay + :type decay_input: bool + + :param v_threshold: threshold voltage of neurons + :type v_threshold: float + + :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to + ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold`` + :type v_reset: float + + :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation + :type surrogate_function: Callable + + :param detach_reset: whether detach the computation graph of reset + :type detach_reset: bool + + The Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks `_ and can be seen as a leaky integrator. + The subthreshold neural dynamics of it is as followed: + + IF ``decay_input == True``: + + .. math:: + V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) + + IF ``decay_input == False``: + + .. math:: + V[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] + + where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter. + """ + + assert isinstance(init_tau, float) and init_tau > 1. + super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) + self.decay_input = decay_input + init_w = - math.log(init_tau - 1.) + self.w = nn.Parameter(torch.as_tensor(init_w)) + + def extra_repr(self): + with torch.no_grad(): + tau = 1. / self.w.sigmoid() + return super().extra_repr() + f', tau={tau}' + + def neuronal_charge(self, x: torch.Tensor): + if self.decay_input: + if self.v_reset is None or self.v_reset == 0.: + self.v = self.v + (x - self.v) * self.w.sigmoid() + else: + self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid() + else: + if self.v_reset is None or self.v_reset == 0.: + self.v = self.v * (1. - self.w.sigmoid()) + x + else: + self.v = self.v - (self.v - self.v_reset) * self.w.sigmoid() + x + + +class MultiStepParametricLIFNode(ParametricLIFNode): + def __init__(self, init_tau: float = 2., decay_input: bool = True, v_threshold: float = 1., + v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), + detach_reset: bool = False, backend='torch'): + """ + * :ref:`API in English ` + + .. _MultiStepParametricLIFNode.__init__-cn: + + :param init_tau: 膜电位时间常数的初始值 + :type init_tau: float + + :param decay_input: 输入是否会衰减 + :type decay_input: bool + + :param v_threshold: 神经元的阈值电压 + :type v_threshold: float + + :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; + 如果设置为 ``None``,则电压会被减去 ``v_threshold`` + :type v_reset: float + + :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 + :type surrogate_function: Callable + + :param detach_reset: 是否将reset过程的计算图分离 + :type detach_reset: bool + + 多步版本的 `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks `_ + 提出的 Parametric Leaky Integrate-and-Fire (PLIF)神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为: + + .. math:: + V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}) + + 其中 :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`,:math:`w` 是可学习的参数。 + + .. tip:: + + 对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够 + 使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。 + + .. tip:: + + 阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。 + + * :ref:`中文API ` + + .. _MultiStepParametricLIFNode.__init__-en: + + :param init_tau: the initial value of membrane time constant + :type init_tau: float + + :param decay_input: whether the input will decay + :type decay_input: bool + + :param v_threshold: threshold voltage of neurons + :type v_threshold: float + + :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to + ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold`` + :type v_reset: float + + :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation + :type surrogate_function: Callable + + :param detach_reset: whether detach the computation graph of reset + :type detach_reset: bool + + :param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU + :type backend: str + + The multi-step Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks `_ and can be seen as a leaky integrator. + The subthreshold neural dynamics of it is as followed: + + .. math:: + V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}) + + where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter. + + .. admonition:: Tip + :class: tip + + The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at + time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T`` + time-steps by ``.v_seq`` and ``.spike_seq``. + + .. admonition:: Tip + :class: tip + + Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step + and multi-step propagation. + """ + super().__init__(init_tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset) + self.register_memory('v_seq', None) + + check_backend(backend) + + self.backend = backend + + def forward(self, x_seq: torch.Tensor): + assert x_seq.dim() > 1 + # x_seq.shape = [T, *] + + if self.backend == 'torch': + spike_seq = [] + self.v_seq = [] + for t in range(x_seq.shape[0]): + spike_seq.append(super().forward(x_seq[t]).unsqueeze(0)) + self.v_seq.append(self.v.unsqueeze(0)) + spike_seq = torch.cat(spike_seq, 0) + self.v_seq = torch.cat(self.v_seq, 0) + return spike_seq + + elif self.backend == 'cupy': + if isinstance(self.v, float): + v_init = self.v + self.v = torch.zeros_like(x_seq[0].data) + if v_init != 0.: + torch.fill_(self.v, v_init) + + spike_seq, self.v_seq = neuron_kernel.MultiStepParametricLIFNodePTT.apply( + x_seq.flatten(1), self.v.flatten(0), self.w.sigmoid(), self.decay_input, self.v_threshold, self.v_reset, + self.detach_reset, self.surrogate_function.cuda_code) + + spike_seq = spike_seq.reshape(x_seq.shape) + self.v_seq = self.v_seq.reshape(x_seq.shape) + + self.v = self.v_seq[-1].clone() + + return spike_seq + else: + raise NotImplementedError + + def extra_repr(self): + return super().extra_repr() + f', backend={self.backend}' + + +class QIFNode(BaseNode): + def __init__(self, tau: float = 2., v_c: float = 0.8, a0: float = 1., v_threshold: float = 1., v_rest: float = 0., + v_reset: float = -0.1, + surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): + """ + * :ref:`API in English ` + + .. _QIFNode.__init__-cn: + + :param tau: 膜电位时间常数 + :type tau: float + + :param v_c: 关键电压 + :type v_c: float + + :param a0: + :type a0: float + + :param v_threshold: 神经元的阈值电压 + :type v_threshold: float + + :param v_rest: 静息电位 + :type v_rest: float + + :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; + 如果设置为 ``None``,则电压会被减去 ``v_threshold`` + :type v_reset: float + + :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 + :type surrogate_function: Callable + + :param detach_reset: 是否将reset过程的计算图分离 + :type detach_reset: bool + + + Quadratic Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,也是指数积分发放神经元(Exponential Integrate-and-Fire)的近似版本。其阈下神经动力学方程为: + + .. math:: + V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c)) + + * :ref:`中文API ` + + .. _QIFNode.__init__-en: + + :param tau: membrane time constant + :type tau: float + + :param v_c: critical voltage + :type v_c: float + + :param a0: + :type a0: float + + :param v_threshold: threshold voltage of neurons + :type v_threshold: float + + :param v_rest: resting potential + :type v_rest: float + + :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to + ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold`` + :type v_reset: float + + :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation + :type surrogate_function: Callable + + :param detach_reset: whether detach the computation graph of reset + :type detach_reset: bool + + The Quadratic Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an approximation of the Exponential Integrate-and-Fire model. + The subthreshold neural dynamics of it is as followed: + + .. math:: + V[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c)) + """ + + assert isinstance(tau, float) and tau > 1. + if v_reset is not None: + assert v_threshold > v_reset + assert v_rest >= v_reset + assert a0 > 0 + + super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) + self.tau = tau + self.v_c = v_c + self.v_rest = v_rest + self.a0 = a0 + + def extra_repr(self): + return super().extra_repr() + f', tau={self.tau}, v_c={self.v_c}, a0={self.a0}, v_rest={self.v_rest}' + + def neuronal_charge(self, x: torch.Tensor): + self.v = self.v + (x + self.a0 * (self.v - self.v_rest) * (self.v - self.v_c)) / self.tau + + +class EIFNode(BaseNode): + def __init__(self, tau: float = 2., delta_T: float = 1., theta_rh: float = .8, v_threshold: float = 1., + v_rest: float = 0., v_reset: float = -0.1, + surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): + """ + * :ref:`API in English ` + + .. _EIFNode.__init__-cn: + + :param tau: 膜电位时间常数 + :type tau: float + + :param delta_T: 陡峭度参数 + :type delta_T: float + + :param theta_rh: 基强度电压阈值 + :type theta_rh: float + + :param v_threshold: 神经元的阈值电压 + :type v_threshold: float + + :param v_rest: 静息电位 + :type v_rest: float + + :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; + 如果设置为 ``None``,则电压会被减去 ``v_threshold`` + :type v_reset: float + + :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 + :type surrogate_function: Callable + + :param detach_reset: 是否将reset过程的计算图分离 + :type detach_reset: bool + + + Exponential Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,是由HH神经元模型(Hodgkin-Huxley model)简化后推导出的一维模型。在 :math:`\\Delta_T\\to 0` 时退化为LIF模型。其阈下神经动力学方程为: + + .. math:: + V[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right) + + * :ref:`中文API ` + + .. _EIFNode.__init__-en: + + :param tau: membrane time constant + :type tau: float + + :param delta_T: sharpness parameter + :type delta_T: float + + :param theta_rh: rheobase threshold + :type theta_rh: float + + :param v_threshold: threshold voltage of neurons + :type v_threshold: float + + :param v_rest: resting potential + :type v_rest: float + + :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to + ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold`` + :type v_reset: float + + :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation + :type surrogate_function: Callable + + :param detach_reset: whether detach the computation graph of reset + :type detach_reset: bool + + The Exponential Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an one-dimensional model derived from the Hodgkin-Huxley model. It degenerates to the LIF model when :math:`\\Delta_T\\to 0`. + The subthreshold neural dynamics of it is as followed: + + .. math:: + V[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right) + """ + + assert isinstance(tau, float) and tau > 1. + if v_reset is not None: + assert v_threshold > v_reset + assert v_rest >= v_reset + assert delta_T > 0 + + super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) + self.tau = tau + self.delta_T = delta_T + self.v_rest = v_rest + self.theta_rh = theta_rh + + def extra_repr(self): + return super().extra_repr() + f', tau={self.tau}, delta_T={self.delta_T}, theta_rh={self.theta_rh}' + + def neuronal_charge(self, x: torch.Tensor): + + with torch.no_grad(): + if not isinstance(self.v, torch.Tensor): + self.v = torch.as_tensor(self.v, device=x.device) + + self.v = self.v + (x + self.v_rest - self.v + self.delta_T * torch.exp( + (self.v - self.theta_rh) / self.delta_T)) / self.tau + + +class MultiStepEIFNode(EIFNode): + def __init__(self, tau: float = 2., delta_T: float = 1., theta_rh: float = .8, v_threshold: float = 1., + v_rest: float = 0., v_reset: float = -0.1, + surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch'): + """ + * :ref:`API in English ` + + .. _MultiStepEIFNode.__init__-cn: + + ::param tau: 膜电位时间常数 + :type tau: float + + :param delta_T: 陡峭度参数 + :type delta_T: float + + :param theta_rh: 基强度电压阈值 + :type theta_rh: float + + :param v_threshold: 神经元的阈值电压 + :type v_threshold: float + + :param v_rest: 静息电位 + :type v_rest: float + + :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; + 如果设置为 ``None``,则电压会被减去 ``v_threshold`` + :type v_reset: float + + :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 + :type surrogate_function: Callable + + :param detach_reset: 是否将reset过程的计算图分离 + :type detach_reset: bool + + 多步版本的 :class:`spikingjelly.clock_driven.neuron.EIFNode`。 + + .. tip:: + + 对于多步神经元,输入 ``x_seq.shape = [T, *]``,不仅可以使用 ``.v`` 和 ``.spike`` 获取 ``t = T - 1`` 时刻的电压和脉冲,还能够 + 使用 ``.v_seq`` 和 ``.spike_seq`` 获取完整的 ``T`` 个时刻的电压和脉冲。 + + .. tip:: + + 阅读 :doc:`传播模式 <./clock_driven/10_propagation_pattern>` 以获取更多关于单步和多步传播的信息。 + + * :ref:`中文API ` + + .. _MultiStepEIFNode.__init__-en: + + :param tau: membrane time constant + :type tau: float + + :param delta_T: sharpness parameter + :type delta_T: float + + :param theta_rh: rheobase threshold + :type theta_rh: float + + :param v_threshold: threshold voltage of neurons + :type v_threshold: float + + :param v_rest: resting potential + :type v_rest: float + + :param v_reset: reset voltage of neurons. If not ``None``, voltage of neurons that just fired spikes will be set to + ``v_reset``. If ``None``, voltage of neurons that just fired spikes will subtract ``v_threshold`` + :type v_reset: float + + :param surrogate_function: surrogate function for replacing gradient of spiking functions during back-propagation + :type surrogate_function: Callable + + :param detach_reset: whether detach the computation graph of reset + :type detach_reset: bool + + :param backend: use which backend, ``'torch'`` or ``'cupy'``. ``'cupy'`` is faster but only supports GPU + :type backend: str + + .. admonition:: Tip + :class: tip + + The input for multi-step neurons are ``x_seq.shape = [T, *]``. We can get membrane potential and spike at + time-step ``t = T - 1`` by ``.v`` and ``.spike``. We can also get membrane potential and spike at all ``T`` + time-steps by ``.v_seq`` and ``.spike_seq``. + + .. admonition:: Tip + :class: tip + + Read :doc:`Propagation Pattern <./clock_driven_en/10_propagation_pattern>` for more details about single-step + and multi-step propagation. + """ + super().__init__(tau, delta_T, theta_rh, v_threshold, v_rest, v_reset, + surrogate_function, detach_reset) + self.register_memory('v_seq', None) + + check_backend(backend) + + self.backend = backend + + def forward(self, x_seq: torch.Tensor): + assert x_seq.dim() > 1 + # x_seq.shape = [T, *] + + if self.backend == 'torch': + spike_seq = [] + self.v_seq = [] + for t in range(x_seq.shape[0]): + spike_seq.append(super().forward(x_seq[t]).unsqueeze(0)) + self.v_seq.append(self.v.unsqueeze(0)) + spike_seq = torch.cat(spike_seq, 0) + self.v_seq = torch.cat(self.v_seq, 0) + return spike_seq + + elif self.backend == 'cupy': + if isinstance(self.v, float): + v_init = self.v + self.v = torch.zeros_like(x_seq[0].data) + if v_init != 0.: + torch.fill_(self.v, v_init) + + spike_seq, self.v_seq = neuron_kernel.MultiStepEIFNodePTT.apply( + x_seq.flatten(1), self.v.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest, + self.theta_rh, self.delta_T, self.detach_reset, self.surrogate_function.cuda_code) + + spike_seq = spike_seq.reshape(x_seq.shape) + self.v_seq = self.v_seq.reshape(x_seq.shape) + + self.v = self.v_seq[-1].clone() + + return spike_seq + else: + raise NotImplementedError + + def extra_repr(self): + return super().extra_repr() + f', backend={self.backend}' + + +class GeneralNode(BaseNode): + def __init__(self, a: float or torch.Tensor, b: float or torch.Tensor, c: float or torch.Tensor = 0., + v_threshold: float = 1., v_reset: float = 0., + surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False): + super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) + self.a = self.register_buffer('a', torch.as_tensor(a)) + self.b = self.register_buffer('b', torch.as_tensor(b)) + self.c = self.register_buffer('c', torch.as_tensor(c)) + + def neuronal_charge(self, x: torch.Tensor): + self.v = self.a * self.v + self.b * x + self.c + + +class MultiStepGeneralNode(GeneralNode): + def __init__(self, a: float, b: float, c: float, v_threshold: float = 1., v_reset: float = 0., + surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, backend='torch'): + + super().__init__(v_threshold, v_reset, surrogate_function, detach_reset) + + self.register_memory('v_seq', None) + + check_backend(backend) + + self.backend = backend + + def forward(self, x_seq: torch.Tensor): + assert x_seq.dim() > 1 + # x_seq.shape = [T, *] + + if self.backend == 'torch': + spike_seq = [] + self.v_seq = [] + for t in range(x_seq.shape[0]): + spike_seq.append(super().forward(x_seq[t]).unsqueeze(0)) + self.v_seq.append(self.v.unsqueeze(0)) + spike_seq = torch.cat(spike_seq, 0) + self.v_seq = torch.cat(self.v_seq, 0) + return spike_seq + + elif self.backend == 'cupy': + if isinstance(self.v, float): + v_init = self.v + self.v = torch.zeros_like(x_seq[0].data) + if v_init != 0.: + torch.fill_(self.v, v_init) + + raise NotImplementedError + + spike_seq = spike_seq.reshape(x_seq.shape) + self.v_seq = self.v_seq.reshape(x_seq.shape) + + self.v = self.v_seq[-1].clone() + + return spike_seq + else: + raise NotImplementedError + + def extra_repr(self): + return super().extra_repr() + f', backend={self.backend}' + + +class LIAFNode(LIFNode): + def __init__(self, act: Callable, threshold_related: bool, *args, **kwargs): + """ + :param act: the activation function + :type act: Callable + :param threshold_related: whether the neuron uses threshold related (TR mode). If true, `y = act(h - v_th)`, + otherwise `y = act(h)` + :type threshold_related: bool + + Other parameters in `*args, **kwargs` are same with :class:`LIFNode`. + + The LIAF neuron proposed in `LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing `_. + + .. admonition:: Warning + :class: warning + + The outputs of this neuron are not binary spikes. + + """ + super().__init__(*args, **kwargs) + self.act = act + self.threshold_related = threshold_related + + def forward(self, x: torch.Tensor): + self.neuronal_charge(x) + if self.threshold_related: + y = self.act(self.v - self.v_threshold) + else: + y = self.act(self.v) + spike = self.neuronal_fire() + self.neuronal_reset(spike) + return y + + diff --git a/models/q_vit/Quant.py b/models/q_vit/Quant.py new file mode 100644 index 0000000000000000000000000000000000000000..7cba6ed08d632ebe2e00c108db96030ce386ef16 --- /dev/null +++ b/models/q_vit/Quant.py @@ -0,0 +1,185 @@ +import torch +import torch.nn.functional as F +from torch.nn.modules.linear import Linear +import math +from torch.nn.parameter import Parameter +from ._quan_base import _Conv2dQ, Qmodes, _LinearQ, _ActQ + + +__all__ = ['Conv2dQ', 'LinearQ', 'ActQ'] + + +class FunQ(torch.autograd.Function): + @staticmethod + def forward(ctx, weight, alpha, g, Qn, Qp): + assert alpha > 0, 'alpha = {}'.format(alpha) + ctx.save_for_backward(weight, alpha) + ctx.other = g, Qn, Qp + q_w = (weight / alpha).round().clamp(Qn, Qp) + w_q = q_w * alpha + return w_q + + @staticmethod + def backward(ctx, grad_weight): + weight, alpha = ctx.saved_tensors + g, Qn, Qp = ctx.other + q_w = weight / alpha + indicate_small = (q_w < Qn).float() + indicate_big = (q_w > Qp).float() + # indicate_middle = torch.ones(indicate_small.shape).to(indicate_small.device) - indicate_small - indicate_big + indicate_middle = 1.0 - indicate_small - indicate_big # Thanks to @haolibai + grad_alpha = ((indicate_small * Qn + indicate_big * Qp + indicate_middle * ( + -q_w + q_w.round())) * grad_weight * g).sum().unsqueeze(dim=0) + grad_weight = indicate_middle * grad_weight + # The following operation can make sure that alpha is always greater than zero in any case and can also + # suppress the update speed of alpha. (Personal understanding) + # grad_alpha.clamp_(-alpha.item(), alpha.item()) # FYI + return grad_weight, grad_alpha, None, None, None + + +def grad_scale(x, scale): + y = x + y_grad = x * scale + return y.detach() - y_grad.detach() + y_grad + + +def round_pass(x): + y = x.round() + y_grad = x + return y.detach() - y_grad.detach() + y_grad + + +class Conv2dQ(_Conv2dQ): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, nbits_w=8, mode=Qmodes.kernel_wise, **kwargs): + super(Conv2dQ, self).__init__( + in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, + stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, + nbits=nbits_w, mode=mode) + self.act = ActQ(in_features=in_channels, nbits_a=nbits_w) + + def forward(self, x): + if self.alpha is None: + return F.conv2d(x, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + # w_reshape = self.weight.reshape([self.weight.shape[0], -1]).transpose(0, 1) + Qn = -2 ** (self.nbits - 1) + Qp = 2 ** (self.nbits - 1) - 1 + if self.training and self.init_state == 0: + # self.alpha.data.copy_(self.weight.abs().max() / 2 ** (self.nbits - 1)) + self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp)) + # self.alpha.data.copy_(self.weight.abs().max() * 2) + self.init_state.fill_(1) + """ + Implementation according to paper. + Feels wrong ... + When we initialize the alpha as a big number (e.g., self.weight.abs().max() * 2), + the clamp function can be skipped. + Then we get w_q = w / alpha * alpha = w, and $\frac{\partial w_q}{\partial \alpha} = 0$ + As a result, I don't think the pseudo-code in the paper echoes the formula. + + Please see jupyter/STE_LSQ.ipynb fo detailed comparison. + """ + g = 1.0 / math.sqrt(self.weight.numel() * Qp) + + # Method1: 31GB GPU memory (AlexNet w4a4 bs 2048) 17min/epoch + alpha = grad_scale(self.alpha, g) + # print(alpha.shape) + # print(self.weight.shape) + alpha = alpha.unsqueeze(1).unsqueeze(2).unsqueeze(3) + w_q = round_pass((self.weight / alpha).clamp(Qn, Qp)) * alpha + + x = self.act(x) + # w = w.clamp(Qn, Qp) + # q_w = round_pass(w) + # w_q = q_w * alpha + + # Method2: 25GB GPU memory (AlexNet w4a4 bs 2048) 32min/epoch + # w_q = FunLSQ.apply(self.weight, self.alpha, g, Qn, Qp) + # wq = y.transpose(0, 1).reshape(self.weight.shape).detach() + self.weight - self.weight.detach() + return F.conv2d(x, w_q, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +class LinearQ(_LinearQ): + def __init__(self, in_features, out_features, bias=True, nbits_w=4, **kwargs): + super(LinearQ, self).__init__(in_features=in_features, + out_features=out_features, bias=bias, nbits=nbits_w, mode=Qmodes.kernel_wise) + self.act = ActQ(in_features=in_features, nbits_a=nbits_w) + + def forward(self, x): + if self.alpha is None: + return F.linear(x, self.weight, self.bias) + Qn = -2 ** (self.nbits - 1) + Qp = 2 ** (self.nbits - 1) - 1 + if self.training and self.init_state == 0: + self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp)) + # self.alpha.data.copy_(self.weight.abs().max() / 2 ** (self.nbits - 1)) + self.init_state.fill_(1) + g = 1.0 / math.sqrt(self.weight.numel() * Qp) + + # Method1: + alpha = grad_scale(self.alpha, g) + alpha = alpha.unsqueeze(1) + w_q = round_pass((self.weight / alpha).clamp(Qn, Qp)) * alpha + + x = self.act(x) + # w = self.weight / alpha + # w = w.clamp(Qn, Qp) + # q_w = round_pass(w) + # w_q = q_w * alpha + + # Method2: + # w_q = FunLSQ.apply(self.weight, self.alpha, g, Qn, Qp) + return F.linear(x, w_q, self.bias) + + +class ActQ(_ActQ): + def __init__(self, in_features, nbits_a=4, mode=Qmodes.kernel_wise, **kwargs): + super(ActQ, self).__init__(in_features=in_features, nbits=nbits_a, mode=mode) + # print(self.alpha.shape, self.zero_point.shape) + def forward(self, x): + if self.alpha is None: + return x + + if self.training and self.init_state == 0: + # The init alpha for activation is very very important as the experimental results shows. + # Please select a init_rate for activation. + # self.alpha.data.copy_(x.max() / 2 ** (self.nbits - 1) * self.init_rate) + if x.min() < -1e-5: + self.signed.data.fill_(1) + if self.signed == 1: + Qn = -2 ** (self.nbits - 1) + Qp = 2 ** (self.nbits - 1) - 1 + else: + Qn = 0 + Qp = 2 ** self.nbits - 1 + self.alpha.data.copy_(2 * x.abs().mean() / math.sqrt(Qp)) + self.zero_point.data.copy_(self.zero_point.data * 0.9 + 0.1 * (torch.min(x.detach()) - self.alpha.data * Qn)) + self.init_state.fill_(1) + + if self.signed == 1: + Qn = -2 ** (self.nbits - 1) + Qp = 2 ** (self.nbits - 1) - 1 + else: + Qn = 0 + Qp = 2 ** self.nbits - 1 + + g = 1.0 / math.sqrt(x.numel() * Qp) + + # Method1: + zero_point = (self.zero_point.round() - self.zero_point).detach() + self.zero_point + alpha = grad_scale(self.alpha, g) + zero_point = grad_scale(zero_point, g) + # x = round_pass((x / alpha).clamp(Qn, Qp)) * alpha + if len(x.shape)==2: + alpha = alpha.unsqueeze(0) + zero_point = zero_point.unsqueeze(0) + elif len(x.shape)==4: + alpha = alpha.unsqueeze(0).unsqueeze(2).unsqueeze(3) + zero_point = zero_point.unsqueeze(0).unsqueeze(2).unsqueeze(3) + + x = round_pass((x / alpha + zero_point).clamp(Qn, Qp)) + x = (x - zero_point) * alpha + + return x diff --git a/models/q_vit/__init__.py b/models/q_vit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/q_vit/__pycache__/Quant.cpython-311.pyc b/models/q_vit/__pycache__/Quant.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..592c832865c369342dbb30d29eac5841af0447df Binary files /dev/null and b/models/q_vit/__pycache__/Quant.cpython-311.pyc differ diff --git a/models/q_vit/__pycache__/Quant.cpython-312.pyc b/models/q_vit/__pycache__/Quant.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84197598bfcda427747b60783825b0ecdd82795d Binary files /dev/null and b/models/q_vit/__pycache__/Quant.cpython-312.pyc differ diff --git a/models/q_vit/__pycache__/__init__.cpython-311.pyc b/models/q_vit/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58b9438291dfaa0963db7cb6ffb02b147cc8d93d Binary files /dev/null and b/models/q_vit/__pycache__/__init__.cpython-311.pyc differ diff --git a/models/q_vit/__pycache__/__init__.cpython-312.pyc b/models/q_vit/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6251533bbc7325014aa8bd7150f2ac5ad7e23c02 Binary files /dev/null and b/models/q_vit/__pycache__/__init__.cpython-312.pyc differ diff --git a/models/q_vit/__pycache__/_quan_base.cpython-311.pyc b/models/q_vit/__pycache__/_quan_base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..667fb0e0d4dc0e5b9ba9186a6fe092c2a990d989 Binary files /dev/null and b/models/q_vit/__pycache__/_quan_base.cpython-311.pyc differ diff --git a/models/q_vit/__pycache__/_quan_base.cpython-312.pyc b/models/q_vit/__pycache__/_quan_base.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ae56b2934ebd691785cbc356b76ece706725cc8 Binary files /dev/null and b/models/q_vit/__pycache__/_quan_base.cpython-312.pyc differ diff --git a/models/q_vit/__pycache__/quant_vision_transformer.cpython-311.pyc b/models/q_vit/__pycache__/quant_vision_transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..def0dbd5f3536625a38e3e349ed936e5b06e65c1 Binary files /dev/null and b/models/q_vit/__pycache__/quant_vision_transformer.cpython-311.pyc differ diff --git a/models/q_vit/__pycache__/quant_vision_transformer.cpython-312.pyc b/models/q_vit/__pycache__/quant_vision_transformer.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c589cfd7444a3b26b8a3a86a29786f1e40be6308 Binary files /dev/null and b/models/q_vit/__pycache__/quant_vision_transformer.cpython-312.pyc differ diff --git a/models/q_vit/_quan_base.py b/models/q_vit/_quan_base.py new file mode 100644 index 0000000000000000000000000000000000000000..57317d2bf5f1f2cfd0406a486bb9bb649a87ad23 --- /dev/null +++ b/models/q_vit/_quan_base.py @@ -0,0 +1,208 @@ +""" + Quantized modules: the base class +""" +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter + +import math +from enum import Enum + +__all__ = ['Qmodes', '_Conv2dQ', '_LinearQ', '_ActQ', + 'truncation', 'get_sparsity_mask', 'FunStopGradient', 'round_pass', 'grad_scale'] + + +class Qmodes(Enum): + layer_wise = 1 + kernel_wise = 2 + + +def grad_scale(x, scale): + y = x + y_grad = x * scale + return y.detach() - y_grad.detach() + y_grad + + +def get_sparsity_mask(param, sparsity): + bottomk, _ = torch.topk(param.abs().view(-1), int(sparsity * param.numel()), largest=False, sorted=True) + threshold = bottomk.data[-1] # This is the largest element from the group of elements that we prune away + return torch.gt(torch.abs(param), threshold).type(param.type()) + + +def round_pass(x): + y = x.round() + y_grad = x + return y.detach() - y_grad.detach() + y_grad + + +class FunStopGradient(torch.autograd.Function): + + @staticmethod + def forward(ctx, weight, stopGradientMask): + ctx.save_for_backward(stopGradientMask) + return weight + + @staticmethod + def backward(ctx, grad_outputs): + stopGradientMask, = ctx.saved_tensors + grad_inputs = grad_outputs * stopGradientMask + return grad_inputs, None + + +def log_shift(value_fp): + value_shift = 2 ** (torch.log2(value_fp).ceil()) + return value_shift + + +def clamp(input, min, max, inplace=False): + if inplace: + input.clamp_(min, max) + return input + return torch.clamp(input, min, max) + + +def get_quantized_range(num_bits, signed=True): + if signed: + n = 2 ** (num_bits - 1) + return -n, n - 1 + return 0, 2 ** num_bits - 1 + + +def linear_quantize(input, scale_factor, inplace=False): + if inplace: + input.mul_(scale_factor).round_() + return input + return torch.round(scale_factor * input) + + +def linear_quantize_clamp(input, scale_factor, clamp_min, clamp_max, inplace=False): + output = linear_quantize(input, scale_factor, inplace) + return clamp(output, clamp_min, clamp_max, inplace) + + +def linear_dequantize(input, scale_factor, inplace=False): + if inplace: + input.div_(scale_factor) + return input + return input / scale_factor + + +def truncation(fp_data, nbits=8): + il = torch.log2(torch.max(fp_data.max(), fp_data.min().abs())) + 1 + il = math.ceil(il - 1e-5) + qcode = nbits - il + scale_factor = 2 ** qcode + clamp_min, clamp_max = get_quantized_range(nbits, signed=True) + q_data = linear_quantize_clamp(fp_data, scale_factor, clamp_min, clamp_max) + q_data = linear_dequantize(q_data, scale_factor) + return q_data, qcode + + +def get_default_kwargs_q(kwargs_q, layer_type): + default = { + 'nbits': 4 + } + if isinstance(layer_type, _Conv2dQ): + default.update({ + 'mode': Qmodes.layer_wise}) + elif isinstance(layer_type, _LinearQ): + pass + elif isinstance(layer_type, _ActQ): + pass + # default.update({ + # 'signed': 'Auto'}) + else: + assert NotImplementedError + return + for k, v in default.items(): + if k not in kwargs_q: + kwargs_q[k] = v + return kwargs_q + + +class _Conv2dQ(nn.Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True, **kwargs_q): + super(_Conv2dQ, self).__init__(in_channels, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self) + self.nbits = kwargs_q['nbits'] + if self.nbits < 0: + self.register_parameter('alpha', None) + return + self.q_mode = kwargs_q['mode'] + if self.q_mode == Qmodes.kernel_wise: + self.alpha = Parameter(torch.Tensor(out_channels)) + else: # layer-wise quantization + self.alpha = Parameter(torch.Tensor(1)) + self.register_buffer('init_state', torch.zeros(1)) + + def add_param(self, param_k, param_v): + self.kwargs_q[param_k] = param_v + + def set_bit(self, nbits): + self.kwargs_q['nbits'] = nbits + + def extra_repr(self): + s_prefix = super(_Conv2dQ, self).extra_repr() + if self.alpha is None: + return '{}, fake'.format(s_prefix) + return '{}, {}'.format(s_prefix, self.kwargs_q) + + +class _LinearQ(nn.Linear): + def __init__(self, in_features, out_features, bias=True, **kwargs_q): + super(_LinearQ, self).__init__(in_features=in_features, out_features=out_features, bias=bias) + self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self) + self.nbits = kwargs_q['nbits'] + if self.nbits < 0: + self.register_parameter('alpha', None) + return + self.q_mode = kwargs_q['mode'] + self.alpha = Parameter(torch.Tensor(1)) + if self.q_mode == Qmodes.kernel_wise: + self.alpha = Parameter(torch.Tensor(out_features)) + self.register_buffer('init_state', torch.zeros(1)) + + def add_param(self, param_k, param_v): + self.kwargs_q[param_k] = param_v + + def extra_repr(self): + s_prefix = super(_LinearQ, self).extra_repr() + if self.alpha is None: + return '{}, fake'.format(s_prefix) + return '{}, {}'.format(s_prefix, self.kwargs_q) + + +class _ActQ(nn.Module): + def __init__(self, in_features, **kwargs_q): + super(_ActQ, self).__init__() + self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self) + self.nbits = kwargs_q['nbits'] + if self.nbits < 0: + self.register_parameter('alpha', None) + self.register_parameter('zero_point', None) + return + # self.signed = kwargs_q['signed'] + self.q_mode = kwargs_q['mode'] + self.alpha = Parameter(torch.Tensor(1)) + self.zero_point = Parameter(torch.Tensor([0])) + if self.q_mode == Qmodes.kernel_wise: + self.alpha = Parameter(torch.Tensor(in_features)) + self.zero_point = Parameter(torch.Tensor(in_features)) + torch.nn.init.zeros_(self.zero_point) + # self.zero_point = Parameter(torch.Tensor([0])) + self.register_buffer('init_state', torch.zeros(1)) + self.register_buffer('signed', torch.zeros(1)) + + def add_param(self, param_k, param_v): + self.kwargs_q[param_k] = param_v + + def set_bit(self, nbits): + self.kwargs_q['nbits'] = nbits + + def extra_repr(self): + # s_prefix = super(_ActQ, self).extra_repr() + if self.alpha is None: + return 'fake' + return '{}'.format(self.kwargs_q) diff --git a/models/q_vit/quant_vision_transformer.py b/models/q_vit/quant_vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4156d7d4938b75d413b73f86c24a334c59a80e0f --- /dev/null +++ b/models/q_vit/quant_vision_transformer.py @@ -0,0 +1,527 @@ +import math +import logging +from functools import partial +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import load_pretrained +from timm.models.layers import Mlp +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.resnet import resnet26d, resnet50d +from timm.models.registry import register_model + +import numpy as np +from .Quant import * +from ._quan_base import * + + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # patch models (my experiments) + 'vit_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', + ), + + # patch models (weights ported from official Google JAX impl) + 'vit_base_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), + 'vit_base_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_base_patch16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + 'vit_base_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + 'vit_large_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch32_224': _cfg( + url='', # no official model weights for this combo, only for in21k + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch16_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + 'vit_large_patch32_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), + + # patch models, imagenet21k (weights ported from official Google JAX impl) + 'vit_base_patch16_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_base_patch32_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch16_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_large_patch32_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + 'vit_huge_patch14_224_in21k': _cfg( + url='', # FIXME I have weights for this but > 2GB limit for github release binaries + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + + # hybrid models (weights ported from official Google JAX impl) + 'vit_base_resnet50_224_in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_resnet50_384': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), + + # hybrid models (my experiments) + 'vit_small_resnet26d_224': _cfg(), + 'vit_small_resnet50d_s3_224': _cfg(), + 'vit_base_resnet26d_224': _cfg(), + 'vit_base_resnet50d_224': _cfg(), + + # deit models (FB weights) + 'vit_deit_tiny_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), + 'vit_deit_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), + 'vit_deit_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), + 'vit_deit_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_deit_tiny_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'), + 'vit_deit_small_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'), + 'vit_deit_base_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ), + 'vit_deit_base_distilled_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', + input_size=(3, 384, 384), crop_pct=1.0), +} + +class Q_Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, nbits, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + drop_probs = to_2tuple(drop) + + self.fc1 = LinearQ(in_features, hidden_features, nbits_w=nbits, mode=Qmodes.kernel_wise) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.fc2 = LinearQ(hidden_features, out_features, nbits_w=nbits, mode=Qmodes.kernel_wise) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + # print(torch.max(x), torch.min(x)) + x = self.act(x) + + x = torch.clip(x, -10., 10.) + # print(torch.clip(x, -10., 10.)) + x = self.drop1(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class Q_Attention(nn.Module): + + def __init__(self, nbits, dim, num_heads=8, quantize_attn=True, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.quantize_attn = quantize_attn + + self.norm_q = nn.LayerNorm(head_dim) + self.norm_k = nn.LayerNorm(head_dim) + + + if self.quantize_attn: + + self.qkv = LinearQ(dim, dim * 3, bias=qkv_bias, nbits_w=nbits, mode=Qmodes.kernel_wise) + + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = LinearQ(dim, dim, nbits_w=nbits, mode=Qmodes.kernel_wise) + self.q_act = ActQ(nbits_a=nbits, in_features=self.num_heads) + self.k_act = ActQ(nbits_a=nbits, in_features=self.num_heads) + self.v_act = ActQ(nbits_a=nbits, in_features=self.num_heads) + self.attn_act = ActQ(nbits_a=nbits, in_features=self.num_heads) + else: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.q_act = ActQ(nbits_a=nbits, in_features=self.num_heads) + self.k_act = ActQ(nbits_a=nbits, in_features=self.num_heads) + self.v_act = ActQ(nbits_a=nbits, in_features=self.num_heads) + self.attn_act = ActQ(nbits_a=nbits, in_features=self.num_heads) + + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + q = self.norm_q(q) + k = self.norm_k(k) + + q = self.q_act(q) + k = self.k_act(k) + v = self.v_act(v) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = self.attn_act(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Q_Block(nn.Module): + + def __init__(self, nbits, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Q_Attention(nbits, dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Q_Mlp(nbits=nbits, in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + +class Q_PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, nbits=4, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = Conv2dQ(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + # nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + +class lowbit_VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__(self, nbits, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=Q_PatchEmbed, norm_layer=None, + act_layer=None, weight_init=''): + """ + Args: + nbits: nbits + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + distilled (bool): model includes a distillation token and head as in DeiT models + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.patch_embed = embed_layer( + nbits=nbits, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + Q_Block( + nbits=nbits, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # Representation layer + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + # Classifier head(s) + self.head = LinearQ(self.num_features, num_classes, nbits_w=8) if num_classes > 0 else nn.Identity() + # nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = LinearQ(self.embed_dim, self.num_classes, nbits_w=8) if num_classes > 0 else nn.Identity() + # self.head = LinearQ(self.embed_dim, self.num_classes, nbits_w=8) if num_classes > 0 else nn.Identity() + # nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + trunc_normal_(self.pos_embed, std=.02) + if self.dist_token is not None: + trunc_normal_(self.dist_token, std=.02) + if mode.startswith('jax'): + # leave cls token as zeros to match jax impl + named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) + else: + trunc_normal_(self.cls_token, std=.02) + self.apply(_init_vit_weights) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + _init_vit_weights(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + def get_classifier(self): + if self.dist_token is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + if self.num_tokens == 2: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.blocks(x) + x = self.norm(x) + if self.dist_token is None: + return self.pre_logits(x[:, 0]) + else: + return x[:, 0], x[:, 1] + + def forward(self, x): + x = self.forward_features(x) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple + if self.training and not torch.jit.is_scripting(): + # during inference, return the average of both classifier predictions + return x, x_dist + else: + return (x + x_dist) / 2 + else: + x = self.head(x) + return x + +def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + lecun_normal_(module.weight) + nn.init.zeros_(module.bias) + else: + if jax_impl: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + else: + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + +def resize_pos_embed(posemb, posemb_new): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if True: + posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + ntok_new -= 1 + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + gs_new = int(math.sqrt(ntok_new)) + _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed(v, model.pos_embed) + out_dict[k] = v + return out_dict + + +def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs): + default_cfg = default_cfgs[variant] + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-1] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + repr_size = kwargs.pop('representation_size', None) + if repr_size is not None and num_classes != default_num_classes: + # Remove representation layer if fine-tuning. This may not always be the desired action, + # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? + _logger.warning("Removing representation layer for fine-tuning.") + repr_size = None + + model_cls = DistilledVisionTransformer if distilled else VisionTransformer + model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs) + model.default_cfg = default_cfg + + if pretrained: + load_pretrained( + model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3), + filter_fn=partial(checkpoint_filter_fn, model=model)) + return model + + +@register_model +def fourbits_deit_small_patch16_224(pretrained=False, **kwargs): + model = lowbit_VisionTransformer( + nbits=4, patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + torch.hub.load_state_dict_from_url( + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', + map_location="cpu", check_hash=True + ) + return model + +@register_model +def threebits_deit_small_patch16_224(pretrained=False, **kwargs): + model = lowbit_VisionTransformer( + nbits=3, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + torch.hub.load_state_dict_from_url( + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', + map_location="cpu", check_hash=True + ) + return model + +@register_model +def twobits_deit_small_patch16_224(pretrained=False, **kwargs): + model = lowbit_VisionTransformer( + nbits=2, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model.default_cfg = _cfg() + if pretrained: + torch.hub.load_state_dict_from_url( + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', + map_location="cpu", check_hash=True + ) + return model diff --git a/models/qk_model_v1_1003.py b/models/qk_model_v1_1003.py new file mode 100644 index 0000000000000000000000000000000000000000..017053d351186a7610a6052424d9afee6e72b0e8 --- /dev/null +++ b/models/qk_model_v1_1003.py @@ -0,0 +1,426 @@ +import torch +import torch.nn as nn +from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode +from timm.models.layers import to_2tuple, trunc_normal_, DropPath +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg +from functools import partial +from timm.models import create_model + +__all__ = ['QKFormer'] + +class MLP(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.mlp1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1) + self.mlp1_bn = nn.BatchNorm2d(hidden_features) + self.mlp1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.mlp2_conv = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1) + self.mlp2_bn = nn.BatchNorm2d(out_features) + self.mlp2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.c_hidden = hidden_features + self.c_output = out_features + + def forward(self, x): + T, B, C, H, W = x.shape + + x = self.mlp1_conv(x.flatten(0, 1)) + x = self.mlp1_bn(x).reshape(T, B, self.c_hidden, H, W) + x = self.mlp1_lif(x) + + x = self.mlp2_conv(x.flatten(0, 1)) + x = self.mlp2_bn(x).reshape(T, B, C, H, W) + x = self.mlp2_lif(x) + return x + +class Token_QK_Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + + self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.q_bn = nn.BatchNorm1d(dim) + self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.k_bn = nn.BatchNorm1d(dim) + self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='torch') + + self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1) + self.proj_bn = nn.BatchNorm1d(dim) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + def forward(self, x): + T, B, C, H, W = x.shape + + x = x.flatten(3) + T, B, C, N = x.shape + x_for_qkv = x.flatten(0, 1) + + q_conv_out = self.q_conv(x_for_qkv) + q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N) + q_conv_out = self.q_lif(q_conv_out) + q = q_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N) + + k_conv_out = self.k_conv(x_for_qkv) + k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N) + k_conv_out = self.k_lif(k_conv_out) + k = k_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N) + + q = torch.sum(q, dim=3, keepdim=True) + attn = self.attn_lif(q) + x = torch.mul(attn, k) + + x = x.flatten(2, 3) + x = self.proj_bn(self.proj_conv(x.flatten(0, 1))).reshape(T, B, C, H, W) + # print(f"proj_conv out shape: {x.shape}") + x = self.proj_lif(x) + return x + +class Spiking_Self_Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = 0.125 + self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.q_bn = nn.BatchNorm1d(dim) + self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.k_bn = nn.BatchNorm1d(dim) + self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.v_bn = nn.BatchNorm1d(dim) + self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='torch') + + self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1) + self.proj_bn = nn.BatchNorm1d(dim) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.qkv_mp = nn.MaxPool1d(4) + + def forward(self, x): + T, B, C, H, W = x.shape + + x = x.flatten(3) + T, B, C, N = x.shape + x_for_qkv = x.flatten(0, 1) + + q_conv_out = self.q_conv(x_for_qkv) + q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous() + q_conv_out = self.q_lif(q_conv_out) + q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, + 4).contiguous() + + k_conv_out = self.k_conv(x_for_qkv) + k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous() + k_conv_out = self.k_lif(k_conv_out) + k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, + 4).contiguous() + + v_conv_out = self.v_conv(x_for_qkv) + v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous() + v_conv_out = self.v_lif(v_conv_out) + v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, + 4).contiguous() + + x = k.transpose(-2, -1) @ v + x = (q @ x) * self.scale + + x = x.transpose(3, 4).reshape(T, B, C, N).contiguous() + x = self.attn_lif(x) + x = x.flatten(0, 1) + x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, H, W) + return x + +class TokenSpikingTransformer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1): + super().__init__() + self.tssa = Token_QK_Attention(dim, num_heads) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop) + + def forward(self, x): + + x = x + self.tssa(x) + x = x + self.mlp(x) + + return x + +class SpikingTransformer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1): + super().__init__() + self.ssa = Spiking_Self_Attention(dim, num_heads) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop) + + def forward(self, x): + + x = x + self.ssa(x) + x = x + self.mlp(x) + + return x + +class PatchEmbedInit(nn.Module): + def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256): + super().__init__() + self.image_size = [img_size_h, img_size_w] + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.C = in_channels + self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1] + self.num_patches = self.H * self.W + + self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False) + self.proj_bn = nn.BatchNorm2d(embed_dims // 8) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.proj1_conv = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False) + self.proj1_bn = nn.BatchNorm2d(embed_dims // 4) + self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + self.proj1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.proj2_conv = nn.Conv2d(embed_dims//4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False) + self.proj2_bn = nn.BatchNorm2d(embed_dims // 2) + self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + self.proj2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.proj3_conv = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.proj3_bn = nn.BatchNorm2d(embed_dims) + self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + self.proj3_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.proj_res_conv = nn.Conv2d(embed_dims // 4, embed_dims, kernel_size=1, stride=4, padding=0, bias=False) + self.proj_res_bn = nn.BatchNorm2d(embed_dims) + self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + + def forward(self, x): + T, B, C, H, W = x.shape + # Downsampling + Res + # x_feat = x.flatten(0, 1) + x = self.proj_conv(x.flatten(0, 1)) + x = self.proj_bn(x).reshape(T, B, -1, H, W) + x = self.proj_lif(x).flatten(0, 1).contiguous() + + x = self.proj1_conv(x) + x = self.proj1_bn(x) + x = self.maxpool1(x) + _, _, H1, W1 = x.shape + x = x.reshape(T, B, -1, H1, W1).contiguous() + x = self.proj1_lif(x).flatten(0, 1).contiguous() + + x_feat = x + x = self.proj2_conv(x) + x = self.proj2_bn(x) + x = self.maxpool2(x) + _, _, H2, W2 = x.shape + x = x.reshape(T, B, -1, H2, W2).contiguous() + x = self.proj2_lif(x).flatten(0, 1).contiguous() + + x = self.proj3_conv(x) + x = self.proj3_bn(x) + x = self.maxpool3(x) + _, _, H3, W3 = x.shape + x = x.reshape(T, B, -1, H3, W3).contiguous() + x = self.proj3_lif(x) + + x_feat = self.proj_res_conv(x_feat) + x_feat = self.proj_res_bn(x_feat) + _, _, Hres, Wres = x_feat.shape + x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous() + x_feat = self.proj_res_lif(x_feat) + x = x + x_feat # shortcut + + return x + +class PatchEmbeddingStage(nn.Module): + def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256): + super().__init__() + self.image_size = [img_size_h, img_size_w] + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.C = in_channels + self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1] + self.num_patches = self.H * self.W + + self.proj_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.proj_bn = nn.BatchNorm2d(embed_dims) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.proj4_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.proj4_bn = nn.BatchNorm2d(embed_dims) + self.proj4_maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + self.proj4_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + self.proj_res_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=1, stride=2, padding=0, bias=False) + self.proj_res_bn = nn.BatchNorm2d(embed_dims) + self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='torch') + + def forward(self, x): + T, B, C, H, W = x.shape + # Downsampling + Res + + x = x.flatten(0, 1).contiguous() + x_feat = x + + x = self.proj_conv(x) + x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous() + x = self.proj_lif(x).flatten(0, 1).contiguous() + + x = self.proj4_conv(x) + x = self.proj4_bn(x) + x = self.proj4_maxpool(x) + _, _, H4, W4 = x.shape + x = x.reshape(T, B, -1, H4, W4).contiguous() + x = self.proj4_lif(x) + + x_feat = self.proj_res_conv(x_feat) + x_feat = self.proj_res_bn(x_feat) + _, _, Hres, Wres = x_feat.shape + x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous() + x_feat = self.proj_res_lif(x_feat) + + x = x + x_feat # shortcut + + return x + + +class vit_snn(nn.Module): + def __init__(self, + img_size_h=128, img_size_w=128, patch_size=16, in_channels=2, num_classes=11, + embed_dims=[64, 128, 256], num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[6, 8, 6], sr_ratios=[8, 4, 2], T=4, pretrained_cfg=None, in_chans = 3, no_weight_decay = None + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.T = T + num_heads = [16, 16, 16] + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule + + # + patch_embed1 = PatchEmbedInit(img_size_h=img_size_h, + img_size_w=img_size_w, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims // 2) + + stage1 = nn.ModuleList([TokenSpikingTransformer( + dim=embed_dims // 2, num_heads=num_heads[0], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j], + norm_layer=norm_layer, sr_ratio=sr_ratios) + for j in range(1)]) + + + patch_embed2 = PatchEmbeddingStage(img_size_h=img_size_h, + img_size_w=img_size_w, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims) + + + stage2 = nn.ModuleList([SpikingTransformer( + dim=embed_dims, num_heads=num_heads[1], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j], + norm_layer=norm_layer, sr_ratio=sr_ratios) + for j in range(1)]) + + + setattr(self, f"patch_embed1", patch_embed1) + setattr(self, f"stage1", stage1) + setattr(self, f"patch_embed2", patch_embed2) + setattr(self, f"stage2", stage2) + + + # classification head + self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity() + self.apply(self._init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pose_embed'} + + @torch.jit.ignore + def _get_pos_embed(self, pos_embed, patch_embed, H, W): + return None + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + stage1 = getattr(self, f"stage1") + patch_embed1 = getattr(self, f"patch_embed1") + stage2 = getattr(self, f"stage2") + patch_embed2 = getattr(self, f"patch_embed2") + + x = patch_embed1(x) + for blk in stage1: + x = blk(x) + + x = patch_embed2(x) + for blk in stage2: + x = blk(x) + + return x.flatten(3).mean(3) + + def forward(self, x): + x = x.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *] + x = self.forward_features(x) + x = self.head(x.mean(0)) + return x + + +@register_model +def QKFormer_1003(pretrained=False, **kwargs): + model = vit_snn( + patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=1, + in_channels=2, num_classes=101, qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=4, sr_ratios=1, + **kwargs + ) + model.default_cfg = _cfg() + return model + + +from timm.models import create_model + +if __name__ == '__main__': + x = torch.randn(1, 1, 2, 128, 128).cuda() + model = create_model( + 'QKFormer_1003', + pretrained=False, + drop_rate=0, + drop_path_rate=0.1, + drop_block_rate=None, + ).cuda() + model.eval() + + from torchinfo import summary + summary(model, input_size=(1, 1, 2, 128, 128)) + y = model(x) + print(y.shape) + print('Test Good!') diff --git a/models/qk_model_with_delay/__init__.py b/models/qk_model_with_delay/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/qk_model_with_delay/__pycache__/__init__.cpython-311.pyc b/models/qk_model_with_delay/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a6ed8e494d6f2f6a21ec2a6ffda0954af636058 Binary files /dev/null and b/models/qk_model_with_delay/__pycache__/__init__.cpython-311.pyc differ diff --git a/models/qk_model_with_delay/__pycache__/delay_synaptic_func_inter.cpython-311.pyc b/models/qk_model_with_delay/__pycache__/delay_synaptic_func_inter.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd80c9d5d42c0694a42ba586c5d285a2e3bf20ca Binary files /dev/null and b/models/qk_model_with_delay/__pycache__/delay_synaptic_func_inter.cpython-311.pyc differ diff --git a/models/qk_model_with_delay/__pycache__/delay_synaptic_inter_model.cpython-311.pyc b/models/qk_model_with_delay/__pycache__/delay_synaptic_inter_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d88d8991a93c4901143316b36217ee8816eeffc Binary files /dev/null and b/models/qk_model_with_delay/__pycache__/delay_synaptic_inter_model.cpython-311.pyc differ diff --git a/models/qk_model_with_delay/delay_synaptic_func_inter.py b/models/qk_model_with_delay/delay_synaptic_func_inter.py new file mode 100644 index 0000000000000000000000000000000000000000..8f289448c837e41b791916e4205ed7451dba0e32 --- /dev/null +++ b/models/qk_model_with_delay/delay_synaptic_func_inter.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +def set_sigma_for_DCLS(model, s): + for name, module in model.named_modules(): + if module.__class__.__name__ == 'DelayConv': + if hasattr(module, 'sigma'): + module.sigma = s + print('Set sigma to ',s) + +class DropoutNd(nn.Module): + def __init__(self, p: float = 0.5, tie=True, transposed=True): + """ + tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) + """ + super().__init__() + if p < 0 or p >= 1: + raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) + self.p = p + self.tie = tie + self.transposed = transposed + self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) + + def forward(self, X): + """X: (batch, dim, lengths...).""" + if self.training: + if not self.transposed: X = rearrange(X, 'b ... d -> b d ...') + # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying + mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape + # mask = self.binomial.sample(mask_shape) + mask = torch.rand(*mask_shape, device=X.device) < 1. - self.p + X = X * mask * (1.0 / (1 - self.p)) + if not self.transposed: X = rearrange(X, 'b d ... -> b ... d') + return X + return X + +class DelayConv(nn.Module): + def __init__( + self, + in_c, + k, + dropout=0.0, + n_delay=1, + dilation=1, + kernel_type='triangle_r_temp' + ): + super().__init__() + self.C = in_c # 输入和输出通道数 + self.win_len = k + self.dilation = dilation + self.n_delay = n_delay + self.kernel_type = kernel_type + + self.t = torch.arange(self.win_len).float().unsqueeze(0) # [1, k] + self.sigma = self.win_len // 2 + + self.delay_kernel = None + self.bump = None + + # ========== 修改:d 形状 -> [C_out, C_in, n_delay] ========== + d = torch.rand(self.C, self.C, self.n_delay) + with torch.no_grad(): + for co in range(self.C): + for ci in range(self.C): + d[co, ci, :] = torch.randperm(self.win_len - 2)[:self.n_delay] + 1 + self.register("d", d, lr=1e-2) + + # 初始化权重: [C_out, C_in, k] + weight = torch.ones([self.C, self.C, k]) + with torch.no_grad(): + for co in range(self.C): # output channel + for ci in range(self.C): # input channel + for i in range(k - 2, -1, -1): + weight[co, ci, i] = weight[co, ci, i + 1] / 2 + + self.weight = nn.Parameter(weight) + + self.dropout = nn.Dropout(dropout / 5) if dropout > 0.0 else nn.Identity() + + def register(self, name, tensor, lr=None): + """注册可训练或固定参数""" + if lr == 0.0: + self.register_buffer(name, tensor) + else: + self.register_parameter(name, nn.Parameter(tensor)) + optim = {"weight_decay": 0} + if lr is not None: + optim["lr"] = lr + setattr(getattr(self, name), "_optim", optim) + + def update_kernel(self, device): + """ + 输出 delay kernel: shape [C_out, C_in, k] + """ + t = self.t.to(device).view(1, 1, 1, -1) # [1,1,1,k] + d = self.d.to(device) # [C_out, C_in, n_delay] + + # ---------- 计算 bump ---------- + if self.kernel_type == 'gauss': + bump = torch.exp(-0.5 * ((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma) ** 2) + bump = (bump - 1e-3).relu() + 1e-3 + bump = bump / (bump.sum(dim=-1, keepdim=True) + 1e-7) + + elif self.kernel_type == 'triangle': + bump = torch.relu(1 - torch.abs((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma)) + bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7) + + elif self.kernel_type == 'triangle_r': + d_int = (d.round() - d).detach() + d + bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma)) + bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7) + + elif self.kernel_type == 'triangle_r_temp': + scale = min(1.0, 1.0 / self.sigma) + d_int = (d.round() - d).detach() * scale + d + bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma)) + bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7) # [C_out, C_in, n_delay, k] + # ------ 在eval模式硬化bump ------ + if not self.training: + max_idx = bump.argmax(dim=-1, keepdim=True) # 找最大值索引 + hard_mask = torch.zeros_like(bump) + hard_mask.scatter_(-1, max_idx, 1.0) + bump = bump * hard_mask + # -------------------------------- + else: + raise ValueError(f"Unknown kernel_type: {self.kernel_type}") + + # bump: [C_out, C_in, n_delay, k] + self.bump = bump.detach().clone().to(device) + + # ---------- 沿 n_delay 维度求和: [C_out, C_in, k] ---------- + bump_sum = bump.sum(dim=2) + + # ---------- 生成最终卷积核 ---------- + # weight: [C_out, C_in, k] + self.delay_kernel = (self.weight * bump_sum).to(device) # [C_out, C_in, k] + + def forward(self, x): + """ + x: (T, B, N, C) + return: (T*B, C, N) + """ + # 调整维度 + x = x.permute(0, 1, 3, 2).contiguous() # (T, B, N, C) + T, B, N, C = x.shape + assert C == self.C, f"Input channel mismatch: {C} vs {self.C}" + x = x.permute(1, 2, 3, 0).contiguous() # (B, N, C, T) + + # 合并 B*N 作为 batch + x_reshaped = x.view(B * N, C, T) # (B*N, C, T) + device = x.device + + # 更新 kernel + self.update_kernel(device) # -> [C_out, C_in, k] + kernel = self.delay_kernel + + # padding + pad_left = (self.win_len - 1) * self.dilation + x_padded = F.pad(x_reshaped, (pad_left, 0)) # (B*N, C, T+pad) + + # 全通道卷积: groups=1 (跨通道交互) + y = F.conv1d(x_padded, kernel, stride=1, dilation=self.dilation, groups=1) # (B*N, C, T) + + # 还原到原始形状 + y = y.view(B, N, C, T).permute(3, 0, 2, 1).contiguous().view(-1, C, N) # (T*B, C, N) + + return self.dropout(y) \ No newline at end of file diff --git a/models/qk_model_with_delay/delay_synaptic_inter_model.py b/models/qk_model_with_delay/delay_synaptic_inter_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e2284479492c46f66d22a78866d69a706eb1201c --- /dev/null +++ b/models/qk_model_with_delay/delay_synaptic_inter_model.py @@ -0,0 +1,459 @@ +import torch +import torch.nn as nn +from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode +from timm.models.layers import to_2tuple, trunc_normal_, DropPath +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg +from functools import partial +from timm.models import create_model +from .delay_synaptic_func_inter import DelayConv + +__all__ = ['delay_QKFormer'] + +class MLP(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.mlp1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1) + self.mlp1_bn = nn.BatchNorm2d(hidden_features) + self.mlp1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.mlp2_conv = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1) + self.mlp2_bn = nn.BatchNorm2d(out_features) + self.mlp2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.c_hidden = hidden_features + self.c_output = out_features + + def forward(self, x): + T, B, C, H, W = x.shape + + x = self.mlp1_conv(x.flatten(0, 1)) + x = self.mlp1_bn(x).reshape(T, B, self.c_hidden, H, W) + x = self.mlp1_lif(x) + + x = self.mlp2_conv(x.flatten(0, 1)) + x = self.mlp2_bn(x).reshape(T, B, C, H, W) + x = self.mlp2_lif(x) + return x + +class Token_QK_Attention(nn.Module): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + sr_ratio=1, + k=16): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + + self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.q_bn = nn.BatchNorm1d(dim) + self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + # self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.k_proj_delay = DelayConv(in_c=self.dim, k=k) + self.k_bn = nn.BatchNorm1d(dim) + self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='cupy') + + self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1) + self.proj_bn = nn.BatchNorm1d(dim) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + def forward(self, x): + T, B, C, H, W = x.shape + + x = x.flatten(3) + T, B, C, N = x.shape + x_for_qkv = x.flatten(0, 1) + + q_conv_out = self.q_conv(x_for_qkv) + q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N) + q_conv_out = self.q_lif(q_conv_out) + q = q_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N) + + # k_conv_out = self.k_conv(x_for_qkv) + k_conv_out = self.k_proj_delay(x_for_qkv.reshape(T,B,C,N)) + k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N) + k_conv_out = self.k_lif(k_conv_out) + k = k_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N) + + q = torch.sum(q, dim=3, keepdim=True) + attn = self.attn_lif(q) + x = torch.mul(attn, k) + + x = x.flatten(2, 3) + x = self.proj_bn(self.proj_conv(x.flatten(0, 1))).reshape(T, B, C, H, W) + x = self.proj_lif(x) + return x + +class Spiking_Self_Attention(nn.Module): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + sr_ratio=1, + k=16): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = 0.125 + self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.q_bn = nn.BatchNorm1d(dim) + self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + # self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.k_proj_delay = DelayConv(in_c=self.dim, k=k) + self.k_bn = nn.BatchNorm1d(dim) + self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + # self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.v_proj_delay = DelayConv(in_c=self.dim, k=k) + self.v_bn = nn.BatchNorm1d(dim) + self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='cupy') + + self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1) + self.proj_bn = nn.BatchNorm1d(dim) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.qkv_mp = nn.MaxPool1d(4) + + def forward(self, x): + T, B, C, H, W = x.shape + + x = x.flatten(3) + T, B, C, N = x.shape + x_for_qkv = x.flatten(0, 1) + + q_conv_out = self.q_conv(x_for_qkv) + q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N).contiguous() + q_conv_out = self.q_lif(q_conv_out) + q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, + 4).contiguous() + + k_conv_out = self.k_proj_delay(x_for_qkv.reshape(T,B,C,N)) + k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N).contiguous() + k_conv_out = self.k_lif(k_conv_out) + k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, + 4).contiguous() + + v_conv_out = self.v_proj_delay(x_for_qkv.reshape(T,B,C,N)) + v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, N).contiguous() + v_conv_out = self.v_lif(v_conv_out) + v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, + 4).contiguous() + + x = k.transpose(-2, -1) @ v + x = (q @ x) * self.scale + + x = x.transpose(3, 4).reshape(T, B, C, N).contiguous() + x = self.attn_lif(x) + x = x.flatten(0, 1) + x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T, B, C, H, W) + return x + +class TokenSpikingTransformer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1): + super().__init__() + self.tssa = Token_QK_Attention(dim, num_heads) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop) + + def forward(self, x): + + x = x + self.tssa(x) + x = x + self.mlp(x) + + return x + +class SpikingTransformer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1): + super().__init__() + self.ssa = Spiking_Self_Attention(dim, num_heads) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop) + + def forward(self, x): + + x = x + self.ssa(x) + x = x + self.mlp(x) + + return x + +class PatchEmbedInit(nn.Module): + def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256): + super().__init__() + self.image_size = [img_size_h, img_size_w] + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.C = in_channels + self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1] + self.num_patches = self.H * self.W + + self.proj_conv = nn.Conv2d(in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False) + self.proj_bn = nn.BatchNorm2d(embed_dims // 8) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.proj1_conv = nn.Conv2d(embed_dims // 8, embed_dims // 4, kernel_size=3, stride=1, padding=1, bias=False) + self.proj1_bn = nn.BatchNorm2d(embed_dims // 4) + self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + self.proj1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.proj2_conv = nn.Conv2d(embed_dims//4, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False) + self.proj2_bn = nn.BatchNorm2d(embed_dims // 2) + self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + self.proj2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.proj3_conv = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.proj3_bn = nn.BatchNorm2d(embed_dims) + self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + self.proj3_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.proj_res_conv = nn.Conv2d(embed_dims // 4, embed_dims, kernel_size=1, stride=4, padding=0, bias=False) + self.proj_res_bn = nn.BatchNorm2d(embed_dims) + self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + + def forward(self, x): + T, B, C, H, W = x.shape + # Downsampling + Res + # x_feat = x.flatten(0, 1) + x = self.proj_conv(x.flatten(0, 1)) + x = self.proj_bn(x).reshape(T, B, -1, H, W) + x = self.proj_lif(x).flatten(0, 1).contiguous() + + x = self.proj1_conv(x) + x = self.proj1_bn(x) + x = self.maxpool1(x) + _, _, H1, W1 = x.shape + x = x.reshape(T, B, -1, H1, W1).contiguous() + x = self.proj1_lif(x).flatten(0, 1).contiguous() + + x_feat = x + x = self.proj2_conv(x) + x = self.proj2_bn(x) + x = self.maxpool2(x) + _, _, H2, W2 = x.shape + x = x.reshape(T, B, -1, H2, W2).contiguous() + x = self.proj2_lif(x).flatten(0, 1).contiguous() + + x = self.proj3_conv(x) + x = self.proj3_bn(x) + x = self.maxpool3(x) + _, _, H3, W3 = x.shape + x = x.reshape(T, B, -1, H3, W3).contiguous() + x = self.proj3_lif(x) + + x_feat = self.proj_res_conv(x_feat) + x_feat = self.proj_res_bn(x_feat) + _, _, Hres, Wres = x_feat.shape + x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous() + x_feat = self.proj_res_lif(x_feat) + x = x + x_feat # shortcut + + return x + +class PatchEmbeddingStage(nn.Module): + def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256): + super().__init__() + self.image_size = [img_size_h, img_size_w] + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.C = in_channels + self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1] + self.num_patches = self.H * self.W + + self.proj_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.proj_bn = nn.BatchNorm2d(embed_dims) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.proj4_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.proj4_bn = nn.BatchNorm2d(embed_dims) + self.proj4_maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + self.proj4_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.proj_res_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=1, stride=2, padding=0, bias=False) + self.proj_res_bn = nn.BatchNorm2d(embed_dims) + self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + def forward(self, x): + T, B, C, H, W = x.shape + # Downsampling + Res + + x = x.flatten(0, 1).contiguous() + x_feat = x + + x = self.proj_conv(x) + x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous() + x = self.proj_lif(x).flatten(0, 1).contiguous() + + x = self.proj4_conv(x) + x = self.proj4_bn(x) + x = self.proj4_maxpool(x) + _, _, H4, W4 = x.shape + x = x.reshape(T, B, -1, H4, W4).contiguous() + x = self.proj4_lif(x) + + x_feat = self.proj_res_conv(x_feat) + x_feat = self.proj_res_bn(x_feat) + _, _, Hres, Wres = x_feat.shape + x_feat = x_feat.reshape(T, B, -1, Hres, Wres).contiguous() + x_feat = self.proj_res_lif(x_feat) + + x = x + x_feat # shortcut + + return x + + +class vit_snn(nn.Module): + def __init__(self, + img_size_h=128, img_size_w=128, patch_size=16, in_channels=2, num_classes=11, + embed_dims=[64, 128, 256], num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[6, 8, 6], sr_ratios=[8, 4, 2], T=4, pretrained_cfg=None, in_chans = 3, no_weight_decay = None + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.T = T + num_heads = [16, 16, 16] + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule + + # + patch_embed1 = PatchEmbedInit(img_size_h=img_size_h, + img_size_w=img_size_w, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims // 2) + + stage1 = nn.ModuleList([TokenSpikingTransformer( + dim=embed_dims // 2, num_heads=num_heads[0], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j], + norm_layer=norm_layer, sr_ratio=sr_ratios) + for j in range(1)]) + + + patch_embed2 = PatchEmbeddingStage(img_size_h=img_size_h, + img_size_w=img_size_w, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims) + + + stage2 = nn.ModuleList([SpikingTransformer( + dim=embed_dims, num_heads=num_heads[1], mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j], + norm_layer=norm_layer, sr_ratio=sr_ratios) + for j in range(1)]) + + + setattr(self, f"patch_embed1", patch_embed1) + setattr(self, f"stage1", stage1) + setattr(self, f"patch_embed2", patch_embed2) + setattr(self, f"stage2", stage2) + + + # classification head + self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity() + self.apply(self._init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pose_embed'} + + @torch.jit.ignore + def _get_pos_embed(self, pos_embed, patch_embed, H, W): + return None + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + stage1 = getattr(self, f"stage1") + patch_embed1 = getattr(self, f"patch_embed1") + stage2 = getattr(self, f"stage2") + patch_embed2 = getattr(self, f"patch_embed2") + + x = patch_embed1(x) + for blk in stage1: + x = blk(x) + + x = patch_embed2(x) + for blk in stage2: + x = blk(x) + + return x.flatten(3).mean(3) + + def forward(self, x): + x = x.permute(1, 0, 2, 3, 4) # [T, N, 2, *, *] + # print("torch.unique", torch.unique(x)) + # print("torch.count_nonzero", torch.count_nonzero(x)) + # print("numel()", x.numel()) + x = self.forward_features(x) + x = self.head(x.mean(0)) + return x + + +@register_model +def delay_QKFormer(pretrained=False, **kwargs): + model = vit_snn( + patch_size=16, embed_dims=256, num_heads=16, mlp_ratios=4, + in_channels=2, num_classes=101, qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=4, sr_ratios=1, + **kwargs + ) + model.default_cfg = _cfg() + return model + + +from timm.models import create_model + +if __name__ == '__main__': + x = torch.randn(1, 1, 2, 128, 128).cuda() + model = create_model( + 'delay_QKFormer', + pretrained=False, + drop_rate=0, + drop_path_rate=0.1, + drop_block_rate=None, + ).cuda() + model.eval() + + from torchinfo import summary + summary(model, input_size=(1, 1, 2, 128, 128)) + # y = model(x) + # print(y.shape) + # print('Test Good!') + + + + + + + + + + diff --git a/models/qkformer.py b/models/qkformer.py new file mode 100644 index 0000000000000000000000000000000000000000..96eb98481c001121ad9925d0dab36b22846e6ccb --- /dev/null +++ b/models/qkformer.py @@ -0,0 +1,448 @@ +# from visualizer import get_local +import torch +import torch.nn as nn +from spikingjelly.clock_driven.neuron import MultiStepParametricLIFNode, MultiStepLIFNode +from spikingjelly.clock_driven import layer +from timm.models.layers import to_2tuple, trunc_normal_, DropPath +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg +from einops.layers.torch import Rearrange +import torch.nn.functional as F +from functools import partial + +__all__ = ['QKFormer_10_512',] + + +def compute_non_zero_rate(x): + x_shape = torch.tensor(list(x.shape)) + all_neural = torch.prod(x_shape) + z = torch.nonzero(x) + print("After attention proj the none zero rate is", z.shape[0]/all_neural) + + +class MLP(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + # self.fc1 = linear_unit(in_features, hidden_features) + self.fc1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1) + self.fc1_bn = nn.BatchNorm2d(hidden_features) + self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + # self.fc2 = linear_unit(hidden_features, out_features) + self.fc2_conv = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1) + self.fc2_bn = nn.BatchNorm2d(out_features) + self.fc2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + # self.drop = nn.Dropout(0.1) + + self.c_hidden = hidden_features + self.c_output = out_features + def forward(self, x): + T,B,C,W,H = x.shape + x = self.fc1_conv(x.flatten(0,1)) + x = self.fc1_bn(x).reshape(T,B,self.c_hidden,W,H).contiguous() + x = self.fc1_lif(x) + + x = self.fc2_conv(x.flatten(0,1)) + x = self.fc2_bn(x).reshape(T,B,C,W,H).contiguous() + x = self.fc2_lif(x) + return x + +class Token_QK_Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + + self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.q_bn = nn.BatchNorm1d(dim) + self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1, bias=False) + self.k_bn = nn.BatchNorm1d(dim) + self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='cupy') + + self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1) + self.proj_bn = nn.BatchNorm1d(dim) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + + def forward(self, x): + T, B, C, H, W = x.shape + + x = x.flatten(3) + T, B, C, N = x.shape + x_for_qkv = x.flatten(0, 1) + + q_conv_out = self.q_conv(x_for_qkv) + q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, N) + q_conv_out = self.q_lif(q_conv_out) + q = q_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N) + + k_conv_out = self.k_conv(x_for_qkv) + k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, N) + k_conv_out = self.k_lif(k_conv_out) + k = k_conv_out.unsqueeze(2).reshape(T, B, self.num_heads, C // self.num_heads, N) + + q = torch.sum(q, dim = 3, keepdim = True) + attn = self.attn_lif(q) + x = torch.mul(attn, k) + + x = x.flatten(2, 3) + x = self.proj_bn(self.proj_conv(x.flatten(0, 1))).reshape(T, B, C, H, W) + x = self.proj_lif(x) + + return x + + +class Spiking_Self_Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = 0.125 + self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1,bias=False) + self.q_bn = nn.BatchNorm1d(dim) + self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1,bias=False) + self.k_bn = nn.BatchNorm1d(dim) + self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1,bias=False) + self.v_bn = nn.BatchNorm1d(dim) + self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='cupy') + + self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1) + self.proj_bn = nn.BatchNorm1d(dim) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + def forward(self, x): + T, B, C, H, W = x.shape + + x = x.flatten(3) + T, B, C, N = x.shape + x_for_qkv = x.flatten(0, 1) + x_feat = x + q_conv_out = self.q_conv(x_for_qkv) + q_conv_out = self.q_bn(q_conv_out).reshape(T,B,C,N).contiguous() + q_conv_out = self.q_lif(q_conv_out) + q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous() + + k_conv_out = self.k_conv(x_for_qkv) + k_conv_out = self.k_bn(k_conv_out).reshape(T,B,C,N).contiguous() + k_conv_out = self.k_lif(k_conv_out) + k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous() + + v_conv_out = self.v_conv(x_for_qkv) + v_conv_out = self.v_bn(v_conv_out).reshape(T,B,C,N).contiguous() + v_conv_out = self.v_lif(v_conv_out) + v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous() + + x = k.transpose(-2,-1) @ v + x = (q @ x) * self.scale + + x = x.transpose(3, 4).reshape(T, B, C, N).contiguous() + x = self.attn_lif(x) + x = x.flatten(0,1) + x = self.proj_lif(self.proj_bn(self.proj_conv(x))).reshape(T,B,C,H,W) + + return x + +class TokenSpikingTransformer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1): + super().__init__() + self.tssa = Token_QK_Attention(dim, num_heads) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features= dim, hidden_features=mlp_hidden_dim, drop=drop) + + def forward(self, x): + + x = x + self.tssa(x) + x = x + self.mlp(x) + + return x + + +class SpikingTransformer(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1): + super().__init__() + self.attn = Spiking_Self_Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) + + def forward(self, x): + x = x + self.attn(x) + x = x + self.mlp(x) + + return x + + +class PatchEmbedInit(nn.Module): + def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256): + super().__init__() + self.image_size = [img_size_h, img_size_w] + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.C = in_channels + self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1] + self.num_patches = self.H * self.W + # Downsampling + Res 0 + self.proj_conv = nn.Conv2d(in_channels, embed_dims // 2, kernel_size=3, stride=1, padding=1, bias=False) + self.proj_bn = nn.BatchNorm2d(embed_dims // 2) + self.proj_maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.proj1_conv = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.proj1_bn = nn.BatchNorm2d(embed_dims) + self.proj1_maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + self.proj1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.proj2_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.proj2_bn = nn.BatchNorm2d(embed_dims) + self.proj2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.proj_res_conv = nn.Conv2d(embed_dims // 2, embed_dims, kernel_size=1, stride=2, padding=0, bias=False) + self.proj_res_bn = nn.BatchNorm2d(embed_dims) + self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + + def forward(self, x): + T, B, C, H, W = x.shape + # Downsampling + Res + x = self.proj_conv(x.flatten(0, 1)) + x = self.proj_bn(x) + x = self.proj_maxpool(x).reshape(T, B, -1, H//2, W//2).contiguous() + x = self.proj_lif(x).flatten(0, 1).contiguous() + + x_feat = x + x = self.proj1_conv(x) + x = self.proj1_bn(x) + x = self.proj1_maxpool(x).reshape(T, B, -1, H // 4, W // 4).contiguous() + x = self.proj1_lif(x).flatten(0, 1).contiguous() + + x = self.proj2_conv(x) + x = self.proj2_bn(x).reshape(T, B, -1, H//4, W//4).contiguous() + x = self.proj2_lif(x) + + x_feat = self.proj_res_conv(x_feat) + x_feat = self.proj_res_bn(x_feat).reshape(T, B, -1, H//4, W//4).contiguous() + x_feat = self.proj_res_lif(x_feat) + + x = x + x_feat # shortcut + + return x + +class PatchEmbeddingStage(nn.Module): + def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256): + super().__init__() + self.image_size = [img_size_h, img_size_w] + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.C = in_channels + self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1] + self.num_patches = self.H * self.W + + self.proj3_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.proj3_bn = nn.BatchNorm2d(embed_dims) + self.proj3_maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + self.proj3_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.proj4_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.proj4_bn = nn.BatchNorm2d(embed_dims) + self.proj4_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.proj_res_conv = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=1, stride=2, padding=0, bias=False) + self.proj_res_bn = nn.BatchNorm2d(embed_dims) + self.proj_res_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + def forward(self, x): + T, B, C, H, W = x.shape + # Downsampling + Res + + x = x.flatten(0, 1).contiguous() + x_feat = x + + x = self.proj3_conv(x) + x = self.proj3_bn(x) + x = self.proj3_maxpool(x).reshape(T, B, -1, H//2, W//2).contiguous() + x = self.proj3_lif(x).flatten(0, 1).contiguous() + + x = self.proj4_conv(x) + x = self.proj4_bn(x).reshape(T, B, -1, H//2, W//2).contiguous() + x = self.proj4_lif(x) + + x_feat = self.proj_res_conv(x_feat) + x_feat = self.proj_res_bn(x_feat).reshape(T, B, -1, H//2, W//2).contiguous() + x_feat = self.proj_res_lif(x_feat) + + x = x + x_feat # shortcut + + return x + +class hierarchical_spiking_transformer(nn.Module): + def __init__(self, + T=4, + img_size_h=128, img_size_w=128, patch_size=16, in_channels=2, num_classes=11, + embed_dims=[64, 128, 256], num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[6, 8, 6], sr_ratios=[8, 4, 2] + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.T = T + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule + + patch_embed1 = PatchEmbedInit(img_size_h=img_size_h, + img_size_w=img_size_w, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims // 4) + + stage1 = nn.ModuleList([TokenSpikingTransformer( + dim=embed_dims // 4, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j], + norm_layer=norm_layer, sr_ratio=sr_ratios) + for j in range(1)]) + + patch_embed2 = PatchEmbeddingStage(img_size_h=img_size_h, + img_size_w=img_size_w, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims // 2) + + + stage2 = nn.ModuleList([TokenSpikingTransformer( + dim=embed_dims // 2, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j], + norm_layer=norm_layer, sr_ratio=sr_ratios) + for j in range(2)]) + + + patch_embed3 = PatchEmbeddingStage(img_size_h=img_size_h, + img_size_w=img_size_w, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims) + + stage3 = nn.ModuleList([SpikingTransformer( + dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j], + norm_layer=norm_layer, sr_ratio=sr_ratios) + for j in range(depths - 3)]) + + setattr(self, f"patch_embed1", patch_embed1) + setattr(self, f"patch_embed2", patch_embed2) + setattr(self, f"patch_embed3", patch_embed3) + setattr(self, f"stage1", stage1) + setattr(self, f"stage2", stage2) + setattr(self, f"stage3", stage3) + + # classification head 这里不需要脉冲,因为输入的是在T时长平均发射值 + self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity() + self.apply(self._init_weights) + + @torch.jit.ignore + def _get_pos_embed(self, pos_embed, patch_embed3, H, W): + if H * W == self.patch_embed3.num_patches: + return pos_embed + else: + return F.interpolate( + pos_embed.reshape(1, patch_embed3.H, patch_embed3.W, -1).permute(0, 3, 1, 2), + size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + + stage1 = getattr(self, f"stage1") + stage2 = getattr(self, f"stage2") + stage3 = getattr(self, f"stage3") + patch_embed1 = getattr(self, f"patch_embed1") + patch_embed2 = getattr(self, f"patch_embed2") + patch_embed3 = getattr(self, f"patch_embed3") + + x = patch_embed1(x) + for blk in stage1: + x = blk(x) + + x = patch_embed2(x) + for blk in stage2: + x = blk(x) + + x = patch_embed3(x) + for blk in stage3: + x = blk(x) + + return x.flatten(3).mean(3) + + def forward(self, x): + T = self.T + x = (x.unsqueeze(0)).repeat(T, 1, 1, 1, 1) + x = self.forward_features(x) + x = self.head(x.mean(0)) + return x + +def QKFormer_10_384(T=1, **kwargs): + model = hierarchical_spiking_transformer( + T=T, + img_size_h=224, img_size_w=224, + patch_size=16, embed_dims=384, num_heads=6, mlp_ratios=4, + in_channels=3, num_classes=1000, qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=10, sr_ratios=1, + **kwargs + ) + return model + +def QKFormer_10_512(T=1, **kwargs): + model = hierarchical_spiking_transformer( + T=T, + img_size_h=224, img_size_w=224, + patch_size=16, embed_dims=512, num_heads=8, mlp_ratios=4, + in_channels=3, num_classes=1000, qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=10, sr_ratios=1, + **kwargs + ) + return model + + +def QKFormer_10_768(T=1, **kwargs): + model = hierarchical_spiking_transformer( + T=T, + img_size_h=224, img_size_w=224, + patch_size=16, embed_dims=768, num_heads=12, mlp_ratios=4, + in_channels=3, num_classes=1000, qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=10, sr_ratios=1, + **kwargs + ) + return model + + +if __name__ == '__main__': + H = 128 + W = 128 + x = torch.randn(2, 3, 224, 224).cuda() + model = QKFormer_10_768(T = 4).cuda() + + model.eval() + from torchinfo import summary + summary(model, input_size=(1, 3, 224, 224)) diff --git a/models/sd_former_v1.py b/models/sd_former_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..492b2e369f23e522f586af7a5f5b1351a2c6b512 --- /dev/null +++ b/models/sd_former_v1.py @@ -0,0 +1,633 @@ +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg +from spikingjelly.clock_driven.neuron import ( + MultiStepLIFNode, + MultiStepParametricLIFNode, +) + +from timm.models.layers import to_2tuple + + +class MS_SPS(nn.Module): + def __init__( + self, + img_size_h=128, + img_size_w=128, + patch_size=4, + in_channels=2, + embed_dims=256, + pooling_stat="1111", + spike_mode="lif", + ): + super().__init__() + self.image_size = [img_size_h, img_size_w] + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.pooling_stat = pooling_stat + + self.C = in_channels + self.H, self.W = ( + self.image_size[0] // patch_size[0], + self.image_size[1] // patch_size[1], + ) + self.num_patches = self.H * self.W + self.proj_conv = nn.Conv2d( + in_channels, embed_dims // 8, kernel_size=3, stride=1, padding=1, bias=False + ) + self.proj_bn = nn.BatchNorm2d(embed_dims // 8) + if spike_mode == "lif": + self.proj_lif = MultiStepLIFNode(tau=2.0, v_threshold=1.0, detach_reset=True, backend="cupy") + elif spike_mode == "plif": + self.proj_lif = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + self.maxpool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False + ) + + self.proj_conv1 = nn.Conv2d( + embed_dims // 8, + embed_dims // 4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.proj_bn1 = nn.BatchNorm2d(embed_dims // 4) + if spike_mode == "lif": + self.proj_lif1 = MultiStepLIFNode( + tau=2.0, detach_reset=True, backend="cupy" + ) + elif spike_mode == "plif": + self.proj_lif1 = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + self.maxpool1 = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False + ) + + self.proj_conv2 = nn.Conv2d( + embed_dims // 4, + embed_dims // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.proj_bn2 = nn.BatchNorm2d(embed_dims // 2) + if spike_mode == "lif": + self.proj_lif2 = MultiStepLIFNode( + tau=2.0, detach_reset=True, backend="cupy" + ) + elif spike_mode == "plif": + self.proj_lif2 = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + self.maxpool2 = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False + ) + + self.proj_conv3 = nn.Conv2d( + embed_dims // 2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False + ) + self.proj_bn3 = nn.BatchNorm2d(embed_dims) + if spike_mode == "lif": + self.proj_lif3 = MultiStepLIFNode( + tau=2.0, detach_reset=True, backend="cupy" + ) + elif spike_mode == "plif": + self.proj_lif3 = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + self.maxpool3 = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False + ) + + self.rpe_conv = nn.Conv2d( + embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False + ) + self.rpe_bn = nn.BatchNorm2d(embed_dims) + if spike_mode == "lif": + self.rpe_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") + elif spike_mode == "plif": + self.rpe_lif = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + + def forward(self, x, hook=None): + T, B, _, H, W = x.shape + ratio = 1 + x = self.proj_conv(x.flatten(0, 1)) # have some fire value + x = self.proj_bn(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous() + x = self.proj_lif(x) + if hook is not None: + hook[self._get_name() + "_lif"] = x.detach() + x = x.flatten(0, 1).contiguous() + if self.pooling_stat[0] == "1": + x = self.maxpool(x) + ratio *= 2 + + x = self.proj_conv1(x) + x = self.proj_bn1(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous() + x = self.proj_lif1(x) + if hook is not None: + hook[self._get_name() + "_lif1"] = x.detach() + x = x.flatten(0, 1).contiguous() + if self.pooling_stat[1] == "1": + x = self.maxpool1(x) + ratio *= 2 + + x = self.proj_conv2(x) + x = self.proj_bn2(x).reshape(T, B, -1, H // ratio, W // ratio).contiguous() + x = self.proj_lif2(x) + if hook is not None: + hook[self._get_name() + "_lif2"] = x.detach() + x = x.flatten(0, 1).contiguous() + if self.pooling_stat[2] == "1": + x = self.maxpool2(x) + ratio *= 2 + + x = self.proj_conv3(x) + x = self.proj_bn3(x) + if self.pooling_stat[3] == "1": + x = self.maxpool3(x) + ratio *= 2 + + x_feat = x + x = self.proj_lif3(x.reshape(T, B, -1, H // ratio, W // ratio).contiguous()) + if hook is not None: + hook[self._get_name() + "_lif3"] = x.detach() + x = x.flatten(0, 1).contiguous() + x = self.rpe_conv(x) + x = self.rpe_bn(x) + x = (x + x_feat).reshape(T, B, -1, H // ratio, W // ratio).contiguous() + + H, W = H // self.patch_size[0], W // self.patch_size[1] + return x, (H, W), hook + +class Erode(nn.Module): + def __init__(self) -> None: + super().__init__() + self.pool = nn.MaxPool3d( + kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1) + ) + + def forward(self, x): + return self.pool(x) + + +class MS_MLP_Conv(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + drop=0.0, + spike_mode="lif", + layer=0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.res = in_features == hidden_features + self.fc1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1) + self.fc1_bn = nn.BatchNorm2d(hidden_features) + if spike_mode == "lif": + self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") + elif spike_mode == "plif": + self.fc1_lif = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + + self.fc2_conv = nn.Conv2d( + hidden_features, out_features, kernel_size=1, stride=1 + ) + self.fc2_bn = nn.BatchNorm2d(out_features) + if spike_mode == "lif": + self.fc2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") + elif spike_mode == "plif": + self.fc2_lif = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + + self.c_hidden = hidden_features + self.c_output = out_features + self.layer = layer + + def forward(self, x, hook=None): + T, B, C, H, W = x.shape + identity = x + + x = self.fc1_lif(x) + if hook is not None: + hook[self._get_name() + str(self.layer) + "_fc1_lif"] = x.detach() + x = self.fc1_conv(x.flatten(0, 1)) + x = self.fc1_bn(x).reshape(T, B, self.c_hidden, H, W).contiguous() + if self.res: + x = identity + x + identity = x + x = self.fc2_lif(x) + if hook is not None: + hook[self._get_name() + str(self.layer) + "_fc2_lif"] = x.detach() + x = self.fc2_conv(x.flatten(0, 1)) + x = self.fc2_bn(x).reshape(T, B, C, H, W).contiguous() + + x = x + identity + return x, hook + + +class MS_SSA_Conv(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + sr_ratio=1, + mode="direct_xor", + spike_mode="lif", + dvs=False, + layer=0, + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.dvs = dvs + self.num_heads = num_heads + if dvs: + self.pool = Erode() + self.scale = 0.125 + self.q_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False) + self.q_bn = nn.BatchNorm2d(dim) + if spike_mode == "lif": + self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") + elif spike_mode == "plif": + self.q_lif = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + + self.k_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False) + self.k_bn = nn.BatchNorm2d(dim) + if spike_mode == "lif": + self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") + elif spike_mode == "plif": + self.k_lif = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + + self.v_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1, bias=False) + self.v_bn = nn.BatchNorm2d(dim) + if spike_mode == "lif": + self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") + elif spike_mode == "plif": + self.v_lif = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + + if spike_mode == "lif": + self.attn_lif = MultiStepLIFNode( + tau=2.0, v_threshold=0.5, detach_reset=True, backend="cupy" + ) + elif spike_mode == "plif": + self.attn_lif = MultiStepParametricLIFNode( + init_tau=2.0, v_threshold=0.5, detach_reset=True, backend="cupy" + ) + + self.talking_heads = nn.Conv1d( + num_heads, num_heads, kernel_size=1, stride=1, bias=False + ) + if spike_mode == "lif": + self.talking_heads_lif = MultiStepLIFNode( + tau=2.0, v_threshold=0.5, detach_reset=True, backend="cupy" + ) + elif spike_mode == "plif": + self.talking_heads_lif = MultiStepParametricLIFNode( + init_tau=2.0, v_threshold=0.5, detach_reset=True, backend="cupy" + ) + + self.proj_conv = nn.Conv2d(dim, dim, kernel_size=1, stride=1) + self.proj_bn = nn.BatchNorm2d(dim) + + if spike_mode == "lif": + self.shortcut_lif = MultiStepLIFNode( + tau=2.0, detach_reset=True, backend="cupy" + ) + elif spike_mode == "plif": + self.shortcut_lif = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + + self.mode = mode + self.layer = layer + + def forward(self, x, hook=None): + T, B, C, H, W = x.shape + identity = x + N = H * W + x = self.shortcut_lif(x) + if hook is not None: + hook[self._get_name() + str(self.layer) + "_first_lif"] = x.detach() + + x_for_qkv = x.flatten(0, 1) + q_conv_out = self.q_conv(x_for_qkv) + q_conv_out = self.q_bn(q_conv_out).reshape(T, B, C, H, W).contiguous() + q_conv_out = self.q_lif(q_conv_out) + + if hook is not None: + hook[self._get_name() + str(self.layer) + "_q_lif"] = q_conv_out.detach() + q = ( + q_conv_out.flatten(3) + .transpose(-1, -2) + .reshape(T, B, N, self.num_heads, C // self.num_heads) + .permute(0, 1, 3, 2, 4) + .contiguous() + ) + + k_conv_out = self.k_conv(x_for_qkv) + k_conv_out = self.k_bn(k_conv_out).reshape(T, B, C, H, W).contiguous() + k_conv_out = self.k_lif(k_conv_out) + if self.dvs: + k_conv_out = self.pool(k_conv_out) + if hook is not None: + hook[self._get_name() + str(self.layer) + "_k_lif"] = k_conv_out.detach() + k = ( + k_conv_out.flatten(3) + .transpose(-1, -2) + .reshape(T, B, N, self.num_heads, C // self.num_heads) + .permute(0, 1, 3, 2, 4) + .contiguous() + ) + + v_conv_out = self.v_conv(x_for_qkv) + v_conv_out = self.v_bn(v_conv_out).reshape(T, B, C, H, W).contiguous() + v_conv_out = self.v_lif(v_conv_out) + if self.dvs: + v_conv_out = self.pool(v_conv_out) + if hook is not None: + hook[self._get_name() + str(self.layer) + "_v_lif"] = v_conv_out.detach() + v = ( + v_conv_out.flatten(3) + .transpose(-1, -2) + .reshape(T, B, N, self.num_heads, C // self.num_heads) + .permute(0, 1, 3, 2, 4) + .contiguous() + ) # T B head N C//h + + kv = k.mul(v) + if hook is not None: + hook[self._get_name() + str(self.layer) + "_kv_before"] = kv + if self.dvs: + kv = self.pool(kv) + kv = kv.sum(dim=-2, keepdim=True) + kv = self.talking_heads_lif(kv) + if hook is not None: + hook[self._get_name() + str(self.layer) + "_kv"] = kv.detach() + x = q.mul(kv) + if self.dvs: + x = self.pool(x) + if hook is not None: + hook[self._get_name() + str(self.layer) + "_x_after_qkv"] = x.detach() + + x = x.transpose(3, 4).reshape(T, B, C, H, W).contiguous() + x = ( + self.proj_bn(self.proj_conv(x.flatten(0, 1))) + .reshape(T, B, C, H, W) + .contiguous() + ) + + x = x + identity + return x, v, hook + + +class MS_Block_Conv(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + sr_ratio=1, + attn_mode="direct_xor", + spike_mode="lif", + dvs=False, + layer=0, + ): + super().__init__() + self.attn = MS_SSA_Conv( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + mode=attn_mode, + spike_mode=spike_mode, + dvs=dvs, + layer=layer, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MS_MLP_Conv( + in_features=dim, + hidden_features=mlp_hidden_dim, + drop=drop, + spike_mode=spike_mode, + layer=layer, + ) + + def forward(self, x, hook=None): + x_attn, attn, hook = self.attn(x, hook=hook) + x, hook = self.mlp(x_attn, hook=hook) + return x, attn, hook + + +class SpikeDrivenTransformer(nn.Module): + def __init__( + self, + img_size_h=128, + img_size_w=128, + patch_size=16, + in_channels=2, + num_classes=1000, + embed_dims=512, + num_heads=8, + mlp_ratios=4, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[6, 8, 6], + sr_ratios=[8, 4, 2], + T=4, + pooling_stat="1111", + attn_mode="direct_xor", + spike_mode="lif", + get_embed=False, + dvs_mode=False, + TET=False, + cml=False, + pretrained=False, + pretrained_cfg=None, + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + self.T = T + self.TET = TET + self.dvs = dvs_mode + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths) + ] # stochastic depth decay rule + + patch_embed = MS_SPS( + img_size_h=img_size_h, + img_size_w=img_size_w, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims, + pooling_stat=pooling_stat, + spike_mode=spike_mode, + ) + + blocks = nn.ModuleList( + [ + MS_Block_Conv( + dim=embed_dims, + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + attn_mode=attn_mode, + spike_mode=spike_mode, + dvs=dvs_mode, + layer=j, + ) + for j in range(depths) + ] + ) + + setattr(self, f"patch_embed", patch_embed) + setattr(self, f"block", blocks) + + # classification head + if spike_mode in ["lif", "alif", "blif"]: + self.head_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend="cupy") + elif spike_mode == "plif": + self.head_lif = MultiStepParametricLIFNode( + init_tau=2.0, detach_reset=True, backend="cupy" + ) + self.head = ( + nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity() + ) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x, hook=None): + block = getattr(self, f"block") + patch_embed = getattr(self, f"patch_embed") + + x, _, hook = patch_embed(x, hook=hook) + for blk in block: + x, _, hook = blk(x, hook=hook) + + x = x.flatten(3).mean(3) + return x, hook + + def forward(self, x, hook=None): + if len(x.shape) < 5: + x = (x.unsqueeze(0)).repeat(self.T, 1, 1, 1, 1) + else: + x = x.transpose(0, 1).contiguous() + + x, hook = self.forward_features(x, hook=hook) + x = self.head_lif(x) + if hook is not None: + hook["head_lif"] = x.detach() + + x = self.head(x) + if not self.TET: + x = x.mean(0) + return x, hook + + +@register_model +def sdt(**kwargs): + model = SpikeDrivenTransformer( + **kwargs, + ) + model.default_cfg = _cfg() + return model + +#### test +if __name__ == "__main__": + import os + os.environ["CUDA_VISIBLE_DEVICES"] = '0' + + model = sdt( + img_size_h=224, + img_size_w=224, + patch_size=16, + in_channels=3, + embed_dims=384, + num_heads=8, + mlp_ratios=4, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=[8, 4, 2], + T=4, + pooling_stat="1111", + attn_mode="direct_xor", + spike_mode="lif", + ) # 或者您自己的模型 + model.eval() + for name, param in model.named_parameters(): + print(name, param.size()) + + # # 如果有预训练权重,加载它们 + # state_dict = torch.load('V3_5.1M_1x4.pth', map_location='cpu') + # model.load_state_dict(state_dict['model'], strict=False) + + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + x = torch.randn(1, 3, 224, 224).to(device) + x = x.unsqueeze(0).repeat(4, 1, 1, 1, 1) + x.requires_grad_(True) + _ = model(x) + # print(x.shape) # torch.Size([4, 1, 3, 224, 224]) + print("test success") \ No newline at end of file diff --git a/models/sdtv3.py b/models/sdtv3.py new file mode 100644 index 0000000000000000000000000000000000000000..0026c71939e1938df98e7bc20f59e3ecca8808ed --- /dev/null +++ b/models/sdtv3.py @@ -0,0 +1,1739 @@ +import torch +import torchinfo +import torch.nn as nn +from timm.models.layers import to_2tuple, trunc_normal_, DropPath +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg +from einops.layers.torch import Rearrange +import torch.nn.functional as F +from functools import partial + +import os + +class Quant(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, i, min_value, max_value): + ctx.min = min_value + ctx.max = max_value + ctx.save_for_backward(i) + return torch.round(torch.clamp(i, min=min_value, max=max_value)) + + @staticmethod + @torch.cuda.amp.custom_fwd + def backward(ctx, grad_output): + grad_input = grad_output.clone() + i, = ctx.saved_tensors + grad_input[i < ctx.min] = 0 + grad_input[i > ctx.max] = 0 + return grad_input, None, None + +class MultiSpike(nn.Module): + def __init__( + self, + min_value=0, + max_value=4, + Norm=None, + ): + super().__init__() + if Norm == None: + self.Norm = max_value + else: + self.Norm = Norm + self.min_value = min_value + self.max_value = max_value + + @staticmethod + def spike_function(x, min_value, max_value): + return Quant.apply(x, min_value, max_value) + + def __repr__(self): + return f"MultiSpike(Max_Value={self.max_value}, Min_Value={self.min_value}, Norm={self.Norm})" + + def forward(self, x): # B C H W + return self.spike_function(x, min_value=self.min_value, max_value=self.max_value) / (self.Norm) # original + # return self.spike_function(x, min_value=self.min_value, max_value=self.max_value) + +class BNAndPadLayer(nn.Module): + def __init__( + self, + pad_pixels, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + ): + super(BNAndPadLayer, self).__init__() + self.bn = nn.BatchNorm2d( + num_features, eps, momentum, affine, track_running_stats + ) + self.pad_pixels = pad_pixels + + def forward(self, input): + output = self.bn(input) + if self.pad_pixels > 0: + if self.bn.affine: + pad_values = ( + self.bn.bias.detach() + - self.bn.running_mean + * self.bn.weight.detach() + / torch.sqrt(self.bn.running_var + self.bn.eps) + ) + else: + pad_values = -self.bn.running_mean / torch.sqrt( + self.bn.running_var + self.bn.eps + ) + output = F.pad(output, [self.pad_pixels] * 4) + pad_values = pad_values.view(1, -1, 1, 1) + output[:, :, 0 : self.pad_pixels, :] = pad_values + output[:, :, -self.pad_pixels :, :] = pad_values + output[:, :, :, 0 : self.pad_pixels] = pad_values + output[:, :, :, -self.pad_pixels :] = pad_values + return output + + @property + def weight(self): + return self.bn.weight + + @property + def bias(self): + return self.bn.bias + + @property + def running_mean(self): + return self.bn.running_mean + + @property + def running_var(self): + return self.bn.running_var + + @property + def eps(self): + return self.bn.eps + +class RepConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + bias=False, + ): + super().__init__() + # hidden_channel = in_channel + conv1x1 = nn.Conv2d(in_channel, in_channel, 1, 1, 0, bias=False, groups=1) + bn = BNAndPadLayer(pad_pixels=1, num_features=in_channel) + conv3x3 = nn.Sequential( + nn.Conv2d(in_channel, in_channel, 3, 1, 0, groups=in_channel, bias=False), + nn.Conv2d(in_channel, out_channel, 1, 1, 0, groups=1, bias=False), + nn.BatchNorm2d(out_channel), + ) + + self.body = nn.Sequential(conv1x1, bn, conv3x3) + + def forward(self, x): + return self.body(x) + +class SepConv(nn.Module): + r""" + Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. + """ + + def __init__( + self, + dim, + expansion_ratio=2, + act2_layer=nn.Identity, + bias=False, + kernel_size=7, + padding=3, + ): + super().__init__() + med_channels = int(expansion_ratio * dim) + self.spike1 = MultiSpike() + self.pwconv1 = nn.Conv2d(dim, med_channels, kernel_size=1, stride=1, bias=bias) + self.bn1 = nn.BatchNorm2d(med_channels) + self.spike2 = MultiSpike() + self.dwconv = nn.Conv2d( + med_channels, + med_channels, + kernel_size=kernel_size, + padding=padding, + groups=med_channels, + bias=bias, + ) # depthwise conv + self.pwconv2 = nn.Conv2d(med_channels, dim, kernel_size=1, stride=1, bias=bias) + self.bn2 = nn.BatchNorm2d(dim) + + def forward(self, x): + + x = self.spike1(x) + + x = self.bn1(self.pwconv1(x)) + + x = self.spike2(x) + + x = self.dwconv(x) + x = self.bn2(self.pwconv2(x)) + return x + +class SepConv_Spike(nn.Module): + r""" + Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. + """ + + def __init__( + self, + dim, + expansion_ratio=2, + act2_layer=nn.Identity, + bias=False, + kernel_size=7, + padding=3, + ): + super().__init__() + med_channels = int(expansion_ratio * dim) + self.spike1 = MultiSpike() + self.pwconv1 = nn.Sequential( + nn.Conv2d(dim, med_channels, kernel_size=1, stride=1, bias=bias), + nn.BatchNorm2d(med_channels) + ) + self.spike2 = MultiSpike() + self.dwconv = nn.Sequential( + nn.Conv2d(med_channels, med_channels, kernel_size=kernel_size, padding=padding, groups=med_channels, bias=bias), + nn.BatchNorm2d(med_channels) + ) + self.spike3 = MultiSpike() + self.pwconv2 = nn.Sequential( + nn.Conv2d(med_channels, dim, kernel_size=1, stride=1, bias=bias), + nn.BatchNorm2d(dim) + ) + + def forward(self, x): + + x = self.spike1(x) + + x = self.pwconv1(x) + + x = self.spike2(x) + + x = self.dwconv(x) + + x = self.spike3(x) + + x = self.pwconv2(x) + return x + + + +class MS_ConvBlock(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4.0, + ): + super().__init__() + + self.Conv = SepConv(dim=dim) + + self.mlp_ratio = mlp_ratio + + self.spike1 = MultiSpike() + self.conv1 = nn.Conv2d( + dim, dim * mlp_ratio, kernel_size=3, padding=1, groups=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(dim * mlp_ratio) # 这里可以进行改进 + self.spike2 = MultiSpike() + self.conv2 = nn.Conv2d( + dim * mlp_ratio, dim, kernel_size=3, padding=1, groups=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(dim) # 这里可以进行改进 + + def forward(self, x): + B, C, H, W = x.shape + + x = self.Conv(x) + x + x_feat = x + x = self.spike1(x) + x = self.bn1(self.conv1(x)).reshape(B, self.mlp_ratio * C, H, W) + x = self.spike2(x) + x = self.bn2(self.conv2(x)).reshape(B, C, H, W) + x = x_feat + x + + return x + +class MS_ConvBlock_spike_SepConv(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4.0, + ): + super().__init__() + + self.Conv = SepConv_Spike(dim=dim) + + self.mlp_ratio = mlp_ratio + + self.spike1 = MultiSpike() + self.conv1 = nn.Conv2d( + dim, dim * mlp_ratio, kernel_size=3, padding=1, groups=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(dim * mlp_ratio) + self.spike2 = MultiSpike() + self.conv2 = nn.Conv2d( + dim * mlp_ratio, dim, kernel_size=3, padding=1, groups=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(dim) + + def forward(self, x): + B, C, H, W = x.shape + + x = self.Conv(x) + x + x_feat = x + x = self.spike1(x) + x = self.bn1(self.conv1(x)).reshape(B, self.mlp_ratio * C, H, W) + x = self.spike2(x) + x = self.bn2(self.conv2(x)).reshape(B, C, H, W) + x = x_feat + x + + return x + +class MS_ConvBlock_spike_MLP(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4.0, + drop=0., + ): + super().__init__() + drop_probs = to_2tuple(drop) + + self.Conv = SepConv_Spike(dim=dim) + + self.mlp_ratio = mlp_ratio + + self.spike1 = MultiSpike() + self.fc1 = nn.Linear( + dim, dim * mlp_ratio, bias=False + ) + self.drop1 = nn.Dropout(drop_probs[0]) + self.spike2 = MultiSpike() + self.fc2 = nn.Linear( + dim * mlp_ratio, dim, bias=False + ) + self.drop2 = nn.Dropout(drop_probs[1]) + + # ############Version 2################ + # self.spike1 = MultiSpike() + # self.conv1 = nn.Conv2d( + # dim, dim * mlp_ratio, kernel_size=1, bias=False + # ) + # self.bn1 = nn.BatchNorm2d(dim * mlp_ratio) + # self.spike2 = MultiSpike() + # self.conv2 = nn.Conv2d( + # dim * mlp_ratio, dim, kernel_size=1, bias=False + # ) + # self.bn2 = nn.BatchNorm2d(dim) + # ##################################### + + def forward(self, x): + B, C, H, W = x.shape + + x = self.Conv(x) + x + x_feat = x + + x = self.spike1(x) + x = self.drop1(self.fc1(x.reshape(B, H * W, C))).reshape(B, self.mlp_ratio * C, H, W) + x = self.spike2(x) + x = self.drop2(self.fc2(x.reshape(B, H * W, self.mlp_ratio * C))).reshape(B, C, H, W) + + # ############Version 2################ + # x = self.spike1(x) + # x = self.bn1(self.conv1(x)).reshape(B, self.mlp_ratio * C, H, W) + # x = self.spike2(x) + # x = self.bn2(self.conv2(x)).reshape(B, C, H, W) + # ##################################### + + x = x_feat + x + + return x + +class MS_ConvBlock_spike_splash(nn.Module): + def __init__( + self, + dim, + mlp_ratio=8.0, + drop=0., + ): + super().__init__() + drop_probs = to_2tuple(drop) + + self.Conv = SepConv_Spike(dim=dim) + + self.mlp_ratio = mlp_ratio + + self.spike1 = MultiSpike() + self.fc1 = nn.Linear( + dim, dim * mlp_ratio, bias=False + ) + self.drop1 = nn.Dropout(drop_probs[0]) + self.spike2 = MultiSpike() + self.conv2 = nn.Conv2d( + dim * mlp_ratio, dim, kernel_size=3, padding=1, groups=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(dim) + + def forward(self, x): + B, C, H, W = x.shape + + x = self.Conv(x) + x + x_feat = x + x = self.spike1(x) + x = self.drop1(self.fc1(x.reshape(B, H * W, C))).reshape(B, self.mlp_ratio * C, H, W) + x = self.spike2(x) + x = self.bn2(self.conv2(x)).reshape(B, C, H, W) + x = x_feat + x + + return x + +class MS_MLP(nn.Module): + def __init__( + self, in_features, hidden_features=None, out_features=None, drop=0.0, layer=0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1) + self.fc1_bn = nn.BatchNorm1d(hidden_features) + self.fc1_spike = MultiSpike() + + self.fc2_conv = nn.Conv1d( + hidden_features, out_features, kernel_size=1, stride=1 + ) + self.fc2_bn = nn.BatchNorm1d(out_features) + self.fc2_spike = MultiSpike() + + self.c_hidden = hidden_features + self.c_output = out_features + + def forward(self, x): + B, C, H, W = x.shape + N = H * W + x = x.flatten(2) + x = self.fc1_spike(x) + x = self.fc1_conv(x) + x = self.fc1_bn(x).reshape(B, self.c_hidden, N).contiguous() + x = self.fc2_spike(x) + x = self.fc2_conv(x) + x = self.fc2_bn(x).reshape(B, C, H, W).contiguous() + + return x + + + +class MS_Attention_RepConv_qkv_id(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + sr_ratio=1, + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.num_heads = num_heads + self.scale = (dim//num_heads) ** -0.5 + + self.head_spike = MultiSpike() + + self.q_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim)) + + self.k_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim)) + + self.v_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim)) + + self.q_spike = MultiSpike() + + self.k_spike = MultiSpike() + + self.v_spike = MultiSpike() + + self.attn_spike = MultiSpike() + + self.proj_conv = nn.Sequential( + RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim) + ) + + # self.proj_conv = nn.Sequential( + # nn.Conv2d(dim, dim, 1, 1, bias=False), nn.BatchNorm2d(dim) + # ) + + + def forward(self, x): + B, C, H, W = x.shape + N = H * W + + x = self.head_spike(x) + + q = self.q_conv(x) + k = self.k_conv(x) + v = self.v_conv(x) + + q = self.q_spike(q) + q = q.flatten(2) + q = ( + q.transpose(-1, -2) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + k = self.k_spike(k) + k = k.flatten(2) + k = ( + k.transpose(-1, -2) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + v = self.v_spike(v) + v = v.flatten(2) + v = ( + v.transpose(-1, -2) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + x = k.transpose(-2, -1) @ v + x = (q @ x) * self.scale + + x = x.transpose(2, 3).reshape(B, C, N).contiguous() + x = self.attn_spike(x) + x = x.reshape(B, C, H, W) + x = self.proj_conv(x).reshape(B, C, H, W) + + return x + +class MS_Attention_linear(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + sr_ratio=1, + lamda_ratio=1, + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.num_heads = num_heads + self.scale = (dim//num_heads) ** -0.5 + self.lamda_ratio = lamda_ratio + + self.head_spike = MultiSpike() + + self.q_conv = nn.Sequential(nn.Conv2d(dim, dim, 1, 1, bias=False), nn.BatchNorm2d(dim)) + + self.q_spike = MultiSpike() + + self.k_conv = nn.Sequential(nn.Conv2d(dim, dim, 1, 1, bias=False), nn.BatchNorm2d(dim)) + + self.k_spike = MultiSpike() + + self.v_conv = nn.Sequential(nn.Conv2d(dim, int(dim*lamda_ratio), 1, 1, bias=False), nn.BatchNorm2d(int(dim*lamda_ratio))) + + self.v_spike = MultiSpike() + + self.attn_spike = MultiSpike() + + + self.proj_conv = nn.Sequential( + nn.Conv2d(dim*lamda_ratio, dim, 1, 1, bias=False), nn.BatchNorm2d(dim) + ) + + + def forward(self, x): + B, C, H, W = x.shape + N = H * W + C_v = int(C*self.lamda_ratio) + + x = self.head_spike(x) + + q = self.q_conv(x) + k = self.k_conv(x) + v = self.v_conv(x) + + q = self.q_spike(q) + q = q.flatten(2) + q = ( + q.transpose(-1, -2) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + k = self.k_spike(k) + k = k.flatten(2) + k = ( + k.transpose(-1, -2) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + v = self.v_spike(v) + v = v.flatten(2) + v = ( + v.transpose(-1, -2) + .reshape(B, N, self.num_heads, C_v // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + x = q @ k.transpose(-2, -1) + # x = (x @ v) * (self.scale*2) + x = x @ v * self.scale + + x = x.transpose(2, 3).reshape(B, C_v, N).contiguous() + x = self.attn_spike(x) + x = x.reshape(B, C_v, H, W) + x = self.proj_conv(x).reshape(B, C, H, W) + + return x + + + + +class MS_Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ): + super().__init__() + + self.attn = MS_Attention_RepConv_qkv_id( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) + + def forward(self, x): + x = x + self.attn(x) + x = x + self.mlp(x) + + return x + +class MS_Block_Spike_SepConv(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + sr_ratio=1, + init_values = 1e-6 + ): + super().__init__() + + self.conv = SepConv_Spike(dim=dim, kernel_size=3, padding=1) + + self.attn = MS_Attention_linear( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + lamda_ratio=4, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) + + + def forward(self, x): + x = x + self.conv(x) + x = x + self.attn(x) + x = x + self.mlp(x) + + return x + + +class MS_DownSampling(nn.Module): + def __init__( + self, + in_channels=2, + embed_dims=256, + kernel_size=3, + stride=2, + padding=1, + first_layer=True, + T=None, + ): + super().__init__() + + self.encode_conv = nn.Conv2d( + in_channels, + embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + self.encode_bn = nn.BatchNorm2d(embed_dims) + self.first_layer = first_layer + if not first_layer: + self.encode_spike = MultiSpike() + + def forward(self, x): + + if hasattr(self, "encode_spike"): + x = self.encode_spike(x) + x = self.encode_conv(x) + x = self.encode_bn(x) + + return x + + + +class Spiking_vit_MetaFormer_Spike_SepConv(nn.Module): + def __init__( + self, + img_size_h=128, + img_size_w=128, + in_channels=1, + num_classes=11, + embed_dim=[64, 128, 256], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[6, 8, 6], + sr_ratios=[8, 4, 2], + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths) + ] # stochastic depth decay rule + + self.downsample1_1 = MS_DownSampling( + in_channels=in_channels, + embed_dims=embed_dim[0] // 2, + kernel_size=7, + stride=2, + padding=3, + first_layer=True, + + ) + + self.ConvBlock1_1 = nn.ModuleList( + [MS_ConvBlock_spike_SepConv(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)] + ) + + self.downsample1_2 = MS_DownSampling( + in_channels=embed_dim[0] // 2, + embed_dims=embed_dim[0], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock1_2 = nn.ModuleList( + [MS_ConvBlock_spike_SepConv(dim=embed_dim[0], mlp_ratio=mlp_ratios)] + ) + + self.downsample2 = MS_DownSampling( + in_channels=embed_dim[0], + embed_dims=embed_dim[1], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock2_1 = nn.ModuleList( + [MS_ConvBlock_spike_SepConv(dim=embed_dim[1], mlp_ratio=mlp_ratios)] + ) + + self.ConvBlock2_2 = nn.ModuleList( + [MS_ConvBlock_spike_SepConv(dim=embed_dim[1], mlp_ratio=mlp_ratios)] + ) + + self.downsample3 = MS_DownSampling( + in_channels=embed_dim[1], + embed_dims=embed_dim[2], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.block3 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[2], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(6) + ] + ) + + self.downsample4 = MS_DownSampling( + in_channels=embed_dim[2], + embed_dims=embed_dim[3], + kernel_size=3, + stride=1, + padding=1, + first_layer=False, + + ) + + self.block4 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[3], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(2) + ] + ) + + self.head = ( + nn.Linear(embed_dim[3], num_classes) if num_classes > 0 else nn.Identity() + ) + self.spike = MultiSpike(Norm=1) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + x = self.downsample1_1(x) + for blk in self.ConvBlock1_1: + x = blk(x) + x = self.downsample1_2(x) + for blk in self.ConvBlock1_2: + x = blk(x) + + x = self.downsample2(x) + for blk in self.ConvBlock2_1: + x = blk(x) + for blk in self.ConvBlock2_2: + x = blk(x) + + x = self.downsample3(x) + for blk in self.block3: + x = blk(x) + + x = self.downsample4(x) + for blk in self.block4: + x = blk(x) + + return x # T,B,C,N + + def forward(self, x): + x = self.forward_features(x) # B,C,H,W + x = x.flatten(2).mean(2) + x = self.spike(x) + x = self.head(x) + return x + +class Spiking_vit_MetaFormer_Spike_SepConv_ChannelMLP(nn.Module): + def __init__( + self, + img_size_h=128, + img_size_w=128, + in_channels=1, + num_classes=11, + embed_dim=[64, 128, 256], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.8, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[6, 8, 6], + sr_ratios=[8, 4, 2], + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths) + ] # stochastic depth decay rule + + self.downsample1_1 = MS_DownSampling( + in_channels=in_channels, + embed_dims=embed_dim[0] // 2, + kernel_size=7, + stride=2, + padding=3, + first_layer=True, + + ) + + self.ConvBlock1_1 = nn.ModuleList( + [MS_ConvBlock_spike_MLP( + dim=embed_dim[0] // 2, + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample1_2 = MS_DownSampling( + in_channels=embed_dim[0] // 2, + embed_dims=embed_dim[0], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock1_2 = nn.ModuleList( + [MS_ConvBlock_spike_MLP( + dim=embed_dim[0], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample2 = MS_DownSampling( + in_channels=embed_dim[0], + embed_dims=embed_dim[1], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock2_1 = nn.ModuleList( + [MS_ConvBlock_spike_MLP( + dim=embed_dim[1], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.ConvBlock2_2 = nn.ModuleList( + [MS_ConvBlock_spike_MLP( + dim=embed_dim[1], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample3 = MS_DownSampling( + in_channels=embed_dim[1], + embed_dims=embed_dim[2], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.block3 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[2], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(6) + ] + ) + + self.downsample4 = MS_DownSampling( + in_channels=embed_dim[2], + embed_dims=embed_dim[3], + kernel_size=3, + stride=1, + padding=1, + first_layer=False, + + ) + + self.block4 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[3], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(2) + ] + ) + + self.head = ( + nn.Linear(embed_dim[3], num_classes) if num_classes > 0 else nn.Identity() + ) + self.spike = MultiSpike(Norm=1) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + x = self.downsample1_1(x) + for blk in self.ConvBlock1_1: + x = blk(x) + x = self.downsample1_2(x) + for blk in self.ConvBlock1_2: + x = blk(x) + + x = self.downsample2(x) + for blk in self.ConvBlock2_1: + x = blk(x) + for blk in self.ConvBlock2_2: + x = blk(x) + + x = self.downsample3(x) + for blk in self.block3: + x = blk(x) + + x = self.downsample4(x) + for blk in self.block4: + x = blk(x) + + return x # T,B,C,N + + def forward(self, x): + x = self.forward_features(x) # B,C,H,W + x = x.flatten(2).mean(2) + x = self.spike(x) + x = self.head(x) + return x + +class Spiking_vit_MetaFormer_FullAttention(nn.Module): + def __init__( + self, + img_size_h=128, + img_size_w=128, + in_channels=1, + num_classes=11, + embed_dim=[64, 128, 256], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.2, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[6, 8, 6], + sr_ratios=[8, 4, 2], + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths) + ] # stochastic depth decay rule + + self.downsample1_1 = MS_DownSampling( + in_channels=in_channels, + embed_dims=embed_dim[0] // 2, + kernel_size=7, + stride=2, + padding=3, + first_layer=True, + + ) + + self.ConvBlock1_1 = nn.ModuleList( + [MS_Block_Spike_SepConv( + dim=embed_dim[0] // 2, + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + sr_ratio=sr_ratios, + )] + ) + + self.downsample1_2 = MS_DownSampling( + in_channels=embed_dim[0] // 2, + embed_dims=embed_dim[0], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock1_2 = nn.ModuleList( + [MS_Block_Spike_SepConv( + dim=embed_dim[0], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + sr_ratio=sr_ratios, + )] + ) + + self.downsample2 = MS_DownSampling( + in_channels=embed_dim[0], + embed_dims=embed_dim[1], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock2_1 = nn.ModuleList( + [MS_Block_Spike_SepConv( + dim=embed_dim[1], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + sr_ratio=sr_ratios, + )] + ) + + self.ConvBlock2_2 = nn.ModuleList( + [MS_Block_Spike_SepConv( + dim=embed_dim[1], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + norm_layer=norm_layer, + sr_ratio=sr_ratios, + )] + ) + + self.downsample3 = MS_DownSampling( + in_channels=embed_dim[1], + embed_dims=embed_dim[2], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.block3 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[2], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(6) + ] + ) + + self.downsample4 = MS_DownSampling( + in_channels=embed_dim[2], + embed_dims=embed_dim[3], + kernel_size=3, + stride=1, + padding=1, + first_layer=False, + + ) + + self.block4 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[3], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(2) + ] + ) + + self.head = ( + nn.Linear(embed_dim[3], num_classes) if num_classes > 0 else nn.Identity() + ) + self.spike = MultiSpike(Norm=1) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + x = self.downsample1_1(x) + for blk in self.ConvBlock1_1: + x = blk(x) + x = self.downsample1_2(x) + for blk in self.ConvBlock1_2: + x = blk(x) + + x = self.downsample2(x) + for blk in self.ConvBlock2_1: + x = blk(x) + for blk in self.ConvBlock2_2: + x = blk(x) + + x = self.downsample3(x) + for blk in self.block3: + x = blk(x) + + x = self.downsample4(x) + for blk in self.block4: + x = blk(x) + + return x # T,B,C,N + + def forward(self, x): + x = self.forward_features(x) # B,C,H,W + x = x.flatten(2).mean(2) + x = self.spike(x) + x = self.head(x) + return x + +class Spiking_vit_MetaFormer_B3(nn.Module): + def __init__( + self, + img_size_h=128, + img_size_w=128, + in_channels=1, + num_classes=11, + embed_dim=[64, 128, 256], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[6, 8, 6], + sr_ratios=[8, 4, 2], + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths) + ] # stochastic depth decay rule + self.block3 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[2], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + ) + for j in range(6) + ] + ) + + self.spike = MultiSpike(Norm=1) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + for blk in self.block3: + x = blk(x) + + return x # T,B,C,N + + def forward(self, x): + x = self.forward_features(x) # B,C,H,W + return x + +class Spiking_vit_MetaFormer_Spike_SepConv_splash(nn.Module): + def __init__( + self, + img_size_h=128, + img_size_w=128, + in_channels=1, + num_classes=11, + embed_dim=[64, 128, 256], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.2, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[6, 8, 6], + sr_ratios=[8, 4, 2], + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths) + ] # stochastic depth decay rule + + self.downsample1_1 = MS_DownSampling( + in_channels=in_channels, + embed_dims=embed_dim[0] // 2, + kernel_size=7, + stride=2, + padding=3, + first_layer=True, + + ) + + self.ConvBlock1_1 = nn.ModuleList( + [MS_ConvBlock_spike_splash( + dim=embed_dim[0] // 2, + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample1_2 = MS_DownSampling( + in_channels=embed_dim[0] // 2, + embed_dims=embed_dim[0], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock1_2 = nn.ModuleList( + [MS_ConvBlock_spike_splash( + dim=embed_dim[0], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample2 = MS_DownSampling( + in_channels=embed_dim[0], + embed_dims=embed_dim[1], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock2_1 = nn.ModuleList( + [MS_ConvBlock_spike_splash( + dim=embed_dim[1], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.ConvBlock2_2 = nn.ModuleList( + [MS_ConvBlock_spike_splash( + dim=embed_dim[1], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample3 = MS_DownSampling( + in_channels=embed_dim[1], + embed_dims=embed_dim[2], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.block3 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[2], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(6) + ] + ) + + self.downsample4 = MS_DownSampling( + in_channels=embed_dim[2], + embed_dims=embed_dim[3], + kernel_size=3, + stride=1, + padding=1, + first_layer=False, + + ) + + self.block4 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[3], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(2) + ] + ) + + self.head = ( + nn.Linear(embed_dim[3], num_classes) if num_classes > 0 else nn.Identity() + ) + self.spike = MultiSpike(Norm=1) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + x = self.downsample1_1(x) + for blk in self.ConvBlock1_1: + x = blk(x) + x = self.downsample1_2(x) + for blk in self.ConvBlock1_2: + x = blk(x) + + x = self.downsample2(x) + for blk in self.ConvBlock2_1: + x = blk(x) + for blk in self.ConvBlock2_2: + x = blk(x) + + x = self.downsample3(x) + for blk in self.block3: + x = blk(x) + + x = self.downsample4(x) + for blk in self.block4: + x = blk(x) + + return x # T,B,C,N + + def forward(self, x): + x = self.forward_features(x) # B,C,H,W + x = x.flatten(2).mean(2) + x = self.spike(x) + x = self.head(x) + return x + +def sdtv3_m_splash(**kwargs): + #4.8M + model = Spiking_vit_MetaFormer_Spike_SepConv_splash( + img_size_h=224, + img_size_w=224, + embed_dim=[48, 96, 192, 240], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_s_splash(**kwargs): + #4.8M + model = Spiking_vit_MetaFormer_Spike_SepConv_splash( + img_size_h=224, + img_size_w=224, + embed_dim=[32, 64, 128, 192], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_l(**kwargs): + #19.0M + model = Spiking_vit_MetaFormer_Spike_SepConv( + img_size_h=224, + img_size_w=224, + embed_dim=[64, 128, 256, 360], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_m(**kwargs): + #10.0M + model = Spiking_vit_MetaFormer_Spike_SepConv( + img_size_h=224, + img_size_w=224, + embed_dim=[48, 96, 192, 240], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + + +def sdtv3_s(**kwargs): + #5.1M + model = Spiking_vit_MetaFormer_Spike_SepConv( + img_size_h=224, + img_size_w=224, + embed_dim=[32, 64, 128, 192], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_s_channelmlp(**kwargs): + #5.1M + model = Spiking_vit_MetaFormer_Spike_SepConv_ChannelMLP( + img_size_h=224, + img_size_w=224, + embed_dim=[32, 64, 128, 192], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_s_fullattn(**kwargs): + #5.1M + model = Spiking_vit_MetaFormer_FullAttention( + img_size_h=224, + img_size_w=224, + embed_dim=[32, 64, 128, 192], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_t(**kwargs): + model = Spiking_vit_MetaFormer_Spike_SepConv( + img_size_h=224, + img_size_w=224, + embed_dim=[24, 48, 96, 128], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_s_attn(**kwargs): + model = Spiking_vit_MetaFormer_B3( + img_size_h=224, + img_size_w=224, + embed_dim=[32, 64, 128, 192], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + + +from timm.models import create_model +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +if __name__ == "__main__": + model = sdtv3_t() + print(model) + x = torch.randn(1,3,224,224) + y = model(x) + torchinfo.summary(model, (1, 3, 224, 224), device='cpu', verbose=1, col_names=["input_size", "output_size", "num_params", "mult_adds"]) \ No newline at end of file diff --git a/models/sdtv3_large.py b/models/sdtv3_large.py new file mode 100644 index 0000000000000000000000000000000000000000..e7cecde51d5f037bf65c39ce76aa24fc1b2e4c18 --- /dev/null +++ b/models/sdtv3_large.py @@ -0,0 +1,511 @@ +from functools import partial +import torch +import torch.nn as nn +import torch +import torch.nn as nn +from spikingjelly.clock_driven import layer +from timm.models.layers import to_2tuple, trunc_normal_, DropPath +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg +from einops.layers.torch import Rearrange +import torch.nn.functional as F +from timm.models.vision_transformer import PatchEmbed, Block +from .util.pos_embed import get_2d_sincos_pos_embed + +import copy +from torchvision import transforms +import matplotlib.pyplot as plt +import torch.nn as nn +#timestep 1x4 +T=4 + +class multispike(torch.autograd.Function): + @staticmethod + def forward(ctx, input, lens=T): + ctx.save_for_backward(input) + ctx.lens = lens + return torch.floor(torch.clamp(input, 0, lens) + 0.5) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + grad_input = grad_output.clone() + temp1 = 0 < input + temp2 = input < ctx.lens + return grad_input * temp1.float() * temp2.float(), None + + +class Multispike(nn.Module): + def __init__(self, spike=multispike,norm=T): + super().__init__() + self.lens = norm + self.spike = spike + self.norm=norm + + def forward(self, inputs): + return self.spike.apply(inputs)/self.norm + + +def MS_conv_unit(in_channels, out_channels,kernel_size=1,padding=0,groups=1): + return nn.Sequential( + layer.SeqToANNContainer( + nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, groups=groups,bias=True), + nn.BatchNorm2d(out_channels) + ) + ) +class MS_ConvBlock(nn.Module): + def __init__(self, dim, + mlp_ratio=4.0): + super().__init__() + + self.neuron1 = Multispike() + self.conv1 = MS_conv_unit(dim, dim * mlp_ratio, 3, 1) + + self.neuron2 = Multispike() + self.conv2 = MS_conv_unit(dim*mlp_ratio, dim, 3, 1) + + + def forward(self, x, mask=None): + short_cut = x + x = self.neuron1(x) + x = self.conv1(x) + x = self.neuron2(x) + x = self.conv2(x) + x = x +short_cut + return x + +class MS_MLP(nn.Module): + def __init__( + self, in_features, hidden_features=None, out_features=None, drop=0.0, layer=0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1) + self.fc1_bn = nn.BatchNorm1d(hidden_features) + self.fc1_lif = Multispike() + + + self.fc2_conv = nn.Conv1d( + hidden_features, out_features, kernel_size=1, stride=1 + ) + self.fc2_bn = nn.BatchNorm1d(out_features) + self.fc2_lif = Multispike() + + self.c_hidden = hidden_features + self.c_output = out_features + + def forward(self, x): + T, B, C, N= x.shape + + x = self.fc1_lif(x) + x = self.fc1_conv(x.flatten(0, 1)) + x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous() + + x = self.fc2_lif(x) + x = self.fc2_conv(x.flatten(0, 1)) + x = self.fc2_bn(x).reshape(T, B, C, N).contiguous() + + return x + +class RepConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + bias=False, + ): + super().__init__() + # TODO in_channel-> 2*in_channel->in_channel + self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5))) + self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel)) + def forward(self, x): + return self.conv2(self.conv1(x)) +class RepConv2(nn.Module): + def __init__( + self, + in_channel, + out_channel, + bias=False, + ): + super().__init__() + # TODO in_channel-> 2*in_channel->in_channel + self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel))) + self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel)) + def forward(self, x): + return self.conv2(self.conv1(x)) + +class MS_Attention_Conv_qkv_id(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.num_heads = num_heads + self.scale = 0.125 + self.sr_ratio=sr_ratio + + self.head_lif = Multispike() + + # track 1: split convs + self.q_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim)) + self.k_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim)) + self.v_conv = nn.Sequential(RepConv(dim,dim*sr_ratio), nn.BatchNorm1d(dim*sr_ratio)) + + # track 2: merge (prefer) NOTE: need `chunk` in forward + # self.qkv_conv = nn.Sequential(RepConv(dim,dim * 3), nn.BatchNorm2d(dim * 3)) + + self.q_lif = Multispike() + + self.k_lif = Multispike() + + self.v_lif = Multispike() + + self.attn_lif = Multispike() + + self.proj_conv = nn.Sequential(RepConv(sr_ratio*dim,dim), nn.BatchNorm1d(dim)) + + def forward(self, x): + T, B, C, N = x.shape + + x = self.head_lif(x) + + x_for_qkv = x.flatten(0, 1) + q_conv_out = self.q_conv(x_for_qkv).reshape(T, B, C, N) + + q_conv_out = self.q_lif(q_conv_out) + + q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, + 4) + + k_conv_out = self.k_conv(x_for_qkv).reshape(T, B, C, N) + + k_conv_out = self.k_lif(k_conv_out) + + k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, + 4) + + v_conv_out = self.v_conv(x_for_qkv).reshape(T, B, self.sr_ratio*C, N) + + v_conv_out = self.v_lif(v_conv_out) + + v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, self.sr_ratio*C // self.num_heads).permute(0, 1, 3, 2, + 4) + + x = k.transpose(-2, -1) @ v + x = (q @ x) * self.scale + x = x.transpose(3, 4).reshape(T, B, self.sr_ratio*C, N) + x = self.attn_lif(x) + + x = self.proj_conv(x.flatten(0, 1)).reshape(T, B, C, N) + return x + + + + +class MS_Block(nn.Module): + def __init__( + self, + dim, + choice, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + sr_ratio=1,init_values=1e-6,finetune=False, + ): + super().__init__() + self.model=choice + if self.model=="base": + self.rep_conv=RepConv2(dim,dim) #if have param==83M + self.lif = Multispike() + self.attn = MS_Attention_Conv_qkv_id( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + ) + self.finetune = finetune + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) + + if self.finetune: + self.layer_scale1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + self.layer_scale2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) + + def forward(self, x): + T, B, C, N = x.shape + if self.model=="base": + x= x + self.rep_conv(self.lif(x).flatten(0, 1)).reshape(T, B, C, N) + # TODO: need channel-wise layer scale, init as 1e-6 + if self.finetune: + x = x + self.drop_path(self.attn(x) * self.layer_scale1.unsqueeze(0).unsqueeze(0).unsqueeze(-1)) + x = x + self.drop_path(self.mlp(x) * self.layer_scale2.unsqueeze(0).unsqueeze(0).unsqueeze(-1)) + else: + x = x + self.attn(x) + x = x + self.mlp(x) + return x + + +class MS_DownSampling(nn.Module): + def __init__( + self, + in_channels=2, + embed_dims=256, + kernel_size=3, + stride=2, + padding=1, + first_layer=True, + ): + super().__init__() + + self.encode_conv = nn.Conv2d( + in_channels, + embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + self.encode_bn =nn.BatchNorm2d(embed_dims) + if not first_layer: + self.encode_lif = Multispike() + + def forward(self, x): + T, B, _, _, _ = x.shape + if hasattr(self, "encode_lif"): + x = self.encode_lif(x) + x = self.encode_conv(x.flatten(0, 1)) + _, _, H, W = x.shape + x = self.encode_bn(x).reshape(T, B, -1, H, W).contiguous() + return x + + +class Spikformer(nn.Module): + def __init__(self, T=1, + choice=None, + img_size_h=224, + img_size_w=224, + patch_size=16, + embed_dim=[128, 256, 512, 640], + num_heads=8, + mlp_ratios=4, + in_channels=3, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), #norm_layer=nn.LayerNorm shaokun + depths=8, + sr_ratios=1, + mlp_ratio=4., + nb_classes=1000, + kd=True): + super().__init__() + + ### MAE encoder spikformer + self.T = T + self.patch_size = patch_size + self.embed_dim =embed_dim + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths) + ] # stochastic depth decay rule + self.downsample1_1 = MS_DownSampling( + in_channels=in_channels, + embed_dims=embed_dim[0] // 2, + kernel_size=7, + stride=2, + padding=3, + first_layer=True, + ) + + self.ConvBlock1_1 = nn.ModuleList( + [MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)] + ) + + self.downsample1_2 = MS_DownSampling( + in_channels=embed_dim[0] // 2, + embed_dims=embed_dim[0], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock1_2 = nn.ModuleList( + [MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)] + ) + + self.downsample2 = MS_DownSampling( + in_channels=embed_dim[0], + embed_dims=embed_dim[1], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock2_1 = nn.ModuleList( + [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] + ) + + self.ConvBlock2_2 = nn.ModuleList( + [MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] + ) + + self.downsample3 = MS_DownSampling( + in_channels=embed_dim[1], + embed_dims=embed_dim[2], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.block3 = nn.ModuleList( + [ + MS_Block( + dim=embed_dim[2], + choice=choice, + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + finetune=True, + ) + for j in range(depths) + ] + ) + self.head = nn.Linear(embed_dim[2], nb_classes) + self.lif = Multispike(norm=1) + self.kd = kd + if self.kd: + self.head_kd = ( + nn.Linear(embed_dim[-1], num_classes) + if num_classes > 0 + else nn.Identity() + ) + self.initialize_weights() + + def initialize_weights(self): + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + + def forward_encoder(self, x ): + x = (x.unsqueeze(0)).repeat(self.T, 1, 1, 1, 1) + + + x = self.downsample1_1(x) + + for blk in self.ConvBlock1_1: + x = blk(x) + + x = self.downsample1_2(x) + for blk in self.ConvBlock1_2: + x = blk(x) + + x = self.downsample2(x) + + for blk in self.ConvBlock2_1: + x = blk(x) + + for blk in self.ConvBlock2_2: + x = blk(x) + x = self.downsample3(x) + x = x.flatten(3) # T,B,C,N + + for blk in self.block3: + x = blk(x) + + return x + + + def forward(self, imgs): + x = self.forward_encoder(imgs) + + x = x.flatten(3).mean(3) + x_lif = self.lif(x) + x = self.head(x).mean(0) + + if self.kd: + x_kd = self.head_kd(x_lif).mean(0) + if self.training: + return x, x_kd + else: + return (x + x_kd) / 2 + return x + + + +def spikformer12_512(**kwargs): + model = Spikformer( + T=1, + choice="base", + img_size_h=32, + img_size_w=32, + patch_size=16, + embed_dim=[128,256,512], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=12, + **kwargs) + return model +def spikformer12_768(**kwargs): + model = Spikformer( + T=1, + choice="large", + img_size_h=32, + img_size_w=32, + patch_size=16, + embed_dim=[196, 384, 768], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=12, + **kwargs) + return model + + + + + +if __name__ == "__main__": + # from encoder import SparseEncoder,nn.Conv2d + import torchinfo + + model = spikformer12_512() + + + print(f"number of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") + diff --git a/models/spikformer.py b/models/spikformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b4cc1b6846f0188c485dad2b79831b7d62c27bb --- /dev/null +++ b/models/spikformer.py @@ -0,0 +1,374 @@ +# from visualizer import get_local +import torch +import torch.nn as nn +from .neuron import MultiStepParametricLIFNode, MultiStepLIFNode +from spikingjelly.clock_driven import layer +from timm.models.layers import to_2tuple, trunc_normal_, DropPath +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg +from einops.layers.torch import Rearrange +import torch.nn.functional as F +from functools import partial + +__all__ = ['vit_snn',] + + +def compute_non_zero_rate(x): + x_shape = torch.tensor(list(x.shape)) + all_neural = torch.prod(x_shape) + z = torch.nonzero(x) + print("After attention proj the none zero rate is", z.shape[0]/all_neural) + + +class TemporalChangeScorer(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, prev_drop_mask=None): # x: [T,B,C,H,W] + T,B,C,H,W = x.shape + x_mean = torch.mean(x, dim=2) # [T,B,H,W] + temporal_diff = x_mean[1:] - x_mean[:-1] # [T-1,B,H,W] + avg_temporal_change = torch.abs(temporal_diff).mean(dim=0) # [B,H,W] + scores = avg_temporal_change.view(avg_temporal_change.shape[0], -1) # [B,H*W] + + if prev_drop_mask is not None: + scores = scores.masked_fill(prev_drop_mask, float('-inf')) + + scores = F.softmax(scores, dim=1) + + return scores.reshape(B,H,W) # [B,H,W] + +class LocalSpatialSimilarity(nn.Module): + def __init__(self,embedding_dim=None): + super().__init__() + self.cosine = nn.CosineSimilarity(dim=1, eps=1e-6) + # self.local_weight = nn.Parameter(torch.tensor(0.5), requires_grad=True) + # self.global_weight = nn.Parameter(torch.tensor(0.5), requires_grad=True) + + # self.conv = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, stride=1, padding=1) + + def forward(self, x, prev_drop_mask=None): # x: [T,B,C,H,W] + T,B,C,H,W = x.shape + + x = torch.mean(x, dim=0) # [B,C,H,W] + avg_kernel = torch.ones(C, C, 3, 3).to(x.device)/9.0 + local_mean = F.conv2d(x, avg_kernel, padding=1) # [B,C,H,W] + # local_mean = self.conv(x) # [B,C,H,W] + + x_flat = x.view(B, C, -1) # [B,C,H*W] + local_mean_flat = local_mean.view(B, C, -1) # [B,C,H*W] + + sim = self.cosine(x_flat, local_mean_flat) # [B,H*W] + scores = -sim + + if prev_drop_mask is not None: + scores = scores.masked_fill(prev_drop_mask, float('-inf')) + scores = F.softmax(scores, dim=1) + + return scores.reshape(B,H,W) # [B,H,W] + +class MLP(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + # self.fc1 = linear_unit(in_features, hidden_features) + self.fc1_conv = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1) + self.fc1_bn = nn.BatchNorm2d(hidden_features) + self.fc1_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + # self.fc2 = linear_unit(hidden_features, out_features) + self.fc2_conv = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1) + self.fc2_bn = nn.BatchNorm2d(out_features) + self.fc2_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + # self.drop = nn.Dropout(0.1) + + self.c_hidden = hidden_features + self.c_output = out_features + def forward(self, x): + T,B,C,H,W = x.shape + x = self.fc1_conv(x.flatten(0,1)) + x = self.fc1_bn(x).reshape(T,B,self.c_hidden,H,W).contiguous() + x = self.fc1_lif(x) + + x = self.fc2_conv(x.flatten(0,1)) + x = self.fc2_bn(x).reshape(T,B,C,H,W).contiguous() + x = self.fc2_lif(x) + return x + + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = 0.125 + self.q_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1,bias=False) + self.q_bn = nn.BatchNorm1d(dim) + self.q_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.k_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1,bias=False) + self.k_bn = nn.BatchNorm1d(dim) + self.k_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.v_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1,bias=False) + self.v_bn = nn.BatchNorm1d(dim) + self.v_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + self.attn_lif = MultiStepLIFNode(tau=2.0, v_threshold=0.5, detach_reset=True, backend='cupy') + + self.proj_conv = nn.Conv1d(dim, dim, kernel_size=1, stride=1) + self.proj_bn = nn.BatchNorm1d(dim) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + self.qkv_mp = nn.MaxPool1d(4) + + def forward(self, x): + T,B,C,H,W = x.shape + x = x.flatten(3) + T, B, C, N = x.shape + x_for_qkv = x.flatten(0, 1) + x_feat = x + q_conv_out = self.q_conv(x_for_qkv) + q_conv_out = self.q_bn(q_conv_out).reshape(T,B,C,N).contiguous() + q_conv_out = self.q_lif(q_conv_out) + q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous() + + k_conv_out = self.k_conv(x_for_qkv) + k_conv_out = self.k_bn(k_conv_out).reshape(T,B,C,N).contiguous() + k_conv_out = self.k_lif(k_conv_out) + k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous() + + + v_conv_out = self.v_conv(x_for_qkv) + v_conv_out = self.v_bn(v_conv_out).reshape(T,B,C,N).contiguous() + v_conv_out = self.v_lif(v_conv_out) + v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C//self.num_heads).permute(0, 1, 3, 2, 4).contiguous() + +# if res_attn != None: +# v = v + res_attn + + x = k.transpose(-2, -1) @ v + x = (q @ x) * self.scale + + x = x.transpose(3, 4).reshape(T, B, C, N).contiguous() + x = self.attn_lif(x) + x = x.flatten(0,1) + x = self.proj_lif(self.proj_bn(self.proj_conv(x)).reshape(T,B,C,H,W)) + + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, sr_ratio=1, k_value=1.): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) + + self.temporal_scorer = TemporalChangeScorer() + self.spatial_scorer = LocalSpatialSimilarity() + self.k_value = k_value + + def forward(self, x): + T,B,C,H,W = x.shape + temporal_score = self.temporal_scorer(x) # [B,H,W] + spatial_score = self.spatial_scorer(x) # [B,H,W] + final_score = temporal_score+spatial_score + + self.k = int(H*self.k_value) + + flat_scores = final_score.view(B, -1) # [B, H*W] + _, indices = torch.topk(flat_scores, k=self.k*self.k, dim=1) # [B, kk] + # indices = torch.tensor([[_ for _ in range(indices.shape[1])] for __ in range(indices.shape[0])]).cuda() + token_indices = indices.unsqueeze(0).expand(T,-1,-1) # [T,B,kk] + + x = x.flatten(3) # [T,B,C,N] where N = H*W + original_x = x.clone() + + # slow_path + informative_tokens = x.gather(3, token_indices.unsqueeze(2).expand(-1, -1, C, -1)) # [T,B,C,kk] + slow_x = informative_tokens.reshape(T, B, C, self.k, self.k) + slow_x = slow_x + self.attn(slow_x) + slow_x = slow_x + self.mlp(slow_x) + + x = original_x.scatter_(3, token_indices.unsqueeze(2).expand(-1, -1, C, -1), slow_x.reshape(T,B,C,self.k*self.k)) + x = x.reshape(T,B,C,H,W) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, img_size_h=128, img_size_w=128, patch_size=4, in_channels=2, embed_dims=256): + super().__init__() + self.image_size = [img_size_h, img_size_w] + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.C = in_channels + self.H, self.W = self.image_size[0] // patch_size[0], self.image_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj_conv = nn.Conv2d(in_channels, embed_dims//8, kernel_size=3, stride=1, padding=1, bias=False) + self.proj_bn = nn.BatchNorm2d(embed_dims//8) + self.proj_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + + self.proj_conv1 = nn.Conv2d(embed_dims//8, embed_dims//4, kernel_size=3, stride=1, padding=1, bias=False) + self.proj_bn1 = nn.BatchNorm2d(embed_dims//4) + self.proj_lif1 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + + self.proj_conv2 = nn.Conv2d(embed_dims//4, embed_dims//2, kernel_size=3, stride=1, padding=1, bias=False) + self.proj_bn2 = nn.BatchNorm2d(embed_dims//2) + self.proj_lif2 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + self.maxpool2 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + + self.proj_conv3 = nn.Conv2d(embed_dims//2, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.proj_bn3 = nn.BatchNorm2d(embed_dims) + self.proj_lif3 = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + self.maxpool3 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) + + self.rpe_conv = nn.Conv2d(embed_dims, embed_dims, kernel_size=3, stride=1, padding=1, bias=False) + self.rpe_bn = nn.BatchNorm2d(embed_dims) + self.rpe_lif = MultiStepLIFNode(tau=2.0, detach_reset=True, backend='cupy') + + def forward(self, x): + T, B, C, H, W = x.shape + x = self.proj_conv(x.flatten(0, 1)) # have some fire value + x = self.proj_bn(x).reshape(T, B, -1, H, W).contiguous() + x = self.proj_lif(x).flatten(0,1).contiguous() + x = self.maxpool(x) + + x = self.proj_conv1(x) + x = self.proj_bn1(x).reshape(T, B, -1, H//2, W//2).contiguous() + x = self.proj_lif1(x).flatten(0, 1).contiguous() + x = self.maxpool1(x) + + x = self.proj_conv2(x) + x = self.proj_bn2(x).reshape(T, B, -1, H//4, W//4).contiguous() + x = self.proj_lif2(x).flatten(0, 1).contiguous() + x = self.maxpool2(x) + + x = self.proj_conv3(x) + x = self.proj_bn3(x).reshape(T, B, -1, H//8, W//8).contiguous() + x = self.proj_lif3(x).flatten(0, 1).contiguous() + x = self.maxpool3(x) + + x_feat = x.reshape(T, B, -1, H//16, W//16).contiguous() + x = self.rpe_conv(x) + x = self.rpe_bn(x).reshape(T, B, -1, H//16, W//16).contiguous() + x = self.rpe_lif(x) + x = x + x_feat + + H, W = H // self.patch_size[0], W // self.patch_size[1] + return x, (H, W) + +class Spiking_vit(nn.Module): + def __init__(self, + img_size_h=128, img_size_w=128, patch_size=16, in_channels=2, num_classes=11, + embed_dims=[64, 128, 256], num_heads=[1, 2, 4], mlp_ratios=[4, 4, 4], qkv_bias=False, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, + depths=[6, 8, 6], sr_ratios=[8, 4, 2], k_values = [1,1,1,1,1,1,1,1] #[0.8,0.8,0.7,0.7,0.7,0.7,0.6,0.6] + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.k_values = k_values + print("k_values", k_values) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths)] # stochastic depth decay rule + + patch_embed = PatchEmbed(img_size_h=img_size_h, + img_size_w=img_size_w, + patch_size=patch_size, + in_channels=in_channels, + embed_dims=embed_dims) + num_patches = patch_embed.num_patches + + block = nn.ModuleList([Block( + dim=embed_dims, num_heads=num_heads, mlp_ratio=mlp_ratios, qkv_bias=qkv_bias, + qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[j], + norm_layer=norm_layer, sr_ratio=sr_ratios, k_value=k_values[j]) + for j in range(depths)]) + + setattr(self, f"patch_embed", patch_embed) + setattr(self, f"block", block) + + # classification head 这里不需要脉冲,因为输入的是在T时长平均发射值 + self.head = nn.Linear(embed_dims, num_classes) if num_classes > 0 else nn.Identity() + self.apply(self._init_weights) + + @torch.jit.ignore + def _get_pos_embed(self, pos_embed, patch_embed, H, W): + if H * W == self.patch_embed1.num_patches: + return pos_embed + else: + return F.interpolate( + pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), + size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + + block = getattr(self, f"block") + patch_embed = getattr(self, f"patch_embed") + + x, (H, W) = patch_embed(x) + for blk in block: + x = blk(x) + return x.flatten(3).mean(3) + + def forward(self, x): + T = 4 + x = (x.unsqueeze(0)).repeat(T, 1, 1, 1, 1) + x = self.forward_features(x) + x = self.head(x.mean(0)) + return x + + +@register_model +def vit_snn(pretrained=False, **kwargs): + model = Spiking_vit( + img_size_h=224, img_size_w=224, + patch_size=16, embed_dims=512, num_heads=8, mlp_ratios=4, + in_channels=3, num_classes=1000, qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=8, sr_ratios=1, + **kwargs + ) + model.default_cfg = _cfg() + return model + + +from timm.models import create_model + +if __name__ == '__main__': + H = 128 + W = 128 + x = torch.randn(2, 3, 224, 224).cuda() + # new_patch = PatchEmbed() + # new_patch.cuda() + # y, _ = new_patch(x) + model = create_model( + 'vit_snn', + pretrained=False, + drop_rate=0, + drop_path_rate=0.1, + drop_block_rate=None, + ).cuda() + model.eval() + y = model(x) + print(y.shape) + print('Test Good!') \ No newline at end of file diff --git a/models/sterf_models.py b/models/sterf_models.py new file mode 100644 index 0000000000000000000000000000000000000000..31c0c14944c136643c89ea560ffd3ec3c34eacdd --- /dev/null +++ b/models/sterf_models.py @@ -0,0 +1,1461 @@ +import torch +import torchinfo +import torch.nn as nn +from timm.models.layers import to_2tuple, trunc_normal_, DropPath +from timm.models.registry import register_model +from timm.models.vision_transformer import _cfg +from einops.layers.torch import Rearrange +import torch.nn.functional as F +from functools import partial + +import os + +class Quant(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, i, min_value, max_value): + ctx.min = min_value + ctx.max = max_value + ctx.save_for_backward(i) + return torch.round(torch.clamp(i, min=min_value, max=max_value)) + + @staticmethod + @torch.cuda.amp.custom_fwd + def backward(ctx, grad_output): + grad_input = grad_output.clone() + i, = ctx.saved_tensors + grad_input[i < ctx.min] = 0 + grad_input[i > ctx.max] = 0 + return grad_input, None, None + +class MultiSpike(nn.Module): + def __init__( + self, + min_value=0, + max_value=8, + Norm=None, + ): + super().__init__() + if Norm == None: + self.Norm = max_value + else: + self.Norm = Norm + self.min_value = min_value + self.max_value = max_value + + @staticmethod + def spike_function(x, min_value, max_value): + return Quant.apply(x, min_value, max_value) + + def __repr__(self): + return f"MultiSpike(Max_Value={self.max_value}, Min_Value={self.min_value}, Norm={self.Norm})" + + def forward(self, x): # B C H W + return self.spike_function(x, min_value=self.min_value, max_value=self.max_value) / (self.Norm) + +class BNAndPadLayer(nn.Module): + def __init__( + self, + pad_pixels, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + ): + super(BNAndPadLayer, self).__init__() + self.bn = nn.BatchNorm2d( + num_features, eps, momentum, affine, track_running_stats + ) + self.pad_pixels = pad_pixels + + def forward(self, input): + output = self.bn(input) + if self.pad_pixels > 0: + if self.bn.affine: + pad_values = ( + self.bn.bias.detach() + - self.bn.running_mean + * self.bn.weight.detach() + / torch.sqrt(self.bn.running_var + self.bn.eps) + ) + else: + pad_values = -self.bn.running_mean / torch.sqrt( + self.bn.running_var + self.bn.eps + ) + output = F.pad(output, [self.pad_pixels] * 4) + pad_values = pad_values.view(1, -1, 1, 1) + output[:, :, 0 : self.pad_pixels, :] = pad_values + output[:, :, -self.pad_pixels :, :] = pad_values + output[:, :, :, 0 : self.pad_pixels] = pad_values + output[:, :, :, -self.pad_pixels :] = pad_values + return output + + @property + def weight(self): + return self.bn.weight + + @property + def bias(self): + return self.bn.bias + + @property + def running_mean(self): + return self.bn.running_mean + + @property + def running_var(self): + return self.bn.running_var + + @property + def eps(self): + return self.bn.eps + +class RepConv(nn.Module): + def __init__( + self, + in_channel, + out_channel, + bias=False, + ): + super().__init__() + # hidden_channel = in_channel + conv1x1 = nn.Conv2d(in_channel, in_channel, 1, 1, 0, bias=False, groups=1) + bn = BNAndPadLayer(pad_pixels=1, num_features=in_channel) + conv3x3 = nn.Sequential( + nn.Conv2d(in_channel, in_channel, 3, 1, 0, groups=in_channel, bias=False), + nn.Conv2d(in_channel, out_channel, 1, 1, 0, groups=1, bias=False), + nn.BatchNorm2d(out_channel), + ) + + self.body = nn.Sequential(conv1x1, bn, conv3x3) + + def forward(self, x): + return self.body(x) + +class SepConv(nn.Module): + r""" + Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. + """ + + def __init__( + self, + dim, + expansion_ratio=2, + act2_layer=nn.Identity, + bias=False, + kernel_size=7, + padding=3, + ): + super().__init__() + med_channels = int(expansion_ratio * dim) + self.spike1 = MultiSpike() + self.pwconv1 = nn.Conv2d(dim, med_channels, kernel_size=1, stride=1, bias=bias) + self.bn1 = nn.BatchNorm2d(med_channels) + self.spike2 = MultiSpike() + self.dwconv = nn.Conv2d( + med_channels, + med_channels, + kernel_size=kernel_size, + padding=padding, + groups=med_channels, + bias=bias, + ) # depthwise conv + self.pwconv2 = nn.Conv2d(med_channels, dim, kernel_size=1, stride=1, bias=bias) + self.bn2 = nn.BatchNorm2d(dim) + + def forward(self, x): + + x = self.spike1(x) + + x = self.bn1(self.pwconv1(x)) + + x = self.spike2(x) + + x = self.dwconv(x) + x = self.bn2(self.pwconv2(x)) + return x + +class SepConv_Spike(nn.Module): + r""" + Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. + """ + + def __init__( + self, + dim, + expansion_ratio=2, + act2_layer=nn.Identity, + bias=False, + kernel_size=7, + padding=3, + ): + super().__init__() + med_channels = int(expansion_ratio * dim) + self.spike1 = MultiSpike() + self.pwconv1 = nn.Sequential( + nn.Conv2d(dim, med_channels, kernel_size=1, stride=1, bias=bias), + nn.BatchNorm2d(med_channels) + ) + self.spike2 = MultiSpike() + self.dwconv = nn.Sequential( + nn.Conv2d(med_channels, med_channels, kernel_size=kernel_size, padding=padding, groups=med_channels, bias=bias), + nn.BatchNorm2d(med_channels) + ) + self.spike3 = MultiSpike() + self.pwconv2 = nn.Sequential( + nn.Conv2d(med_channels, dim, kernel_size=1, stride=1, bias=bias), + nn.BatchNorm2d(dim) + ) + + def forward(self, x): + + x = self.spike1(x) + + x = self.pwconv1(x) + + x = self.spike2(x) + + x = self.dwconv(x) + + x = self.spike3(x) + + x = self.pwconv2(x) + return x + + + +class MS_ConvBlock(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4.0, + ): + super().__init__() + + self.Conv = SepConv(dim=dim) + + self.mlp_ratio = mlp_ratio + + self.spike1 = MultiSpike() + self.conv1 = nn.Conv2d( + dim, dim * mlp_ratio, kernel_size=3, padding=1, groups=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(dim * mlp_ratio) # 这里可以进行改进 + self.spike2 = MultiSpike() + self.conv2 = nn.Conv2d( + dim * mlp_ratio, dim, kernel_size=3, padding=1, groups=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(dim) # 这里可以进行改进 + + def forward(self, x): + B, C, H, W = x.shape + + x = self.Conv(x) + x + x_feat = x + x = self.spike1(x) + x = self.bn1(self.conv1(x)).reshape(B, self.mlp_ratio * C, H, W) + x = self.spike2(x) + x = self.bn2(self.conv2(x)).reshape(B, C, H, W) + x = x_feat + x + + return x + +class MS_ConvBlock_spike_SepConv(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4.0, + ): + super().__init__() + + self.Conv = SepConv_Spike(dim=dim) + + self.mlp_ratio = mlp_ratio + + self.spike1 = MultiSpike() + self.conv1 = nn.Conv2d( + dim, dim * mlp_ratio, kernel_size=3, padding=1, groups=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(dim * mlp_ratio) + self.spike2 = MultiSpike() + self.conv2 = nn.Conv2d( + dim * mlp_ratio, dim, kernel_size=3, padding=1, groups=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(dim) + + def forward(self, x): + B, C, H, W = x.shape + + x = self.Conv(x) + x + x_feat = x + x = self.spike1(x) + x = self.bn1(self.conv1(x)).reshape(B, self.mlp_ratio * C, H, W) + x = self.spike2(x) + x = self.bn2(self.conv2(x)).reshape(B, C, H, W) + x = x_feat + x + + return x + +class MS_ConvBlock_spike_MLP(nn.Module): + def __init__( + self, + dim, + mlp_ratio=8.0, + drop=0., + ): + super().__init__() + drop_probs = to_2tuple(drop) + + self.Conv = SepConv_Spike(dim=dim) + + self.mlp_ratio = mlp_ratio + + self.spike1 = MultiSpike() + self.fc1 = nn.Linear( + dim, dim * mlp_ratio, bias=False + ) + self.drop1 = nn.Dropout(drop_probs[0]) + self.spike2 = MultiSpike() + self.fc2 = nn.Linear( + dim * mlp_ratio, dim, bias=False + ) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + B, C, H, W = x.shape + + x = self.Conv(x) + x + x_feat = x + x = self.spike1(x) + x = self.drop1(self.fc1(x.reshape(B, H * W, C))).reshape(B, self.mlp_ratio * C, H, W) + x = self.spike2(x) + x = self.drop2(self.fc2(x.reshape(B, H * W, self.mlp_ratio * C))).reshape(B, C, H, W) + x = x_feat + x + + return x + +class MS_ConvBlock_spike_splash(nn.Module): + def __init__( + self, + dim, + mlp_ratio=8.0, + drop=0., + ): + super().__init__() + drop_probs = to_2tuple(drop) + + self.Conv = SepConv_Spike(dim=dim) + + self.mlp_ratio = mlp_ratio + + self.spike1 = MultiSpike() + self.fc1 = nn.Linear( + dim, dim * mlp_ratio, bias=False + ) + self.drop1 = nn.Dropout(drop_probs[0]) + self.spike2 = MultiSpike() + self.conv2 = nn.Conv2d( + dim * mlp_ratio, dim, kernel_size=3, padding=1, groups=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(dim) + + def forward(self, x): + B, C, H, W = x.shape + + x = self.Conv(x) + x + x_feat = x + x = self.spike1(x) + x = self.drop1(self.fc1(x.reshape(B, H * W, C))).reshape(B, self.mlp_ratio * C, H, W) + x = self.spike2(x) + x = self.bn2(self.conv2(x)).reshape(B, C, H, W) + x = x_feat + x + + return x + +class MS_MLP(nn.Module): + def __init__( + self, in_features, hidden_features=None, out_features=None, drop=0.0, layer=0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1) + self.fc1_bn = nn.BatchNorm1d(hidden_features) + self.fc1_spike = MultiSpike() + + self.fc2_conv = nn.Conv1d( + hidden_features, out_features, kernel_size=1, stride=1 + ) + self.fc2_bn = nn.BatchNorm1d(out_features) + self.fc2_spike = MultiSpike() + + self.c_hidden = hidden_features + self.c_output = out_features + + def forward(self, x): + B, C, H, W = x.shape + N = H * W + x = x.flatten(2) + x = self.fc1_spike(x) + x = self.fc1_conv(x) + x = self.fc1_bn(x).reshape(B, self.c_hidden, N).contiguous() + x = self.fc2_spike(x) + x = self.fc2_conv(x) + x = self.fc2_bn(x).reshape(B, C, H, W).contiguous() + + return x + + + +class MS_Attention_RepConv_qkv_id(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + sr_ratio=1, + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.num_heads = num_heads + self.scale = (dim//num_heads) ** -0.5 + + self.head_spike = MultiSpike() + + self.q_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim)) + + self.k_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim)) + + self.v_conv = nn.Sequential(RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim)) + + self.q_spike = MultiSpike() + + self.k_spike = MultiSpike() + + self.v_spike = MultiSpike() + + self.attn_spike = MultiSpike() + + self.proj_conv = nn.Sequential( + RepConv(dim, dim, bias=False), nn.BatchNorm2d(dim) + ) + + # self.proj_conv = nn.Sequential( + # nn.Conv2d(dim, dim, 1, 1, bias=False), nn.BatchNorm2d(dim) + # ) + + + def forward(self, x): + B, C, H, W = x.shape + N = H * W + + x = self.head_spike(x) + + q = self.q_conv(x) + k = self.k_conv(x) + v = self.v_conv(x) + + q = self.q_spike(q) + q = q.flatten(2) + q = ( + q.transpose(-1, -2) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + k = self.k_spike(k) + k = k.flatten(2) + k = ( + k.transpose(-1, -2) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + v = self.v_spike(v) + v = v.flatten(2) + v = ( + v.transpose(-1, -2) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + x = k.transpose(-2, -1) @ v + x = (q @ x) * self.scale + + x = x.transpose(2, 3).reshape(B, C, N).contiguous() + x = self.attn_spike(x) + x = x.reshape(B, C, H, W) + x = self.proj_conv(x).reshape(B, C, H, W) + + return x + +class MS_Attention_linear(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + sr_ratio=1, + lamda_ratio=1, + ): + super().__init__() + assert ( + dim % num_heads == 0 + ), f"dim {dim} should be divided by num_heads {num_heads}." + self.dim = dim + self.num_heads = num_heads + self.scale = (dim//num_heads) ** -0.5 + self.lamda_ratio = lamda_ratio + + self.head_spike = MultiSpike() + + self.q_conv = nn.Sequential(nn.Conv2d(dim, dim, 1, 1, bias=False), nn.BatchNorm2d(dim)) + + self.q_spike = MultiSpike() + + self.k_conv = nn.Sequential(nn.Conv2d(dim, dim, 1, 1, bias=False), nn.BatchNorm2d(dim)) + + self.k_spike = MultiSpike() + + self.v_conv = nn.Sequential(nn.Conv2d(dim, int(dim*lamda_ratio), 1, 1, bias=False), nn.BatchNorm2d(int(dim*lamda_ratio))) + + self.v_spike = MultiSpike() + + self.attn_spike = MultiSpike() + + + self.proj_conv = nn.Sequential( + nn.Conv2d(dim*lamda_ratio, dim, 1, 1, bias=False), nn.BatchNorm2d(dim) + ) + + + def forward(self, x): + B, C, H, W = x.shape + N = H * W + C_v = int(C*self.lamda_ratio) + + x = self.head_spike(x) + + q = self.q_conv(x) + k = self.k_conv(x) + v = self.v_conv(x) + + q = self.q_spike(q) + q = q.flatten(2) + q = ( + q.transpose(-1, -2) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + k = self.k_spike(k) + k = k.flatten(2) + k = ( + k.transpose(-1, -2) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + v = self.v_spike(v) + v = v.flatten(2) + v = ( + v.transpose(-1, -2) + .reshape(B, N, self.num_heads, C_v // self.num_heads) + .permute(0, 2, 1, 3) + .contiguous() + ) + + x = q @ k.transpose(-2, -1) + x = (x @ v) * (self.scale*2) + + x = x.transpose(2, 3).reshape(B, C_v, N).contiguous() + x = self.attn_spike(x) + x = x.reshape(B, C_v, H, W) + x = self.proj_conv(x).reshape(B, C, H, W) + + return x + + + + +class MS_Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + sr_ratio=1, + ): + super().__init__() + + self.attn = MS_Attention_RepConv_qkv_id( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) + + def forward(self, x): + x = x + self.attn(x) + x = x + self.mlp(x) + + return x + +class MS_Block_Spike_SepConv(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + sr_ratio=1, + init_values = 1e-6 + ): + super().__init__() + + self.conv = SepConv_Spike(dim=dim, kernel_size=3, padding=1) + + self.attn = MS_Attention_linear( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + sr_ratio=sr_ratio, + lamda_ratio=4, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) + + + def forward(self, x): + x = x + self.conv(x) + x = x + self.attn(x) + x = x + self.mlp(x) + + return x + + +class MS_DownSampling(nn.Module): + def __init__( + self, + in_channels=2, + embed_dims=256, + kernel_size=3, + stride=2, + padding=1, + first_layer=True, + T=None, + ): + super().__init__() + + self.encode_conv = nn.Conv2d( + in_channels, + embed_dims, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + self.encode_bn = nn.BatchNorm2d(embed_dims) + self.first_layer = first_layer + if not first_layer: + self.encode_spike = MultiSpike() + + def forward(self, x): + + if hasattr(self, "encode_spike"): + x = self.encode_spike(x) + x = self.encode_conv(x) + x = self.encode_bn(x) + + return x + + + +class Spiking_vit_MetaFormer_Spike_SepConv(nn.Module): + def __init__( + self, + img_size_h=128, + img_size_w=128, + patch_size=16, + in_channels=2, + num_classes=11, + embed_dim=[64, 128, 256], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[6, 8, 6], + sr_ratios=[8, 4, 2], + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths) + ] # stochastic depth decay rule + + self.downsample1_1 = MS_DownSampling( + in_channels=in_channels, + embed_dims=embed_dim[0] // 2, + kernel_size=7, + stride=2, + padding=3, + first_layer=True, + + ) + + self.ConvBlock1_1 = nn.ModuleList( + [MS_ConvBlock_spike_SepConv(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)] + ) + + self.downsample1_2 = MS_DownSampling( + in_channels=embed_dim[0] // 2, + embed_dims=embed_dim[0], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock1_2 = nn.ModuleList( + [MS_ConvBlock_spike_SepConv(dim=embed_dim[0], mlp_ratio=mlp_ratios)] + ) + + self.downsample2 = MS_DownSampling( + in_channels=embed_dim[0], + embed_dims=embed_dim[1], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock2_1 = nn.ModuleList( + [MS_ConvBlock_spike_SepConv(dim=embed_dim[1], mlp_ratio=mlp_ratios)] + ) + + self.ConvBlock2_2 = nn.ModuleList( + [MS_ConvBlock_spike_SepConv(dim=embed_dim[1], mlp_ratio=mlp_ratios)] + ) + + self.downsample3 = MS_DownSampling( + in_channels=embed_dim[1], + embed_dims=embed_dim[2], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.block3 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[2], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(6) + ] + ) + + self.downsample4 = MS_DownSampling( + in_channels=embed_dim[2], + embed_dims=embed_dim[3], + kernel_size=3, + stride=1, + padding=1, + first_layer=False, + + ) + + self.block4 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[3], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(2) + ] + ) + + self.head = ( + nn.Linear(embed_dim[3], num_classes) if num_classes > 0 else nn.Identity() + ) + self.spike = MultiSpike(Norm=1) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + x = self.downsample1_1(x) + for blk in self.ConvBlock1_1: + x = blk(x) + x = self.downsample1_2(x) + for blk in self.ConvBlock1_2: + x = blk(x) + + x = self.downsample2(x) + for blk in self.ConvBlock2_1: + x = blk(x) + for blk in self.ConvBlock2_2: + x = blk(x) + + x = self.downsample3(x) + for blk in self.block3: + x = blk(x) + + x = self.downsample4(x) + for blk in self.block4: + x = blk(x) + + return x # T,B,C,N + + def forward(self, x): + x = self.forward_features(x) # B,C,H,W + x = x.flatten(2).mean(2) + x = self.spike(x) + x = self.head(x) + return x + +class Spiking_vit_MetaFormer_Spike_SepConv_ChannelMLP(nn.Module): + def __init__( + self, + img_size_h=128, + img_size_w=128, + in_channels=1, + num_classes=11, + embed_dim=[64, 128, 256], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.2, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[6, 8, 6], + sr_ratios=[8, 4, 2], + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths) + ] # stochastic depth decay rule + + self.downsample1_1 = MS_DownSampling( + in_channels=in_channels, + embed_dims=embed_dim[0] // 2, + kernel_size=7, + stride=2, + padding=3, + first_layer=True, + + ) + + self.ConvBlock1_1 = nn.ModuleList( + [MS_ConvBlock_spike_MLP( + dim=embed_dim[0] // 2, + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample1_2 = MS_DownSampling( + in_channels=embed_dim[0] // 2, + embed_dims=embed_dim[0], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock1_2 = nn.ModuleList( + [MS_ConvBlock_spike_MLP( + dim=embed_dim[0], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample2 = MS_DownSampling( + in_channels=embed_dim[0], + embed_dims=embed_dim[1], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock2_1 = nn.ModuleList( + [MS_ConvBlock_spike_MLP( + dim=embed_dim[1], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.ConvBlock2_2 = nn.ModuleList( + [MS_ConvBlock_spike_MLP( + dim=embed_dim[1], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample3 = MS_DownSampling( + in_channels=embed_dim[1], + embed_dims=embed_dim[2], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.block3 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[2], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(6) + ] + ) + + self.downsample4 = MS_DownSampling( + in_channels=embed_dim[2], + embed_dims=embed_dim[3], + kernel_size=3, + stride=1, + padding=1, + first_layer=False, + + ) + + self.block4 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[3], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(2) + ] + ) + + self.head = ( + nn.Linear(embed_dim[3], num_classes) if num_classes > 0 else nn.Identity() + ) + self.spike = MultiSpike(Norm=1) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + x = self.downsample1_1(x) + for blk in self.ConvBlock1_1: + x = blk(x) + x = self.downsample1_2(x) + for blk in self.ConvBlock1_2: + x = blk(x) + + x = self.downsample2(x) + for blk in self.ConvBlock2_1: + x = blk(x) + for blk in self.ConvBlock2_2: + x = blk(x) + + x = self.downsample3(x) + for blk in self.block3: + x = blk(x) + + x = self.downsample4(x) + for blk in self.block4: + x = blk(x) + + return x # T,B,C,N + + def forward(self, x): + x = self.forward_features(x) # B,C,H,W + x = x.flatten(2).mean(2) + x = self.spike(x) + x = self.head(x) + return x + + +class Spiking_vit_MetaFormer_Spike_SepConv_splash(nn.Module): + def __init__( + self, + img_size_h=128, + img_size_w=128, + in_channels=1, + num_classes=11, + embed_dim=[64, 128, 256], + num_heads=[1, 2, 4], + mlp_ratios=[4, 4, 4], + qkv_bias=False, + qk_scale=None, + drop_rate=0.2, + attn_drop_rate=0.0, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + depths=[6, 8, 6], + sr_ratios=[8, 4, 2], + ): + super().__init__() + self.num_classes = num_classes + self.depths = depths + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depths) + ] # stochastic depth decay rule + + self.downsample1_1 = MS_DownSampling( + in_channels=in_channels, + embed_dims=embed_dim[0] // 2, + kernel_size=7, + stride=2, + padding=3, + first_layer=True, + + ) + + self.ConvBlock1_1 = nn.ModuleList( + [MS_ConvBlock_spike_splash( + dim=embed_dim[0] // 2, + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample1_2 = MS_DownSampling( + in_channels=embed_dim[0] // 2, + embed_dims=embed_dim[0], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock1_2 = nn.ModuleList( + [MS_ConvBlock_spike_splash( + dim=embed_dim[0], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample2 = MS_DownSampling( + in_channels=embed_dim[0], + embed_dims=embed_dim[1], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.ConvBlock2_1 = nn.ModuleList( + [MS_ConvBlock_spike_splash( + dim=embed_dim[1], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.ConvBlock2_2 = nn.ModuleList( + [MS_ConvBlock_spike_splash( + dim=embed_dim[1], + mlp_ratio=mlp_ratios, + drop=drop_rate, + )] + ) + + self.downsample3 = MS_DownSampling( + in_channels=embed_dim[1], + embed_dims=embed_dim[2], + kernel_size=3, + stride=2, + padding=1, + first_layer=False, + ) + + self.block3 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[2], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(6) + ] + ) + + self.downsample4 = MS_DownSampling( + in_channels=embed_dim[2], + embed_dims=embed_dim[3], + kernel_size=3, + stride=1, + padding=1, + first_layer=False, + + ) + + self.block4 = nn.ModuleList( + [ + MS_Block_Spike_SepConv( + dim=embed_dim[3], + num_heads=num_heads, + mlp_ratio=mlp_ratios, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[j], + norm_layer=norm_layer, + sr_ratio=sr_ratios, + + ) + for j in range(2) + ] + ) + + self.head = ( + nn.Linear(embed_dim[3], num_classes) if num_classes > 0 else nn.Identity() + ) + self.spike = MultiSpike(Norm=1) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward_features(self, x): + x = self.downsample1_1(x) + for blk in self.ConvBlock1_1: + x = blk(x) + x = self.downsample1_2(x) + for blk in self.ConvBlock1_2: + x = blk(x) + + x = self.downsample2(x) + for blk in self.ConvBlock2_1: + x = blk(x) + for blk in self.ConvBlock2_2: + x = blk(x) + + x = self.downsample3(x) + for blk in self.block3: + x = blk(x) + + x = self.downsample4(x) + for blk in self.block4: + x = blk(x) + + return x # T,B,C,N + + def forward(self, x): + x = self.forward_features(x) # B,C,H,W + x = x.flatten(2).mean(2) + x = self.spike(x) + x = self.head(x) + return x + +def Efficient_Spiking_Transformer_l(**kwargs): + #19.0M + model = Spiking_vit_MetaFormer_Spike_SepConv( + img_size_h=224, + img_size_w=224, + patch_size=16, + embed_dim=[64, 128, 256, 360], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def Efficient_Spiking_Transformer_m(**kwargs): + #10.0M + model = Spiking_vit_MetaFormer_Spike_SepConv( + img_size_h=224, + img_size_w=224, + patch_size=16, + embed_dim=[48, 96, 192, 240], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + + +def Efficient_Spiking_Transformer_s(**kwargs): + #5.1M + model = Spiking_vit_MetaFormer_Spike_SepConv( + img_size_h=224, + img_size_w=224, + patch_size=16, + embed_dim=[32, 64, 128, 192], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def Efficient_Spiking_Transformer_t(**kwargs): + model = Spiking_vit_MetaFormer_Spike_SepConv( + img_size_h=224, + img_size_w=224, + patch_size=16, + embed_dim=[24, 48, 96, 128], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_s_channelmlp(**kwargs): + #5.1M + model = Spiking_vit_MetaFormer_Spike_SepConv_ChannelMLP( + img_size_h=224, + img_size_w=224, + embed_dim=[32, 64, 128, 192], + num_heads=8, + mlp_ratios=6, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_l_channelmlp(**kwargs): + # 16.56M + model = Spiking_vit_MetaFormer_Spike_SepConv_ChannelMLP( + img_size_h=224, + img_size_w=224, + embed_dim=[64, 128, 256, 360], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_m_channelmlp(**kwargs): + # 10M? + model = Spiking_vit_MetaFormer_Spike_SepConv_ChannelMLP( + img_size_h=224, + img_size_w=224, + embed_dim=[48, 96, 192, 240], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_s_splash(**kwargs): + #4.8M + model = Spiking_vit_MetaFormer_Spike_SepConv_splash( + img_size_h=224, + img_size_w=224, + embed_dim=[32, 64, 128, 192], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_m_splash(**kwargs): + #4.8M + model = Spiking_vit_MetaFormer_Spike_SepConv_splash( + img_size_h=224, + img_size_w=224, + embed_dim=[48, 96, 192, 240], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + +def sdtv3_l_splash(**kwargs): + #4.8M + model = Spiking_vit_MetaFormer_Spike_SepConv_splash( + img_size_h=224, + img_size_w=224, + embed_dim=[64, 128, 256, 360], + num_heads=8, + mlp_ratios=4, + in_channels=3, + num_classes=1000, + qkv_bias=False, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=8, + sr_ratios=1, + **kwargs, + ) + return model + + +from timm.models import create_model +# os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +if __name__ == "__main__": + model = sdtv3_s_channelmlp() + print(model) + x = torch.randn(1,3,224,224) + y = model(x) + torchinfo.summary(model, (1, 3, 224, 224),device='cpu') diff --git a/models/util/__init__.py b/models/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/util/__pycache__/__init__.cpython-311.pyc b/models/util/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33d06d8091a1b3051531f2248a1526ccba29126f Binary files /dev/null and b/models/util/__pycache__/__init__.cpython-311.pyc differ diff --git a/models/util/__pycache__/__init__.cpython-312.pyc b/models/util/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27978a831f56eea3a24b67ada20b092043ede9d7 Binary files /dev/null and b/models/util/__pycache__/__init__.cpython-312.pyc differ diff --git a/models/util/__pycache__/pos_embed.cpython-310.pyc b/models/util/__pycache__/pos_embed.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..368995825f79ecce73ea53270a08a5dbb2d1bbc1 Binary files /dev/null and b/models/util/__pycache__/pos_embed.cpython-310.pyc differ diff --git a/models/util/__pycache__/pos_embed.cpython-311.pyc b/models/util/__pycache__/pos_embed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a2690b1476a448053882427c521d732b5c61a10 Binary files /dev/null and b/models/util/__pycache__/pos_embed.cpython-311.pyc differ diff --git a/models/util/__pycache__/pos_embed.cpython-312.pyc b/models/util/__pycache__/pos_embed.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4168fbb910e5af1e60bac5c509c47d8c727450d5 Binary files /dev/null and b/models/util/__pycache__/pos_embed.cpython-312.pyc differ diff --git a/models/util/crop.py b/models/util/crop.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb26125cca771791b0c5eea2f1c1fabcca0348b --- /dev/null +++ b/models/util/crop.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch + +from torchvision import transforms +from torchvision.transforms import functional as F + + +class RandomResizedCrop(transforms.RandomResizedCrop): + """ + RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. + This may lead to results different with torchvision's version. + Following BYOL's TF code: + https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 + """ + @staticmethod + def get_params(img, scale, ratio): + width, height = F._get_image_size(img) + area = height * width + + target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() + log_ratio = torch.log(torch.tensor(ratio)) + aspect_ratio = torch.exp( + torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + w = min(w, width) + h = min(h, height) + + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + + return i, j, h, w \ No newline at end of file diff --git a/models/util/datasets.py b/models/util/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..179d31280cb9ed7e79f07957e2477611b586b6f0 --- /dev/null +++ b/models/util/datasets.py @@ -0,0 +1,137 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +import os +import PIL + +from torchvision import datasets, transforms +import torch +from timm.data import create_transform +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + +def build_dataset(is_train, args): + transform = build_transform(is_train, args) + root = os.path.join(args.data_path, 'train' if is_train else 'val') + dataset = datasets.ImageFolder(root, transform=transform) + return dataset + +def build_dataset_full(is_train,args): + transform = build_transform(is_train, args) + + train_set = datasets.ImageFolder(root=os.path.join(args.data_path, 'train'), + transform=transform) + valid_set = datasets.ImageFolder(root=os.path.join(args.data_path, 'val'), + transform=transform) + + full_set = torch.utils.data.ConcatDataset([train_set, valid_set]) + print(full_set) + return full_set + +def build_transform(is_train, args): + mean = IMAGENET_DEFAULT_MEAN + std = IMAGENET_DEFAULT_STD + # train transform + if is_train: + # this should always dispatch to transforms_imagenet_train + transform = create_transform( + input_size=args.input_size, + is_training=True, + color_jitter=args.color_jitter, + auto_augment=args.aa, + interpolation='bicubic', + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + mean=mean, + std=std, + ) + # if args.three_aug: + + # secondary_tfl = transforms.RandomChoice([gray_scale(p=1.0), + # Solarization(p=1.0), + # GaussianBlur(p=1.0)]) + # transform = transforms.Compose([transform,secondary_tfl]) + return transform + + # eval transform + t = [] + if args.input_size <= 224: + crop_pct = 224 / 232 + else: + crop_pct = 1.0 + size = int(args.input_size / crop_pct) + t.append( + transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(args.input_size)) + + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(mean, std)) + return transforms.Compose(t) +import torch +from torchvision import transforms + +import numpy as np +from torchvision import datasets, transforms +import random + + + +from PIL import ImageFilter, ImageOps +import torchvision.transforms.functional as TF + + +class GaussianBlur(object): + """ + Apply Gaussian Blur to the PIL image. + """ + def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): + self.prob = p + self.radius_min = radius_min + self.radius_max = radius_max + + def __call__(self, img): + do_it = random.random() <= self.prob + if not do_it: + return img + + img = img.filter( + ImageFilter.GaussianBlur( + radius=random.uniform(self.radius_min, self.radius_max) + ) + ) + return img + +class Solarization(object): + """ + Apply Solarization to the PIL image. + """ + def __init__(self, p=0.2): + self.p = p + + def __call__(self, img): + if random.random() < self.p: + return ImageOps.solarize(img) + else: + return img + +class gray_scale(object): + """ + Apply Solarization to the PIL image. + """ + def __init__(self, p=0.2): + self.p = p + self.transf = transforms.Grayscale(3) + + def __call__(self, img): + if random.random() < self.p: + return self.transf(img) + else: + return img diff --git a/models/util/kd_loss.py b/models/util/kd_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e8db8f6688bbf0f4d4cc65e4ff58f94cbac0862d --- /dev/null +++ b/models/util/kd_loss.py @@ -0,0 +1,85 @@ +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +""" +Implements the knowledge distillation loss +""" +import torch +from torch.nn import functional as F + + +class DistillationLoss(torch.nn.Module): + """ + This module wraps a standard criterion and adds an extra knowledge distillation loss by + taking a teacher model prediction and using it as additional supervision. + """ + + def __init__( + self, + base_criterion: torch.nn.Module, + teacher_model: torch.nn.Module, + distillation_type: str, + alpha: float, + tau: float, + ): + super().__init__() + self.base_criterion = base_criterion + self.teacher_model = teacher_model + assert distillation_type in ["none", "soft", "hard"] + self.distillation_type = distillation_type + self.alpha = alpha + self.tau = tau + + def forward(self, inputs, outputs, labels): + """ + Args: + inputs: The original inputs that are feed to the teacher model + outputs: the outputs of the model to be trained. It is expected to be + either a Tensor, or a Tuple[Tensor, Tensor], with the original output + in the first position and the distillation predictions as the second output + labels: the labels for the base criterion + """ + outputs_kd = None + if not isinstance(outputs, torch.Tensor): + # assume that the model outputs a tuple of [outputs, outputs_kd] + outputs, outputs_kd = outputs + base_loss = self.base_criterion(outputs, labels) + if self.distillation_type == "none": + return base_loss + + if outputs_kd is None: + raise ValueError( + "When knowledge distillation is enabled, the model is " + "expected to return a Tuple[Tensor, Tensor] with the output of the " + "class_token and the dist_token" + ) + # don't backprop throught the teacher + with torch.no_grad(): + teacher_outputs = self.teacher_model(inputs) + + if self.distillation_type == "soft": + T = self.tau + # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 + # with slight modifications + distillation_loss = ( + F.kl_div( + F.log_softmax(outputs_kd / T, dim=1), + # We provide the teacher's targets in log probability because we use log_target=True + # (as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) + # but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. + F.log_softmax(teacher_outputs / T, dim=1), + reduction="sum", + log_target=True, + ) + * (T * T) + / outputs_kd.numel() + ) + # We divide by outputs_kd.numel() to have the legacy PyTorch behavior. + # But we also experiments output_kd.size(0) + # see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details + elif self.distillation_type == "hard": + distillation_loss = F.cross_entropy( + outputs_kd, teacher_outputs.argmax(dim=1) + ) + + loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha + return loss diff --git a/models/util/lars.py b/models/util/lars.py new file mode 100644 index 0000000000000000000000000000000000000000..509c5f65b7f68423343121d5676d05ce32d5a6c0 --- /dev/null +++ b/models/util/lars.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# LARS optimizer, implementation from MoCo v3: +# https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- + +import torch + + +class LARS(torch.optim.Optimizer): + """ + LARS optimizer, no rate scaling or weight decay for parameters <= 1D. + """ + def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): + defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self): + for g in self.param_groups: + for p in g['params']: + dp = p.grad + + if dp is None: + continue + + if p.ndim > 1: # if not normalization gamma/beta or bias + dp = dp.add(p, alpha=g['weight_decay']) + param_norm = torch.norm(p) + update_norm = torch.norm(dp) + one = torch.ones_like(param_norm) + q = torch.where(param_norm > 0., + torch.where(update_norm > 0, + (g['trust_coefficient'] * param_norm / update_norm), one), + one) + dp = dp.mul(q) + + param_state = self.state[p] + if 'mu' not in param_state: + param_state['mu'] = torch.zeros_like(p) + mu = param_state['mu'] + mu.mul_(g['momentum']).add_(dp) + p.add_(mu, alpha=-g['lr']) \ No newline at end of file diff --git a/models/util/loss.py b/models/util/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0669d1a6d2a822a97045e71acafa741f9ca0105e --- /dev/null +++ b/models/util/loss.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn +import random +import os +import numpy as np +import logging +def TET_loss(outputs, labels, criterion, means, lamb): + print('using TET') + T = outputs.size(1) + Loss_es = 0 + for t in range(T): + Loss_es += criterion(outputs[t, ...], labels) + Loss_es = Loss_es / T # L_TET + if lamb != 0: + MMDLoss = torch.nn.MSELoss() + y = torch.zeros_like(outputs).fill_(means) + Loss_mmd = MMDLoss(outputs, y) # L_mse + else: + Loss_mmd = 0 + return (1 - lamb) * Loss_es + lamb * Loss_mmd # L_Total \ No newline at end of file diff --git a/models/util/lr_decay.py b/models/util/lr_decay.py new file mode 100644 index 0000000000000000000000000000000000000000..615b2fe35a6e2109c053b4cb133632f89886909f --- /dev/null +++ b/models/util/lr_decay.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ELECTRA https://github.com/google-research/electra +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import json + + +def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): + """ + Parameter groups for layer-wise lr decay + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + param_group_names = {} + param_groups = {} + + num_layers = len(model.blocks) + 1 + + layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) + + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if p.ndim == 1 or n in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0. + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = get_layer_id_for_vit(n, num_layers) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_group_names: + this_scale = layer_scales[layer_id] + + param_group_names[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["params"].append(n) + param_groups[group_name]["params"].append(p) + + # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + + return list(param_groups.values()) + + +def get_layer_id_for_vit(name, num_layers): + """ + Assign a parameter with its layer id + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + """ + if name in ['cls_token', 'pos_embed']: + return 0 + elif name.startswith('patch_embed'): + return 0 + elif name.startswith('blocks'): + return int(name.split('.')[1]) + 1 + else: + return num_layers diff --git a/models/util/lr_decay_spikformer.py b/models/util/lr_decay_spikformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a3626f481ddba4435aa0129f2d9f08647e95dec6 --- /dev/null +++ b/models/util/lr_decay_spikformer.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# ELECTRA https://github.com/google-research/electra +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import json + +import json +import pdb + + + +def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): + """ + Parameter groups for layer-wise lr decay + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 + """ + param_group_names = {} + param_groups = {} + # num_layers = len(model.block3) + 1 + num_layers = model.depths +1 + layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) + print(layer_scales,len(layer_scales)) + for n, p in model.named_parameters(): + if not p.requires_grad: + continue + + # no decay: all 1D parameters and model specific ones + if p.ndim == 1 or n in no_weight_decay_list: + g_decay = "no_decay" + this_decay = 0. + else: + g_decay = "decay" + this_decay = weight_decay + + layer_id = get_layer_id_for_vit(n, num_layers) + group_name = "layer_%d_%s" % (layer_id, g_decay) + + if group_name not in param_group_names: + this_scale = layer_scales[layer_id] + + param_group_names[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + param_groups[group_name] = { + "lr_scale": this_scale, + "weight_decay": this_decay, + "params": [], + } + + param_group_names[group_name]["params"].append(n) + param_groups[group_name]["params"].append(p) + + print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) + + return list(param_groups.values()) + + +def get_layer_id_for_vit(name, num_layers): + """ + Assign a parameter with its layer id + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + """ + if name in ['cls_token', 'pos_embed']: + return 0 + elif name.startswith('downsample1_1'): + return 0 + elif name.startswith('ConvBlock1_1'): + return 0 + elif name.startswith('downsample1_2'): + return 0 + elif name.startswith('ConvBlock1_2'): + return 0 + elif name.startswith('downsample2'): + return 0 + elif name.startswith('ConvBlock2_1'): + return 0 + elif name.startswith('ConvBlock2_2'): + return 0 + elif name.startswith('downsample3'): + return 0 + elif name.startswith('block3'): + return int(name.split('.')[1]) + 1 + elif name.startswith('block4'): + return int(name.split('.')[1]) + 1 + else: + return num_layers + + diff --git a/models/util/lr_sched.py b/models/util/lr_sched.py new file mode 100644 index 0000000000000000000000000000000000000000..4cb682bebbce25ea1df70119928faa5fc9a6ab02 --- /dev/null +++ b/models/util/lr_sched.py @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ + (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr diff --git a/models/util/misc.py b/models/util/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..36efc0786d9e3e8a2785c342feaf724d839bafaa --- /dev/null +++ b/models/util/misc.py @@ -0,0 +1,349 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist +from torch import inf + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] + for checkpoint_path in checkpoint_paths: + to_save = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + } + + save_on_master(to_save, checkpoint_path) + else: + client_state = {'epoch': epoch} + model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + print("Resume checkpoint %s" % args.resume) + if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): + optimizer.load_state_dict(checkpoint['optimizer']) + args.start_epoch = checkpoint['epoch'] + 1 + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + print("With optim & sched!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + maxk = min(max(topk), output.size()[1]) + batch_size = target.size(0) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk] diff --git a/models/util/pos_embed.py b/models/util/pos_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..cc6ed041bc028063e43e36aa68473d0cddbae20c --- /dev/null +++ b/models/util/pos_embed.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# Position embedding utils +# -------------------------------------------------------- + +import numpy as np + +import torch + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed diff --git a/models/vit.py b/models/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..358d17d17370f90ade32b0de979ac9f1e61a752a --- /dev/null +++ b/models/vit.py @@ -0,0 +1,4512 @@ +""" Vision Transformer (ViT) in PyTorch + +A PyTorch implement of Vision Transformers as described in: + +'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' + - https://arxiv.org/abs/2010.11929 + +`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` + - https://arxiv.org/abs/2106.10270 + +`FlexiViT: One Model for All Patch Sizes` + - https://arxiv.org/abs/2212.08013 + +The official jax code is released and available at + * https://github.com/google-research/vision_transformer + * https://github.com/google-research/big_vision + +Acknowledgments: + * The paper authors for releasing code and weights, thanks! + * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch + * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT + * Bert reference code checks against Huggingface Transformers and Tensorflow Bert + +Hacked together by / Copyright 2020, Ross Wightman +""" +import copy +import logging +import math +from collections import OrderedDict +from functools import partial +from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.jit import Final + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ + OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, SwiGLU, \ + trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \ + get_act_layer, get_norm_layer, LayerType +from timm.models._builder import build_model_with_cfg +from timm.models._features import feature_take_indices +from timm.models._manipulate import named_apply, checkpoint_seq, adapt_input_conv +from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations + +__all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this + + +_logger = logging.getLogger(__name__) + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: Type[nn.Module] = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + mlp_layer: Type[nn.Module] = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + bias=proj_bias, + drop=proj_drop, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ResPostBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + mlp_layer: Type[nn.Module] = Mlp, + ) -> None: + super().__init__() + self.init_values = init_values + + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + bias=proj_bias, + drop=proj_drop, + ) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.init_weights() + + def init_weights(self) -> None: + # NOTE this init overrides that base model init with specific changes for the block type + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.norm1(self.attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + +class ParallelScalingBlock(nn.Module): + """ Parallel ViT block (MLP & Attention in parallel) + Based on: + 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442 + """ + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + mlp_layer: Optional[Type[nn.Module]] = None, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() + mlp_hidden_dim = int(mlp_ratio * dim) + in_proj_out_dim = mlp_hidden_dim + 3 * dim + + self.in_norm = norm_layer(dim) + self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias) + self.in_split = [mlp_hidden_dim] + [dim] * 3 + if qkv_bias: + self.register_buffer('qkv_bias', None) + self.register_parameter('mlp_bias', None) + else: + self.register_buffer('qkv_bias', torch.zeros(3 * dim), persistent=False) + self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim)) + + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias) + + self.mlp_drop = nn.Dropout(proj_drop) + self.mlp_act = act_layer() + self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias) + + self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + + # Combined MLP fc1 & qkv projections + y = self.in_norm(x) + if self.mlp_bias is not None: + # Concat constant zero-bias for qkv w/ trainable mlp_bias. + # Appears faster than adding to x_mlp separately + y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias))) + else: + y = self.in_proj(y) + x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1) + + # Dot product attention w/ qk norm + q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) + k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) + v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) + if self.fused_attn: + x_attn = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x_attn = attn @ v + x_attn = x_attn.transpose(1, 2).reshape(B, N, C) + x_attn = self.attn_out_proj(x_attn) + + # MLP activation, dropout, fc2 + x_mlp = self.mlp_act(x_mlp) + x_mlp = self.mlp_drop(x_mlp) + x_mlp = self.mlp_out_proj(x_mlp) + + # Add residual w/ drop path & layer scale applied + y = self.drop_path(self.ls(x_attn + x_mlp)) + x = x + y + return x + + +class ParallelThingsBlock(nn.Module): + """ Parallel ViT block (N parallel attention followed by N parallel MLP) + Based on: + `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + """ + def __init__( + self, + dim: int, + num_heads: int, + num_parallel: int = 2, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_bias: bool = True, + init_values: Optional[float] = None, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + mlp_layer: Type[nn.Module] = Mlp, + ) -> None: + super().__init__() + self.num_parallel = num_parallel + self.attns = nn.ModuleList() + self.ffns = nn.ModuleList() + for _ in range(num_parallel): + self.attns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('attn', Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + )), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + self.ffns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('mlp', mlp_layer( + dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + bias=proj_bias, + drop=proj_drop, + )), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + + def _forward_jit(self, x: torch.Tensor) -> torch.Tensor: + x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) + x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) + return x + + @torch.jit.ignore + def _forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + sum(attn(x) for attn in self.attns) + x = x + sum(ffn(x) for ffn in self.ffns) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return self._forward_jit(x) + else: + return self._forward(x) + + +def global_pool_nlc( + x: torch.Tensor, + pool_type: str = 'token', + num_prefix_tokens: int = 1, + reduce_include_prefix: bool = False, +): + if not pool_type: + return x + + if pool_type == 'token': + x = x[:, 0] # class token + else: + x = x if reduce_include_prefix else x[:, num_prefix_tokens:] + if pool_type == 'avg': + x = x.mean(dim=1) + elif pool_type == 'avgmax': + x = 0.5 * (x.amax(dim=1) + x.mean(dim=1)) + elif pool_type == 'max': + x = x.amax(dim=1) + else: + assert not pool_type, f'Unknown pool type {pool_type}' + + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_norm: bool = False, + proj_bias: bool = True, + init_values: Optional[float] = None, + class_token: bool = True, + pos_embed: str = 'learn', + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + final_norm: bool = True, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0., + pos_drop_rate: float = 0., + patch_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', + fix_init: bool = False, + embed_layer: Callable = PatchEmbed, + embed_norm_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Number of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT). + final_norm: Enable norm after transformer blocks, before head (standard in most ViT). + fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + fix_init: Apply weight initialization fix (scaling w/ layer index). + embed_layer: Patch embedding layer. + embed_norm_layer: Normalization layer to use / override in patch embed module. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') + assert class_token or global_pool != 'token' + assert pos_embed in ('', 'none', 'learn') + use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm + norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + embed_norm_layer = get_norm_layer(embed_norm_layer) + act_layer = get_act_layer(act_layer) or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) + if embed_norm_layer is not None: + embed_args['norm_layer'] = embed_norm_layer + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + if not pos_embed or pos_embed == 'none': + self.pos_embed = None + else: + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_bias=proj_bias, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)] + self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity() + + # Classifier Head + if global_pool == 'map': + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + ) + else: + self.attn_pool = None + self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + if fix_init: + self.fix_init_weight() + + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def init_weights(self, mode: str = '') -> None: + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + if self.reg_token is not None: + nn.init.normal_(self.reg_token, std=1e-6) + named_apply(get_init_weights_vit(mode, head_bias), self) + + def _init_weights(self, m: nn.Module) -> None: + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None: + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {'pos_embed', 'cls_token', 'dist_token'} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + if hasattr(self.patch_embed, 'set_grad_checkpointing'): + self.patch_embed.set_grad_checkpointing(enable) + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') + if global_pool == 'map' and self.attn_pool is None: + assert False, "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != 'map' and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def set_input_size( + self, + img_size: Optional[Tuple[int, int]] = None, + patch_size: Optional[Tuple[int, int]] = None, + ): + """Method updates the input image resolution, patch size + + Args: + img_size: New input resolution, if None current resolution is used + patch_size: New patch size, if None existing patch size is used + """ + prev_grid_size = self.patch_embed.grid_size + self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size) + if self.pos_embed is not None: + num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens + num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens + if num_new_tokens != self.pos_embed.shape[1]: + self.pos_embed = nn.Parameter(resample_abs_pos_embed( + self.pos_embed, + new_size=self.patch_embed.grid_size, + old_size=prev_grid_size, + num_prefix_tokens=num_prefix_tokens, + verbose=True, + )) + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.pos_embed is None: + return x.view(x.shape[0], -1, x.shape[-1]) + + if self.dynamic_img_size: + B, H, W, C = x.shape + prev_grid_size = self.patch_embed.grid_size + pos_embed = resample_abs_pos_embed( + self.pos_embed, + new_size=(H, W), + old_size=prev_grid_size, + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # forward pass + B, _, height, width = x.shape + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + else: + prefix_tokens = None + + if reshape: + # reshape to BCHW output format + H, W = self.patch_embed.dynamic_feat_size((height, width)) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.fc_norm = nn.Identity() + self.reset_classifier(0, '') + return take_indices + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, List[int], Tuple[int]] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> List[torch.Tensor]: + """ Intermediate layer accessor inspired by DINO / DINOv2 interface. + NOTE: This API is for backwards compat, favour using forward_intermediates() directly. + """ + return self.forward_intermediates( + x, n, + return_prefix_tokens=return_prefix_tokens, + norm=norm, + output_fmt='NCHW' if reshape else 'NLC', + intermediates_only=True, + ) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + return x + + def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor: + if self.attn_pool is not None: + x = self.attn_pool(x) + return x + pool_type = self.global_pool if pool_type is None else pool_type + x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = self.pool(x) + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + +class VisionTransformer_attn(nn.Module): + """ Vision Transformer attention blocks for testing + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token', + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = True, + qk_norm: bool = False, + proj_bias: bool = True, + init_values: Optional[float] = None, + class_token: bool = True, + pos_embed: str = 'learn', + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + final_norm: bool = True, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0., + pos_drop_rate: float = 0., + patch_drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', + fix_init: bool = False, + embed_layer: Callable = PatchEmbed, + embed_norm_layer: Optional[LayerType] = None, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Number of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT). + final_norm: Enable norm after transformer blocks, before head (standard in most ViT). + fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + fix_init: Apply weight initialization fix (scaling w/ layer index). + embed_layer: Patch embedding layer. + embed_norm_layer: Normalization layer to use / override in patch embed module. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') + assert class_token or global_pool != 'token' + assert pos_embed in ('', 'none', 'learn') + use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm + norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + embed_norm_layer = get_norm_layer(embed_norm_layer) + act_layer = get_act_layer(act_layer) or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt='NHWC')) + if embed_norm_layer is not None: + embed_args['norm_layer'] = embed_norm_layer + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + if not pos_embed or pos_embed == 'none': + self.pos_embed = None + else: + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.Sequential(*[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + proj_bias=proj_bias, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth)]) + self.feature_info = [ + dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)] + self.norm = norm_layer(embed_dim) if final_norm and not use_fc_norm else nn.Identity() + + # Classifier Head + if global_pool == 'map': + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + act_layer=act_layer, + ) + else: + self.attn_pool = None + self.fc_norm = norm_layer(embed_dim) if final_norm and use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + if fix_init: + self.fix_init_weight() + + def fix_init_weight(self): + def rescale(param, _layer_id): + param.div_(math.sqrt(2.0 * _layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def init_weights(self, mode: str = '') -> None: + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + if self.reg_token is not None: + nn.init.normal_(self.reg_token, std=1e-6) + named_apply(get_init_weights_vit(mode, head_bias), self) + + def _init_weights(self, m: nn.Module) -> None: + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None: + _load_weights(self, checkpoint_path, prefix) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {'pos_embed', 'cls_token', 'dist_token'} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + if hasattr(self.patch_embed, 'set_grad_checkpointing'): + self.patch_embed.set_grad_checkpointing(enable) + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map') + if global_pool == 'map' and self.attn_pool is None: + assert False, "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != 'map' and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def set_input_size( + self, + img_size: Optional[Tuple[int, int]] = None, + patch_size: Optional[Tuple[int, int]] = None, + ): + """Method updates the input image resolution, patch size + + Args: + img_size: New input resolution, if None current resolution is used + patch_size: New patch size, if None existing patch size is used + """ + prev_grid_size = self.patch_embed.grid_size + self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size) + if self.pos_embed is not None: + num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens + num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens + if num_new_tokens != self.pos_embed.shape[1]: + self.pos_embed = nn.Parameter(resample_abs_pos_embed( + self.pos_embed, + new_size=self.patch_embed.grid_size, + old_size=prev_grid_size, + num_prefix_tokens=num_prefix_tokens, + verbose=True, + )) + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.pos_embed is None: + return x.view(x.shape[0], -1, x.shape[-1]) + + if self.dynamic_img_size: + B, H, W, C = x.shape + prev_grid_size = self.patch_embed.grid_size + pos_embed = resample_abs_pos_embed( + self.pos_embed, + new_size=(H, W), + old_size=prev_grid_size, + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + return_prefix_tokens: bool = False, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + return_prefix_tokens: Return both prefix and spatial intermediate tokens + norm: Apply norm layer to all intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.' + reshape = output_fmt == 'NCHW' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + + # forward pass + B, _, height, width = x.shape + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + blocks = self.blocks + else: + blocks = self.blocks[:max_index + 1] + for i, blk in enumerate(blocks): + x = blk(x) + if i in take_indices: + # normalize intermediates with final norm layer if enabled + intermediates.append(self.norm(x) if norm else x) + + # process intermediates + if self.num_prefix_tokens: + # split prefix (e.g. class, distill) and spatial feature tokens + prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates] + intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates] + else: + prefix_tokens = None + + if reshape: + # reshape to BCHW output format + H, W = self.patch_embed.dynamic_feat_size((height, width)) + intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates] + if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None: + # return_prefix not support in torchscript due to poor type handling + intermediates = list(zip(intermediates, prefix_tokens)) + + if intermediates_only: + return intermediates + + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ): + """ Prune layers not required for specified intermediates. + """ + take_indices, max_index = feature_take_indices(len(self.blocks), indices) + self.blocks = self.blocks[:max_index + 1] # truncate blocks + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.fc_norm = nn.Identity() + self.reset_classifier(0, '') + return take_indices + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, List[int], Tuple[int]] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> List[torch.Tensor]: + """ Intermediate layer accessor inspired by DINO / DINOv2 interface. + NOTE: This API is for backwards compat, favour using forward_intermediates() directly. + """ + return self.forward_intermediates( + x, n, + return_prefix_tokens=return_prefix_tokens, + norm=norm, + output_fmt='NCHW' if reshape else 'NLC', + intermediates_only=True, + ) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + return x + + def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor: + if self.attn_pool is not None: + x = self.attn_pool(x) + return x + pool_type = self.global_pool if pool_type is None else pool_type + x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + x = self.pool(x) + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + # x = self.forward_head(x) + return x + +def init_weights_vit_timm(module: nn.Module, name: str = '') -> None: + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.0) -> None: + """ ViT weight initialization, matching JAX (Flax) impl """ + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def init_weights_vit_moco(module: nn.Module, name: str = '') -> None: + """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ + if isinstance(module, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1])) + nn.init.uniform_(module.weight, -val, val) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable: + if 'jax' in mode: + return partial(init_weights_vit_jax, head_bias=head_bias) + elif 'moco' in mode: + return init_weights_vit_moco + else: + return init_weights_vit_timm + + +def resize_pos_embed( + posemb: torch.Tensor, + posemb_new: torch.Tensor, + num_prefix_tokens: int = 1, + gs_new: Tuple[int, int] = (), + interpolation: str = 'bicubic', + antialias: bool = False, +) -> torch.Tensor: + """ Rescale the grid of position embeddings when loading from state_dict. + *DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed + """ + ntok_new = posemb_new.shape[1] - num_prefix_tokens + ntok_old = posemb.shape[1] - num_prefix_tokens + gs_old = [int(math.sqrt(ntok_old))] * 2 + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + return resample_abs_pos_embed( + posemb, gs_new, gs_old, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = '', load_bfloat16: bool = False) -> None: + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + if load_bfloat16: + import jax.numpy as jnp + import ml_dtypes + + def _n2p(_w, t=True, idx=None): + if idx is not None: + _w = _w[idx] + + if load_bfloat16: + _w = _w.view(ml_dtypes.bfloat16).astype(jnp.float32) + _w = np.array(_w) + + if _w.ndim == 4 and _w.shape[0] == _w.shape[1] == _w.shape[2] == 1: + _w = _w.flatten() + if t: + if _w.ndim == 4: + _w = _w.transpose([3, 2, 0, 1]) + elif _w.ndim == 3: + _w = _w.transpose([2, 0, 1]) + elif _w.ndim == 2: + _w = _w.transpose([1, 0]) + + _w = torch.from_numpy(_w) + return _w + + if load_bfloat16: + w = jnp.load(checkpoint_path) + else: + w = np.load(checkpoint_path) + + interpolation = 'bilinear' + antialias = False + big_vision = False + if not prefix: + if 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + elif 'params/embedding/kernel' in w: + prefix = 'params/' + big_vision = True + elif 'params/img/embedding/kernel' in w: + prefix = 'params/img/' + big_vision = True + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]: + embed_conv_w = resample_patch_embed( + embed_conv_w, + model.patch_embed.proj.weight.shape[-2:], + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + if model.cls_token is not None: + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + if big_vision: + pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) + else: + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) + pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if (isinstance(model.head, nn.Linear) and + f'{prefix}head/bias' in w and + model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]): + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + if model.attn_pool is not None: + block_prefix = f'{prefix}MAPHead_0/' + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' + model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) + model.attn_pool.kv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) + model.attn_pool.kv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) + model.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) + model.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) + model.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + model.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + for r in range(2): + getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) + getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) + + mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2) + for i, block in enumerate(model.blocks.children()): + if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w: + block_prefix = f'{prefix}Transformer/encoderblock/' + idx = i + else: + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + idx = None + mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx)) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx)) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx)) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx)) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx)) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_( + _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx)) + getattr(block.mlp, f'fc{r + 1}').bias.copy_( + _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx)) + + +def _convert_openai_clip( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, + prefix: str = 'visual.', +) -> Dict[str, torch.Tensor]: + out_dict = {} + swaps = [ + ('conv1', 'patch_embed.proj'), + ('positional_embedding', 'pos_embed'), + ('transformer.resblocks.', 'blocks.'), + ('ln_pre', 'norm_pre'), + ('ln_post', 'norm'), + ('ln_', 'norm'), + ('in_proj_', 'qkv.'), + ('out_proj', 'proj'), + ('mlp.c_fc', 'mlp.fc1'), + ('mlp.c_proj', 'mlp.fc2'), + ] + for k, v in state_dict.items(): + if not k.startswith(prefix): + continue + k = k.replace(prefix, '') + for sp in swaps: + k = k.replace(sp[0], sp[1]) + + if k == 'proj': + k = 'head.weight' + v = v.transpose(0, 1) + out_dict['head.bias'] = torch.zeros(v.shape[0]) + elif k == 'class_embedding': + k = 'cls_token' + v = v.unsqueeze(0).unsqueeze(1) + elif k == 'pos_embed': + v = v.unsqueeze(0) + out_dict[k] = v + return out_dict + + +def _convert_dinov2( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, +) -> Dict[str, torch.Tensor]: + import re + out_dict = {} + state_dict.pop("mask_token", None) + if 'register_tokens' in state_dict: + # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed) + out_dict['reg_token'] = state_dict.pop('register_tokens') + out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0] + out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:] + for k, v in state_dict.items(): + if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k): + out_dict[k.replace("w12", "fc1")] = v + continue + elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): + out_dict[k.replace("w3", "fc2")] = v + continue + out_dict[k] = v + return out_dict + + +def _convert_aimv2( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, +) -> Dict[str, torch.Tensor]: + out_dict = {} + for k, v in state_dict.items(): + k = k.replace('norm_1', 'norm1') + k = k.replace('norm_2', 'norm2') + k = k.replace('preprocessor.patchifier.', 'patch_embed.') + k = k.replace('preprocessor.pos_embed', 'pos_embed') + k = k.replace('trunk.', '') + k = k.replace('post_trunk_norm.', 'norm.') + k = k.replace('mlp.fc1', 'mlp.fc1_g') + k = k.replace('mlp.fc3', 'mlp.fc1_x') + out_dict[k] = v + return out_dict + + +def checkpoint_filter_fn( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, + adapt_layer_scale: bool = False, + interpolation: str = 'bicubic', + antialias: bool = True, +) -> Dict[str, torch.Tensor]: + """ convert patch embedding weight from manual patchify + linear proj to conv""" + import re + out_dict = {} + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('state_dict', state_dict) + prefix = '' + + if 'visual.class_embedding' in state_dict: + state_dict = _convert_openai_clip(state_dict, model) + elif 'module.visual.class_embedding' in state_dict: + state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.') + elif "mask_token" in state_dict: + state_dict = _convert_dinov2(state_dict, model) + elif "encoder" in state_dict: + # IJEPA, vit in an 'encoder' submodule + state_dict = state_dict['encoder'] + prefix = 'module.' + elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict: + # OpenCLIP model with timm vision encoder + prefix = 'visual.trunk.' + if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear): + # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) + out_dict['head.weight'] = state_dict['visual.head.proj.weight'] + out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) + elif 'preprocessor.patchifier.proj.weight' in state_dict: + state_dict = _convert_aimv2(state_dict, model) + + if prefix: + # filter on & remove prefix string from keys + state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} + + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + O, I, H, W = model.patch_embed.proj.weight.shape + if len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + if v.shape[-1] != W or v.shape[-2] != H: + v = resample_patch_embed( + v, + (H, W), + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: + # To resize pos embedding when using model at different size from pretrained weights + num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1) + v = resample_abs_pos_embed( + v, + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + interpolation=interpolation, + antialias=antialias, + verbose=True, + ) + elif adapt_layer_scale and 'gamma_' in k: + # remap layer-scale gamma into sub-module (deit3 models) + k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) + elif 'pre_logits' in k: + # NOTE representation layer removed as not used in latest 21k/1k pretrained weights + continue + out_dict[k] = v + return out_dict + + +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + return { + 'url': url, + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'crop_pct': 0.9, + 'interpolation': 'bicubic', + 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, + 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', + 'classifier': 'head', + **kwargs, + } + +default_cfgs = { + + # re-finetuned augreg 21k FT on in1k weights + 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(), + 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg( + hf_hub_id='timm/'), + + # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k + 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + + # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k + 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + hf_hub_id='timm/'), + 'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), + + # How to train your ViT (augreg) weights trained on in1k only + 'vit_small_patch16_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_small_patch16_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch32_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch32_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_224.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz', + hf_hub_id='timm/', + custom_load=True), + 'vit_base_patch16_384.augreg_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', + hf_hub_id='timm/', + custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), + + 'vit_large_patch14_224.untrained': _cfg(url=''), + 'vit_huge_patch14_224.untrained': _cfg(url=''), + 'vit_giant_patch14_224.untrained': _cfg(url=''), + 'vit_gigantic_patch14_224.untrained': _cfg(url=''), + + # patch models, imagenet21k (weights from official Google JAX impl), classifier not valid + 'vit_base_patch32_224.orig_in21k': _cfg( + #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth', + hf_hub_id='timm/', + num_classes=0), + 'vit_base_patch16_224.orig_in21k': _cfg( + #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth', + hf_hub_id='timm/', + num_classes=0), + 'vit_large_patch32_224.orig_in21k': _cfg( + #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', + hf_hub_id='timm/', + num_classes=0), + 'vit_large_patch16_224.orig_in21k': _cfg( + #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth', + hf_hub_id='timm/', + num_classes=0), + 'vit_huge_patch14_224.orig_in21k': _cfg( + hf_hub_id='timm/', + num_classes=0), + + # How to train your ViT (augreg) weights, pretrained on in21k + 'vit_tiny_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_small_patch32_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_small_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_base_patch32_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_base_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_base_patch8_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + 'vit_large_patch16_224.augreg_in21k': _cfg( + url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', + hf_hub_id='timm/', + custom_load=True, num_classes=21843), + + # SAM trained models (https://arxiv.org/abs/2106.01548) + 'vit_base_patch32_224.sam_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True, + hf_hub_id='timm/'), + 'vit_base_patch16_224.sam_in1k': _cfg( + url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True, + hf_hub_id='timm/'), + + # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only) + 'vit_small_patch16_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_small_patch8_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch16_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_base_patch8_224.dino': _cfg( + url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only) + 'vit_small_patch14_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_base_patch14_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_large_patch14_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_giant_patch14_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + + # DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only) + 'vit_small_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_base_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_large_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_giant_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth', + hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + + # ViT ImageNet-21K-P pretraining by MILL + 'vit_base_patch16_224_miil.in21k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', + hf_hub_id='timm/', + mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221), + 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth', + hf_hub_id='timm/', + mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), + + # Custom timm variants + 'vit_base_patch16_rpn_224.sw_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth', + hf_hub_id='timm/'), + 'vit_medium_patch16_gap_240.sw_in12k': _cfg( + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821), + 'vit_medium_patch16_gap_256.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_medium_patch16_gap_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'), + 'vit_base_patch16_gap_224': _cfg(), + + # CLIP pretrained image tower and related fine-tuned weights + 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)), + 'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)), + 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), + 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + + 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg( + # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', # FIXME weight exists, need to push + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), + 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + + 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg( + hf_hub_id='', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + + 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch16_clip_224.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + + 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821), + 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), + + 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), + + 'vit_base_patch32_clip_224.laion2b': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_base_patch16_clip_224.laion2b': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.laion2b': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768), + 'vit_huge_patch14_clip_224.laion2b': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_giant_patch14_clip_224.laion2b': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_gigantic_patch14_clip_224.laion2b': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280), + + 'vit_base_patch32_clip_224.laion400m_e32': _cfg( + hf_hub_id='timm/', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_base_patch16_clip_224.laion400m_e32': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_base_patch16_plus_clip_240.laion400m_e32': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 240, 240), crop_pct=1.0, num_classes=640), + 'vit_large_patch14_clip_224.laion400m_e32': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + + 'vit_base_patch32_clip_224.datacompxl': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_base_patch32_clip_256.datacompxl': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 256, 256), num_classes=512), + 'vit_base_patch16_clip_224.datacompxl': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.datacompxl': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + + 'vit_base_patch16_clip_224.dfn2b': _cfg( + hf_hub_id='timm/', + license='apple-ascl', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.dfn2b_s39b': _cfg( + hf_hub_id='timm/', + license='apple-ascl', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + 'vit_large_patch14_clip_224.dfn2b': _cfg( + hf_hub_id='timm/', + license='apple-ascl', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + 'vit_huge_patch14_clip_224.dfn5b': _cfg( + hf_hub_id='timm/', + license='apple-ascl', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_huge_patch14_clip_378.dfn5b': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + license='apple-ascl', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024), + + 'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + 'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_huge_patch14_clip_224.metaclip_altogether': _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_gigantic_patch14_clip_224.metaclip_2pt5b': _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280), + 'vit_base_patch32_clip_224.metaclip_400m': _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_base_patch16_clip_224.metaclip_400m': _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.metaclip_400m': _cfg( + hf_hub_id='timm/', + license='cc-by-nc-4.0', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + + 'vit_base_patch32_clip_224.openai': _cfg( + hf_hub_id='timm/', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_base_patch16_clip_224.openai': _cfg( + hf_hub_id='timm/', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_large_patch14_clip_224.openai': _cfg( + hf_hub_id='timm/', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + 'vit_large_patch14_clip_336.openai': _cfg( + hf_hub_id='timm/', + notes=('natively QuickGELU, use quickgelu model variant for original results',), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), num_classes=768), + + # experimental (may be removed) + 'vit_base_patch32_plus_256.untrained': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), + 'vit_small_patch16_36x1_224.untrained': _cfg(url=''), + 'vit_small_patch16_18x2_224.untrained': _cfg(url=''), + 'vit_base_patch16_18x2_224.untrained': _cfg(url=''), + + # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain + # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip + 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt', + hf_hub_id='timm/', license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt', + hf_hub_id='timm/', license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + 'eva_large_patch14_196.in22k_ft_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt', + hf_hub_id='timm/', license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 196, 196), crop_pct=1.0), + 'eva_large_patch14_336.in22k_ft_in1k': _cfg( + # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt', + hf_hub_id='timm/', license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'), + + 'flexivit_small.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_small.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_small.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + + 'flexivit_base.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_base.1000ep_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + 'flexivit_base.300ep_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + + 'flexivit_large.1200ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_large.600ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + 'flexivit_large.300ep_in1k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95), + + 'flexivit_base.patch16_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + 'flexivit_base.patch30_in21k': _cfg( + url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True, + hf_hub_id='timm/', + input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), + + 'vit_base_patch16_xp_224.untrained': _cfg(url=''), + 'vit_large_patch14_xp_224.untrained': _cfg(url=''), + 'vit_huge_patch14_xp_224.untrained': _cfg(url=''), + + 'vit_base_patch16_224.mae': _cfg( + url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth', + hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_large_patch16_224.mae': _cfg( + url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth', + hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_huge_patch14_224.mae': _cfg( + url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth', + hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + 'vit_huge_patch14_gap_224.in1k_ijepa': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_huge_patch14_gap_224.in22k_ijepa': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_huge_patch16_gap_448.in1k_ijepa': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + input_size=(3, 448, 448), crop_pct=1.0, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_giant_patch16_gap_224.in22k_ijepa': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + 'vit_base_patch32_siglip_256.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_base_patch16_siglip_224.v2_webli': _cfg( + hf_hub_id='timm/', + num_classes=0), + 'vit_base_patch16_siglip_224.webli': _cfg( + hf_hub_id='timm/', + num_classes=0), + 'vit_base_patch16_siglip_256.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_base_patch16_siglip_256.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_base_patch16_siglip_256.webli_i18n': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_base_patch16_siglip_384.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + 'vit_base_patch16_siglip_384.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + 'vit_base_patch16_siglip_512.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), + num_classes=0), + 'vit_base_patch16_siglip_512.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), + num_classes=0), + 'vit_large_patch16_siglip_256.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_large_patch16_siglip_256.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_large_patch16_siglip_384.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + 'vit_large_patch16_siglip_384.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + 'vit_large_patch16_siglip_512.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), + num_classes=0), + 'vit_so400m_patch14_siglip_224.v2_webli': _cfg( + hf_hub_id='timm/', + num_classes=0), + 'vit_so400m_patch14_siglip_224.webli': _cfg( + hf_hub_id='timm/', + num_classes=0), + 'vit_so400m_patch14_siglip_378.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 378, 378), + num_classes=0), + 'vit_so400m_patch14_siglip_378.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 378, 378), + num_classes=0), + 'vit_so400m_patch14_siglip_384.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + 'vit_so400m_patch16_siglip_256.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_so400m_patch16_siglip_256.webli_i18n': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_so400m_patch16_siglip_384.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + 'vit_so400m_patch16_siglip_512.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), + num_classes=0), + 'vit_giantopt_patch16_siglip_256.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_giantopt_patch16_siglip_384.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + + 'vit_base_patch32_siglip_gap_256.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_base_patch16_siglip_gap_224.v2_webli': _cfg( + hf_hub_id='timm/', + num_classes=0), + 'vit_base_patch16_siglip_gap_224.webli': _cfg( + hf_hub_id='timm/', + num_classes=0), + 'vit_base_patch16_siglip_gap_256.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_base_patch16_siglip_gap_256.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_base_patch16_siglip_gap_256.webli_i18n': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_base_patch16_siglip_gap_384.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + 'vit_base_patch16_siglip_gap_384.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + 'vit_base_patch16_siglip_gap_512.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), + num_classes=0), + 'vit_base_patch16_siglip_gap_512.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), + num_classes=0), + 'vit_large_patch16_siglip_gap_256.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_large_patch16_siglip_gap_256.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_large_patch16_siglip_gap_384.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + 'vit_large_patch16_siglip_gap_384.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + 'vit_large_patch16_siglip_gap_512.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), + num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.v2_webli': _cfg( + hf_hub_id='timm/', + num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.webli': _cfg( + hf_hub_id='timm/', + num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg( + hf_hub_id='timm/', + num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg( + hf_hub_id='timm/', + num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.pali2_3b_pt': _cfg( + hf_hub_id='timm/', + num_classes=0), + 'vit_so400m_patch14_siglip_gap_224.pali2_10b_pt': _cfg( + hf_hub_id='timm/', + num_classes=0), + # 'vit_so400m_patch14_siglip_gap_224.pali2_28b_pt': _cfg( + # hf_hub_id='google/paligemma2-28b-pt-224-jax', + # hf_hub_filename='pt_27b_224.npz', + # custom_load='hf', + # num_classes=0), + 'vit_so400m_patch14_siglip_gap_378.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 378, 378), + num_classes=0), + 'vit_so400m_patch14_siglip_gap_378.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 378, 378), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_384.webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali_refcoco_seg': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali_ocrvqa': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali2_3b_pt': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali2_10b_pt': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + # 'vit_so400m_patch14_siglip_gap_448.pali2_28b_pt': _cfg( + # hf_hub_id='google/paligemma2-28b-pt-448-jax', + # hf_hub_filename='pt_27b_448.npz', + # custom_load='hf', + # input_size=(3, 448, 448), crop_pct=1.0, + # num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali2_3b_docci': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_448.pali2_10b_docci': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg( + hf_hub_id='timm/', + input_size=(3, 896, 896), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_896.pali_refcoco_seg': _cfg( + hf_hub_id='timm/', + input_size=(3, 896, 896), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_896.pali_ocrvqa': _cfg( + hf_hub_id='timm/', + input_size=(3, 896, 896), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_896.pali2_3b_pt': _cfg( + hf_hub_id='timm/', + input_size=(3, 896, 896), crop_pct=1.0, + num_classes=0), + 'vit_so400m_patch14_siglip_gap_896.pali2_10b_pt': _cfg( + hf_hub_id='timm/', + input_size=(3, 896, 896), crop_pct=1.0, + num_classes=0), + # 'vit_so400m_patch14_siglip_gap_896.pali2_28b_pt': _cfg( + # hf_hub_id='google/paligemma2-28b-pt-896-jax', + # hf_hub_filename='pt_27b_896.npz', + # custom_load='hf', + # input_size=(3, 896, 896), crop_pct=1.0, + # num_classes=0), + 'vit_so400m_patch16_siglip_gap_256.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_so400m_patch16_siglip_gap_384.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + 'vit_so400m_patch16_siglip_gap_512.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), + num_classes=0), + 'vit_giantopt_patch16_siglip_gap_256.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), + num_classes=0), + 'vit_giantopt_patch16_siglip_gap_384.v2_webli': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), + num_classes=0), + + 'vit_so400m_patch14_siglip_378.webli_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', + ), + 'vit_so400m_patch14_siglip_gap_378.webli_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', + ), + + 'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg( + hf_hub_id='timm/', + license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_medium_patch32_clip_224.tinyclip_laion400m': _cfg( + hf_hub_id='timm/', + license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_medium_patch16_clip_224.tinyclip_yfcc15m': _cfg( + hf_hub_id='timm/', + license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_betwixt_patch32_clip_224.tinyclip_laion400m': _cfg( + hf_hub_id='timm/', + license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + + 'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_little_patch16_reg1_gap_256.sbb_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_medium_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_medium_patch16_reg4_gap_256.sbb_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_mediumd_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_betwixt_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_base_patch16_reg4_gap_256.untrained': _cfg( + input_size=(3, 256, 256)), + + 'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=0.95), + 'vit_so150m_patch16_reg4_gap_384.sbb_e250_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_so150m_patch16_reg4_map_256.untrained': _cfg( + input_size=(3, 256, 256)), + 'vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), crop_pct=1.0), + 'vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + input_size=(3, 256, 256), crop_pct=1.0), + 'vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), crop_pct=1.0), + 'vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'), + + 'vit_intern300m_patch14_448.ogvl_dist': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0, + ), + 'vit_intern300m_patch14_448.ogvl_2pt5': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0, + ), + + 'aimv2_large_patch14_224.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_224.apple_pt_dist': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_huge_patch14_224.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_1b_patch14_224.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_3b_patch14_224.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_336.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_336.apple_pt_dist': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_huge_patch14_336.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_1b_patch14_336.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_3b_patch14_336.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 336, 336), crop_pct=1.0, num_classes=0), + 'aimv2_large_patch14_448.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), + 'aimv2_huge_patch14_448.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), + 'aimv2_1b_patch14_448.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), + 'aimv2_3b_patch14_448.apple_pt': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl', + input_size=(3, 448, 448), crop_pct=1.0, num_classes=0), + + 'test_vit.r160_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 160, 160), crop_pct=0.95), + 'test_vit2.r160_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 160, 160), crop_pct=0.95), + 'test_vit3.r160_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 160, 160), crop_pct=0.95), + 'test_vit4.r160_in1k': _cfg( + input_size=(3, 160, 160), crop_pct=0.95), +} + +_quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]] +for n in _quick_gelu_cfgs: + # generate quickgelu default cfgs based on contents of notes field + c = copy.deepcopy(default_cfgs[n]) + if c['hf_hub_id'] == 'timm/': + c['hf_hub_id'] = 'timm/' + n # need to use non-quickgelu model name for hub id + default_cfgs[n.replace('_clip_', '_clip_quickgelu_')] = c +default_cfgs = generate_default_cfgs(default_cfgs) + + +def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformer: + out_indices = kwargs.pop('out_indices', 3) + if 'flexi' in variant: + # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed + # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation. + _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False) + else: + _filter_fn = checkpoint_filter_fn + + # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln? + strict = kwargs.pop('pretrained_strict', True) + if 'siglip' in variant and kwargs.get('global_pool', None) != 'map': + strict = False + + return build_model_with_cfg( + VisionTransformer, + variant, + pretrained, + pretrained_filter_fn=_filter_fn, + pretrained_strict=strict, + feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), + **kwargs, + ) + + +@register_model +def vit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Tiny (Vit-Ti/16) + """ + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vit_tiny_patch16_224_relu(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Tiny (Vit-Ti/16) + """ + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, act_layer='relu') + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vit_tiny_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Tiny (Vit-Ti/16) @ 384x384. + """ + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small (ViT-S/32) + """ + model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small (ViT-S/32) at 384x384. + """ + model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small (ViT-S/16) + """ + model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small (ViT-S/16) + """ + model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small (ViT-S/8) + """ + model_args = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. + """ + model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. + """ + model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) + """ + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). + """ + model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16) + model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + """ + model_args = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16) + model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_gigantic_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + """ + model_args = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16) + model = _create_vision_transformer( + 'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_224_miil(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False) + model = _create_vision_transformer( + 'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_gap_240(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240 + """ + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256 + """ + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384 + """ + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_betwixt_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Betwixt (ViT-b/16) w/o class token, w/ avg-pool @ 256x256 + """ + model_args = dict( + patch_size=16, embed_dim=640, depth=12, num_heads=10, class_token=False, + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224 + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) w/ no class token, avg pool + """ + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch16_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448 + """ + model_args = dict( + patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool + """ + model_args = dict( + patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, + class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + # TinyCLIP 8M + model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_xsmall_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + # TinyCLIP 40M + model_args = dict( + patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_medium_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + # TinyCLIP 39M + model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_medium_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_betwixt_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + # TinyCLIP 61M + model_args = dict( + patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_betwixt_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/32 CLIP image tower @ 224x224 + """ + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_clip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/32 CLIP image tower @ 256x256 + """ + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/32 CLIP image tower @ 384x384 + """ + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/32 CLIP image tower @ 448x448 + """ + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/16 CLIP image tower + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/16 CLIP image tower @ 384x384 + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_plus_clip_240(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/16+) CLIP image tower @ 240x240 + """ + model_args = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch16_plus_clip_240', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) CLIP image tower + """ + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 + """ + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) CLIP image tower. + """ + model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336 + """ + model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 + """ + model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + Pretrained weights from CLIP image tower. + """ + model_args = dict( + patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_gigantic_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + Pretrained weights from CLIP image tower. + """ + model_args = dict( + patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/32 CLIP image tower @ 224x224 + """ + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_base_patch32_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/16 CLIP image tower w/ QuickGELU act + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_base_patch16_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_large_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_clip_quickgelu_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_large_patch14_clip_quickgelu_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act. + """ + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_huge_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_clip_quickgelu_378(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act + """ + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_huge_patch14_clip_quickgelu_378', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_gigantic_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-bigG model (ViT-G/14) w/ QuickGELU act + """ + model_args = dict( + patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, + norm_layer=nn.LayerNorm, act_layer='quick_gelu') + model = _create_vision_transformer( + 'vit_gigantic_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +# Experimental models below + +@register_model +def vit_base_patch32_plus_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/32+) + """ + model_args = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) + model = _create_vision_transformer( + 'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_plus_240(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/16+) + """ + model_args = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) + model = _create_vision_transformer( + 'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_rpn_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base (ViT-B/16) w/ residual post-norm + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, + class_token=False, block_fn=ResPostBlock, global_pool='avg') + model = _create_vision_transformer( + 'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch16_36x1_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5) + model = _create_vision_transformer( + 'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_args = dict( + patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock) + model = _create_vision_transformer( + 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock) + model = _create_vision_transformer( + 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def eva_large_patch14_196(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') + model = _create_vision_transformer( + 'eva_large_patch14_196', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def eva_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') + model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def flexivit_small(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ FlexiViT-Small + """ + model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True) + model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def flexivit_base(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ FlexiViT-Base + """ + model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True) + model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def flexivit_large(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ FlexiViT-Large + """ + model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True) + model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled. + """ + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, no_embed_class=True, + norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, + ) + model = _create_vision_transformer( + 'vit_base_patch16_xp_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled. + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True, + norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, + ) + model = _create_vision_transformer( + 'vit_large_patch14_xp_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled. + """ + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True, + norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True, + ) + model = _create_vision_transformer( + 'vit_huge_patch14_xp_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-S/14 for DINOv2 + """ + model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5) + model = _create_vision_transformer( + 'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/14 for DINOv2 + """ + model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5) + model = _create_vision_transformer( + 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-L/14 for DINOv2 + """ + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5) + model = _create_vision_transformer( + 'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-G/14 for DINOv2 + """ + # The hidden_features of SwiGLU is calculated by: + # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + # When embed_dim=1536, hidden_features=4096 + # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 + model_args = dict( + patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, + mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU + ) + model = _create_vision_transformer( + 'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_small_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-S/14 for DINOv2 w/ 4 registers + """ + model_args = dict( + patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, + reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_small_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-B/14 for DINOv2 w/ 4 registers + """ + model_args = dict( + patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, + reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-L/14 for DINOv2 w/ 4 registers + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, + reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_large_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT-G/14 for DINOv2 + """ + # The hidden_features of SwiGLU is calculated by: + # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + # When embed_dim=1536, hidden_features=4096 + # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 + model_args = dict( + patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2, + mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_giant_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + act_layer='gelu_tanh', + ) + model = _create_vision_transformer( + 'vit_base_patch32_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map', + act_layer='gelu_tanh' + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_378(pretrained: bool = False, **kwargs) -> VisionTransformer: + # this is a corrected variant of the 384 with a res properly divisible by patch size (no padding/truncation) + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_378', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + act_layer='gelu_tanh', + ) + model = _create_vision_transformer( + 'vit_so400m_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + act_layer='gelu_tanh', + ) + model = _create_vision_transformer( + 'vit_so400m_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + act_layer='gelu_tanh', + ) + model = _create_vision_transformer( + 'vit_so400m_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giantopt_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map', + act_layer='gelu_tanh', + ) + model = _create_vision_transformer( + 'vit_giantopt_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giantopt_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map', + act_layer='gelu_tanh', + ) + model = _create_vision_transformer( + 'vit_giantopt_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch32_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + act_layer='gelu_tanh', + ) + model = _create_vision_transformer( + 'vit_base_patch32_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, + global_pool='avg', fc_norm=False, act_layer='gelu_tanh' + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_378(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_378', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_896', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, act_layer='gelu_tanh', + ) + model = _create_vision_transformer( + 'vit_so400m_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, + global_pool='avg', fc_norm=False, act_layer='gelu_tanh' + ) + model = _create_vision_transformer( + 'vit_so400m_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, + global_pool='avg', fc_norm=False, act_layer='gelu_tanh' + ) + model = _create_vision_transformer( + 'vit_so400m_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giantopt_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, + global_pool='avg', fc_norm=False, act_layer='gelu_tanh' + ) + model = _create_vision_transformer( + 'vit_giantopt_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giantopt_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, + global_pool='avg', fc_norm=False, act_layer='gelu_tanh' + ) + model = _create_vision_transformer( + 'vit_giantopt_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + + +@register_model +def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_wee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock, + ) + model = _create_vision_transformer( + 'vit_pwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_little_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_little_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_little_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_medium_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_mediumd_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_mediumd_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_mediumd_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_betwixt_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_betwixt_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_betwixt_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_betwixt_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5, + class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_betwixt_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, + no_embed_class=True, global_pool='avg', reg_tokens=4, + ) + model = _create_vision_transformer( + 'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """ + model_args = dict( + patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572, + class_token=False, reg_tokens=4, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """ + model_args = dict( + patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572, + class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """ + model_args = dict( + patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572, + class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so150m_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m2_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """ + model_args = dict( + patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5, + qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_so150m2_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m2_patch16_reg1_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """ + model_args = dict( + patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5, + qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_so150m2_patch16_reg1_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so150m2_patch16_reg1_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """ + model_args = dict( + patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5, + qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg', + ) + model = _create_vision_transformer( + 'vit_so150m2_patch16_reg1_gap_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, + init_values=0.1, final_norm=False, dynamic_img_size=True, + ) + model = _create_vision_transformer( + 'vit_intern300m_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Large AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Huge AIM-v2 model + """ + + model_args = dict( + patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_1b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 1B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_1b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_3b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 3B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_3b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Large AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_huge_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Huge AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_huge_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_1b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 1B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_1b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_3b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 3B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_3b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_large_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Large AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_huge_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Huge AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_huge_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_1b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 1B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False, + mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_1b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def aimv2_3b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT 3B AIM-v2 model + """ + model_args = dict( + patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False, + mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu', + norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU, + ) + model = _create_vision_transformer( + 'aimv2_3b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Test + """ + model_args = dict(patch_size=16, embed_dim=64, depth=6, num_heads=2, mlp_ratio=3, dynamic_img_size=True) + model = _create_vision_transformer('test_vit', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def test_vit2(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Test + """ + model_args = dict( + patch_size=16, embed_dim=64, depth=8, num_heads=2, mlp_ratio=3, + class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True) + model = _create_vision_transformer('test_vit2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Test + """ + model_args = dict( + patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=2, + class_token=False, reg_tokens=1, global_pool='map', init_values=1e-5) + model = _create_vision_transformer('test_vit3', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ ViT Test + """ + model_args = dict( + patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=3, + class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True, + norm_layer='rmsnorm', + ) + model = _create_vision_transformer('test_vit4', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def test_vit_attention(pretrained: bool = False, **kwargs) -> VisionTransformer_attn: + """ ViT Test + """ + model_args = dict( + patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=3, + class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True, + norm_layer='rmsnorm', + ) + model = _create_vision_transformer('test_vit_attention', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +register_model_deprecations(__name__, { + 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', + 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k', + 'vit_small_patch16_224_in21k': 'vit_small_patch16_224.augreg_in21k', + 'vit_base_patch32_224_in21k': 'vit_base_patch32_224.augreg_in21k', + 'vit_base_patch16_224_in21k': 'vit_base_patch16_224.augreg_in21k', + 'vit_base_patch8_224_in21k': 'vit_base_patch8_224.augreg_in21k', + 'vit_large_patch32_224_in21k': 'vit_large_patch32_224.orig_in21k', + 'vit_large_patch16_224_in21k': 'vit_large_patch16_224.augreg_in21k', + 'vit_huge_patch14_224_in21k': 'vit_huge_patch14_224.orig_in21k', + 'vit_base_patch32_224_sam': 'vit_base_patch32_224.sam', + 'vit_base_patch16_224_sam': 'vit_base_patch16_224.sam', + 'vit_small_patch16_224_dino': 'vit_small_patch16_224.dino', + 'vit_small_patch8_224_dino': 'vit_small_patch8_224.dino', + 'vit_base_patch16_224_dino': 'vit_base_patch16_224.dino', + 'vit_base_patch8_224_dino': 'vit_base_patch8_224.dino', + 'vit_base_patch16_224_miil_in21k': 'vit_base_patch16_224_miil.in21k', + 'vit_base_patch32_224_clip_laion2b': 'vit_base_patch32_clip_224.laion2b', + 'vit_large_patch14_224_clip_laion2b': 'vit_large_patch14_clip_224.laion2b', + 'vit_huge_patch14_224_clip_laion2b': 'vit_huge_patch14_clip_224.laion2b', + 'vit_giant_patch14_224_clip_laion2b': 'vit_giant_patch14_clip_224.laion2b', +}) + +if __name__ == "__main__": + model = test_vit_attention() + print(model) \ No newline at end of file