|
|
| """
|
| Backbone modules.
|
| """
|
| from collections import OrderedDict
|
|
|
| import os
|
| import torch
|
| import torch.nn.functional as F
|
| import torchvision
|
| from torch import nn
|
| from torchvision.models._utils import IntermediateLayerGetter
|
| from typing import Dict, List
|
|
|
| from util.misc import NestedTensor, is_main_process, clean_state_dict
|
|
|
| from .position_encoding import build_position_encoding
|
| from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
| from .swin_transformer import build_swin_transformer
|
|
|
| class FrozenBatchNorm2d(torch.nn.Module):
|
| """
|
| BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
|
| Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
| without which any other models than torchvision.models.resnet[18,34,50,101]
|
| produce nans.
|
| """
|
|
|
| def __init__(self, n):
|
| super(FrozenBatchNorm2d, self).__init__()
|
| self.register_buffer("weight", torch.ones(n))
|
| self.register_buffer("bias", torch.zeros(n))
|
| self.register_buffer("running_mean", torch.zeros(n))
|
| self.register_buffer("running_var", torch.ones(n))
|
|
|
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
| missing_keys, unexpected_keys, error_msgs):
|
| num_batches_tracked_key = prefix + 'num_batches_tracked'
|
| if num_batches_tracked_key in state_dict:
|
| del state_dict[num_batches_tracked_key]
|
|
|
| super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
| state_dict, prefix, local_metadata, strict,
|
| missing_keys, unexpected_keys, error_msgs)
|
|
|
| def forward(self, x):
|
|
|
|
|
| w = self.weight.reshape(1, -1, 1, 1)
|
| b = self.bias.reshape(1, -1, 1, 1)
|
| rv = self.running_var.reshape(1, -1, 1, 1)
|
| rm = self.running_mean.reshape(1, -1, 1, 1)
|
| eps = 1e-5
|
| scale = w * (rv + eps).rsqrt()
|
| bias = b - rm * scale
|
| return x * scale + bias
|
|
|
|
|
| class BackboneBase(nn.Module):
|
|
|
| def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
|
| super().__init__()
|
| for name, parameter in backbone.named_parameters():
|
| if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
| parameter.requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| return_interm_indices = [1, 2, 3]
|
| return_layers = {}
|
| for idx, layer_index in enumerate(return_interm_indices):
|
| return_layers.update({"layer{}".format(5 - len(return_interm_indices) + idx): "{}".format(layer_index)})
|
|
|
| self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
| self.num_channels = num_channels
|
|
|
| def forward(self, tensor_list: NestedTensor):
|
| xs = self.body(tensor_list.tensors)
|
|
|
|
|
|
|
| out: Dict[str, NestedTensor] = {}
|
| for name, x in xs.items():
|
| m = tensor_list.mask
|
| assert m is not None
|
| mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
| out[name] = NestedTensor(x, mask)
|
| return out
|
|
|
|
|
| class Backbone(BackboneBase):
|
| """ResNet backbone with frozen BatchNorm."""
|
| def __init__(self, name: str,
|
| train_backbone: bool,
|
| return_interm_layers: bool,
|
| dilation: bool):
|
| if name in ['resnet18', 'resnet34', 'resnet50', 'resnet101']:
|
| backbone = getattr(torchvision.models, name)(
|
| replace_stride_with_dilation=[False, False, dilation],
|
| pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
|
| num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
|
|
|
| super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
|
|
|
|
| class Joiner(nn.Sequential):
|
| def __init__(self, backbone, position_embedding):
|
| super().__init__(backbone, position_embedding)
|
|
|
| def forward(self, tensor_list: NestedTensor):
|
| xs = self[0](tensor_list)
|
| out: List[NestedTensor] = []
|
| pos = []
|
| for name, x in xs.items():
|
| out.append(x)
|
|
|
| pos.append(self[1](x).to(x.tensors.dtype))
|
|
|
| return out, pos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def build_backbone(args):
|
| """
|
| Useful args:
|
| - backbone: backbone name
|
| - lr_backbone:
|
| - dilation
|
| - return_interm_indices: available: [0,1,2,3], [1,2,3], [3]
|
| - backbone_freeze_keywords:
|
| - use_checkpoint: for swin only for now
|
|
|
| """
|
| backbone_freeze_keywords = None
|
| use_checkpoint = True
|
| position_embedding = build_position_encoding(args)
|
| train_backbone = args.lr_backbone > 0
|
| if not train_backbone:
|
| raise ValueError("Please set lr_backbone > 0")
|
| return_interm_indices = [1, 2, 3]
|
| assert return_interm_indices in [[0,1,2,3], [1,2,3], [3]]
|
|
|
|
|
| if args.backbone in ['resnet50', 'resnet101']:
|
| backbone = Backbone(args.backbone, train_backbone, args.dilation,
|
| return_interm_indices,
|
| batch_norm=FrozenBatchNorm2d)
|
| bb_num_channels = backbone.num_channels
|
| elif args.backbone in ['swin_T_224_1k', 'swin_B_224_22k', 'swin_B_384_22k', 'swin_L_224_22k', 'swin_L_384_22k']:
|
| pretrain_img_size = int(args.backbone.split('_')[-2])
|
| backbone = build_swin_transformer(args.backbone, \
|
| pretrain_img_size=pretrain_img_size, \
|
| out_indices=tuple(return_interm_indices), \
|
| dilation=args.dilation, use_checkpoint=use_checkpoint)
|
|
|
|
|
| if backbone_freeze_keywords is not None:
|
| for name, parameter in backbone.named_parameters():
|
| for keyword in backbone_freeze_keywords:
|
| if keyword in name:
|
| parameter.requires_grad_(False)
|
| break
|
| if hasattr(args, "backbone_dir"):
|
| pretrained_dir = args.backbone_dir
|
| PTDICT = {
|
| 'swin_T_224_1k': 'swin_tiny_patch4_window7_224.pth',
|
| 'swin_B_224_22k': 'swin_base_patch4_window7_224_22kto1k.pth',
|
| 'swin_B_384_22k': 'swin_base_patch4_window12_384.pth',
|
| 'swin_L_384_22k': 'swin_large_patch4_window12_384_22k.pth',
|
| }
|
| pretrainedpath = os.path.join(pretrained_dir, PTDICT[args.backbone])
|
| checkpoint = torch.load(pretrainedpath, map_location='cpu', weights_only=False)['model']
|
| from collections import OrderedDict
|
| def key_select_function(keyname):
|
| if 'head' in keyname:
|
| return False
|
| if args.dilation and 'layers.3' in keyname:
|
| return False
|
| return True
|
| _tmp_st = OrderedDict({k:v for k, v in clean_state_dict(checkpoint).items() if key_select_function(k)})
|
| _tmp_st_output = backbone.load_state_dict(_tmp_st, strict=False)
|
| print(str(_tmp_st_output))
|
| bb_num_channels = backbone.num_features[4 - len(return_interm_indices):]
|
|
|
|
|
|
|
| else:
|
| raise NotImplementedError("Unknown backbone {}".format(args.backbone))
|
|
|
|
|
| assert len(bb_num_channels) == len(return_interm_indices), f"len(bb_num_channels) {len(bb_num_channels)} != len(return_interm_indices) {len(return_interm_indices)}"
|
|
|
|
|
| model = Joiner(backbone, position_embedding)
|
| model.num_channels = bb_num_channels
|
| assert isinstance(bb_num_channels, List), "bb_num_channels is expected to be a List but {}".format(type(bb_num_channels))
|
| return model |