| ''' |
| Code Reference: |
| https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py |
| https://github.com/G-U-N/PyCIL/blob/master/convs/resnet.py |
| ''' |
|
|
| import copy |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from torch.nn.parameter import Parameter |
|
|
| __all__ = ['resnet18', 'resnet34', 'resnet50', 'cifar_resnet20', 'cifar_resnet32', 'cifar_resnet32_V2', 'resnet32_V2', 'resnet18_AML', 'CosineLinear', 'SplitCosineLinear'] |
|
|
| def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): |
| """3x3 convolution with padding""" |
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
| padding=dilation, groups=groups, bias=False, dilation=dilation) |
|
|
| def conv1x1(in_planes, out_planes, stride=1): |
| """1x1 convolution""" |
| return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) |
|
|
| class BasicBlock(nn.Module): |
| expansion = 1 |
| __constants__ = ['downsample'] |
|
|
| def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, |
| base_width=64, dilation=1, norm_layer=None): |
| super(BasicBlock, self).__init__() |
| if norm_layer is None: |
| norm_layer = nn.BatchNorm2d |
| if groups != 1 or base_width != 64: |
| raise ValueError('BasicBlock only supports groups=1 and base_width=64') |
| if dilation > 1: |
| raise NotImplementedError("Dilation > 1 not supported in BasicBlock") |
| |
| self.conv1 = conv3x3(inplanes, planes, stride) |
| self.bn1 = norm_layer(planes) |
| self.relu = nn.ReLU(inplace=True) |
| self.conv2 = conv3x3(planes, planes) |
| self.bn2 = norm_layer(planes) |
| self.downsample = downsample |
| self.stride = stride |
|
|
| def forward(self, x): |
| identity = x |
|
|
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
|
|
| out = self.conv2(out) |
| out = self.bn2(out) |
|
|
| if self.downsample is not None: |
| identity = self.downsample(x) |
|
|
| out += identity |
| out = self.relu(out) |
|
|
| return out |
|
|
| class Bottleneck(nn.Module): |
| expansion = 4 |
| __constants__ = ['downsample'] |
|
|
| def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, |
| base_width=64, dilation=1, norm_layer=None): |
| super(Bottleneck, self).__init__() |
| if norm_layer is None: |
| norm_layer = nn.BatchNorm2d |
| width = int(planes * (base_width / 64.)) * groups |
| |
| self.conv1 = conv1x1(inplanes, width) |
| self.bn1 = norm_layer(width) |
| self.conv2 = conv3x3(width, width, stride, groups, dilation) |
| self.bn2 = norm_layer(width) |
| self.conv3 = conv1x1(width, planes * self.expansion) |
| self.bn3 = norm_layer(planes * self.expansion) |
| self.relu = nn.ReLU(inplace=True) |
| self.downsample = downsample |
| self.stride = stride |
|
|
| def forward(self, x): |
| |
| identity = x |
|
|
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
|
|
| out = self.conv2(out) |
| out = self.bn2(out) |
| out = self.relu(out) |
|
|
| out = self.conv3(out) |
| out = self.bn3(out) |
|
|
| if self.downsample is not None: |
| identity = self.downsample(x) |
|
|
| out += identity |
| out = self.relu(out) |
|
|
| return out |
|
|
| class ResNet(nn.Module): |
|
|
| def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, |
| groups=1, width_per_group=64, replace_stride_with_dilation=None, |
| norm_layer=None,args=None): |
| super(ResNet, self).__init__() |
| if norm_layer is None: |
| norm_layer = nn.BatchNorm2d |
| self._norm_layer = norm_layer |
|
|
| self.inplanes = 64 |
| self.dilation = 1 |
| if replace_stride_with_dilation is None: |
| |
| |
| replace_stride_with_dilation = [False, False, False] |
| if len(replace_stride_with_dilation) != 3: |
| raise ValueError("replace_stride_with_dilation should be None " |
| "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) |
| self.groups = groups |
| self.base_width = width_per_group |
| |
| assert args is not None, "you should pass args to resnet" |
| if 'cifar' in args["dataset"] or '5-datasets' in args["dataset"]: |
| self.conv1 = nn.Sequential(nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), |
| nn.BatchNorm2d(self.inplanes), nn.ReLU(inplace=True)) |
| elif 'imagenet' in args["dataset"]: |
| if args["init_cls_num"] == args["inc_cls_num"]: |
| self.conv1 = nn.Sequential( |
| nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False), |
| nn.BatchNorm2d(self.inplanes), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2, padding=1), |
| ) |
| else: |
| self.conv1 = nn.Sequential( |
| nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False), |
| nn.BatchNorm2d(self.inplanes), |
| nn.ReLU(inplace=True), |
| nn.MaxPool2d(kernel_size=3, stride=2, padding=1), |
| ) |
|
|
|
|
| self.layer1 = self._make_layer(block, 64, layers[0]) |
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2, |
| dilate=replace_stride_with_dilation[0]) |
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2, |
| dilate=replace_stride_with_dilation[1]) |
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2, |
| dilate=replace_stride_with_dilation[2]) |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| self.out_dim = 512 * block.expansion |
| |
|
|
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
|
|
| |
| |
| |
| if zero_init_residual: |
| for m in self.modules(): |
| if isinstance(m, Bottleneck): |
| nn.init.constant_(m.bn3.weight, 0) |
| elif isinstance(m, BasicBlock): |
| nn.init.constant_(m.bn2.weight, 0) |
| |
| |
| self.neck = nn.ModuleList() |
| self.fc = nn.Linear(512, 20) |
|
|
| def _make_layer(self, block, planes, blocks, stride=1, dilate=False): |
| norm_layer = self._norm_layer |
| downsample = None |
| previous_dilation = self.dilation |
| if dilate: |
| self.dilation *= stride |
| stride = 1 |
| if stride != 1 or self.inplanes != planes * block.expansion: |
| downsample = nn.Sequential( |
| conv1x1(self.inplanes, planes * block.expansion, stride), |
| norm_layer(planes * block.expansion), |
| ) |
|
|
| layers = [] |
| layers.append(block(self.inplanes, planes, stride, downsample, self.groups, |
| self.base_width, previous_dilation, norm_layer)) |
| self.inplanes = planes * block.expansion |
| for _ in range(1, blocks): |
| layers.append(block(self.inplanes, planes, groups=self.groups, |
| base_width=self.base_width, dilation=self.dilation, |
| norm_layer=norm_layer)) |
|
|
| return nn.Sequential(*layers) |
|
|
| def _forward_impl(self, x): |
| x = self.conv1(x) |
|
|
| x_1 = self.layer1(x) |
| x_2 = self.layer2(x_1) |
| x_3 = self.layer3(x_2) |
| x_4 = self.layer4(x_3) |
|
|
| pooled = self.avgpool(x_4) |
| features = torch.flatten(pooled, 1) |
|
|
| return { |
| 'fmaps': [x_1, x_2, x_3, x_4], |
| 'features': features |
| } |
|
|
| def forward(self, x): |
| return self._forward_impl(x) |
| |
| def feature(self, x): |
| x = self.conv1(x) |
|
|
| x_1 = self.layer1(x) |
| x_2 = self.layer2(x_1) |
| x_3 = self.layer3(x_2) |
| x_4 = self.layer4(x_3) |
|
|
| pooled = self.avgpool(x_4) |
| features = torch.flatten(pooled, 1) |
| |
| return features |
|
|
| @property |
| def last_conv(self): |
| if hasattr(self.layer4[-1], 'conv3'): |
| return self.layer4[-1].conv3 |
| else: |
| return self.layer4[-1].conv2 |
|
|
| def _resnet(arch, block, layers, pretrained, progress, **kwargs): |
| model = ResNet(block, layers, **kwargs) |
| if pretrained: |
| raise NotImplementedError |
|
|
| if 'cosine_fc' in kwargs['args'].keys() and kwargs['args']['cosine_fc']: |
| in_features = model.fc.in_features |
| out_features = model.fc.out_features |
| model.fc = CosineLinear(in_features, out_features) |
| return model |
|
|
| def resnet18(pretrained=False, progress=True, **kwargs): |
| r"""ResNet-18 model from |
| `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ |
| Args: |
| pretrained (bool): If True, returns a model pre-trained on ImageNet |
| progress (bool): If True, displays a progress bar of the download to stderr |
| """ |
| return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, |
| **kwargs) |
|
|
| def resnet34(pretrained=False, progress=True, **kwargs): |
| r"""ResNet-34 model from |
| `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ |
| Args: |
| pretrained (bool): If True, returns a model pre-trained on ImageNet |
| progress (bool): If True, displays a progress bar of the download to stderr |
| """ |
| return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, |
| **kwargs) |
|
|
| def resnet50(pretrained=False, progress=True, **kwargs): |
| r"""ResNet-50 model from |
| `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ |
| Args: |
| pretrained (bool): If True, returns a model pre-trained on ImageNet |
| progress (bool): If True, displays a progress bar of the download to stderr |
| """ |
| return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, |
| **kwargs) |
|
|
| class ResNetBasicblock(nn.Module): |
| expansion = 1 |
|
|
| def __init__(self, inplanes, planes, stride=1, downsample=None): |
| super(ResNetBasicblock, self).__init__() |
|
|
| self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
| self.bn_a = nn.BatchNorm2d(planes) |
|
|
| self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) |
| self.bn_b = nn.BatchNorm2d(planes) |
|
|
| self.downsample = downsample |
|
|
| def forward(self, x): |
| residual = x |
|
|
| basicblock = self.conv_a(x) |
| basicblock = self.bn_a(basicblock) |
| basicblock = F.relu(basicblock, inplace=True) |
|
|
| basicblock = self.conv_b(basicblock) |
| basicblock = self.bn_b(basicblock) |
|
|
| if self.downsample is not None: |
| residual = self.downsample(x) |
|
|
| return F.relu(residual + basicblock, inplace=True) |
|
|
| ''' |
| Code Reference: |
| https://github.com/G-U-N/PyCIL/blob/master/convs/resnet.py |
| |
| We keep this version ResNet to ensure that we can achieve better accuracy. |
| ''' |
| class CifarResNet(nn.Module): |
| """ |
| ResNet optimized for the Cifar Dataset, as specified in |
| https://arxiv.org/abs/1512.03385.pdf |
| """ |
|
|
| def __init__(self, block, depth, channels=3): |
| super(CifarResNet, self).__init__() |
|
|
| |
| assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' |
| layer_blocks = (depth - 2) // 6 |
|
|
| self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) |
| self.bn_1 = nn.BatchNorm2d(16) |
|
|
| self.inplanes = 16 |
| self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) |
| self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) |
| self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) |
| self.avgpool = nn.AvgPool2d(8) |
| self.out_dim = 64 * block.expansion |
| |
|
|
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| m.weight.data.normal_(0, math.sqrt(2. / n)) |
| |
| elif isinstance(m, nn.BatchNorm2d): |
| m.weight.data.fill_(1) |
| m.bias.data.zero_() |
| elif isinstance(m, nn.Linear): |
| nn.init.kaiming_normal_(m.weight) |
| m.bias.data.zero_() |
| |
| |
| self.neck = nn.ModuleList() |
|
|
| def _make_layer(self, block, planes, blocks, stride=1): |
| downsample = None |
| if stride != 1 or self.inplanes != planes * block.expansion: |
| downsample = nn.Sequential( |
| nn.Conv2d(self.inplanes, planes * block.expansion, |
| kernel_size=1, stride=stride, bias=False), |
| nn.BatchNorm2d(planes * block.expansion), |
| ) |
| |
|
|
| layers = [] |
| layers.append(block(self.inplanes, planes, stride, downsample)) |
| self.inplanes = planes * block.expansion |
| for i in range(1, blocks): |
| layers.append(block(self.inplanes, planes)) |
|
|
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| x = self.conv_1_3x3(x) |
| x = F.relu(self.bn_1(x), inplace=True) |
|
|
| x_1 = self.stage_1(x) |
| x_2 = self.stage_2(x_1) |
| x_3 = self.stage_3(x_2) |
|
|
| pooled = self.avgpool(x_3) |
| features = pooled.view(pooled.size(0), -1) |
|
|
| return { |
| 'fmaps': [x_1, x_2, x_3], |
| 'features': features |
| } |
| |
| def feature(self, x): |
| x = self.conv_1_3x3(x) |
| x = F.relu(self.bn_1(x), inplace=True) |
|
|
| x_1 = self.stage_1(x) |
| x_2 = self.stage_2(x_1) |
| x_3 = self.stage_3(x_2) |
|
|
| pooled = self.avgpool(x_3) |
| features = pooled.view(pooled.size(0), -1) |
| |
| return features |
|
|
| @property |
| def last_conv(self): |
| return self.stage_3[-1].conv_b |
|
|
| ''' |
| Code Reference: |
| https://github.com/hshustc/CVPR19_Incremental_Learning/blob/master/cifar100-class-incremental/modified_linear.py |
| ''' |
| class CosineLinear(nn.Module): |
| def __init__(self, in_features, out_features, sigma=True): |
| super(CosineLinear, self).__init__() |
| self.in_features = in_features |
| self.out_features = out_features |
| self.weight = Parameter(torch.Tensor(out_features, in_features)) |
| if sigma: |
| self.sigma = Parameter(torch.Tensor(1)) |
| else: |
| self.register_parameter('sigma', None) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| stdv = 1. / math.sqrt(self.weight.size(1)) |
| self.weight.data.uniform_(-stdv, stdv) |
| if self.sigma is not None: |
| self.sigma.data.fill_(1) |
|
|
| def forward(self, input): |
| out = F.linear(F.normalize(input, p=2,dim=1), \ |
| F.normalize(self.weight, p=2, dim=1)) |
| if self.sigma is not None: |
| out = self.sigma * out |
| return out |
|
|
| class SplitCosineLinear(nn.Module): |
| |
| def __init__(self, in_features, out_features1, out_features2, sigma=True): |
| super(SplitCosineLinear, self).__init__() |
| self.in_features = in_features |
| self.out_features = out_features1 + out_features2 |
| self.fc1 = CosineLinear(in_features, out_features1, False) |
| self.fc2 = CosineLinear(in_features, out_features2, False) |
| if sigma: |
| self.sigma = Parameter(torch.Tensor(1)) |
| self.sigma.data.fill_(1) |
| else: |
| self.register_parameter('sigma', None) |
|
|
| def forward(self, x): |
| out1 = self.fc1(x) |
| out2 = self.fc2(x) |
| out = torch.cat((out1, out2), dim=1) |
| if self.sigma is not None: |
| out = self.sigma * out |
| return out |
|
|
|
|
| ''' |
| Code Reference: |
| https://github.com/hshustc/CVPR19_Incremental_Learning |
| |
| The version of ResNet used in the official LUCIR repository, if not used, will lead to a decrease in performance. |
| ''' |
|
|
| class modified_BasicBlock(nn.Module): |
| expansion = 1 |
|
|
| def __init__(self, inplanes, planes, stride=1, downsample=None, last=False): |
| super(modified_BasicBlock, self).__init__() |
| self.conv1 = conv3x3(inplanes, planes, stride) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.relu = nn.ReLU(inplace=True) |
| self.conv2 = conv3x3(planes, planes) |
| self.bn2 = nn.BatchNorm2d(planes) |
| self.downsample = downsample |
| self.stride = stride |
| self.last = last |
|
|
| def forward(self, x): |
| residual = x |
|
|
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
|
|
| out = self.conv2(out) |
| out = self.bn2(out) |
|
|
| if self.downsample is not None: |
| residual = self.downsample(x) |
|
|
| out += residual |
| if not self.last: |
| out = self.relu(out) |
|
|
| return out |
|
|
| class modified_ResNet(nn.Module): |
|
|
| def __init__(self, block, layers, num_classes=10): |
| self.inplanes = 16 |
| super(modified_ResNet, self).__init__() |
| self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, |
| bias=False) |
| self.bn1 = nn.BatchNorm2d(16) |
| self.relu = nn.ReLU(inplace=True) |
| self.layer1 = self._make_layer(block, 16, layers[0]) |
| self.layer2 = self._make_layer(block, 32, layers[1], stride=2) |
| self.layer3 = self._make_layer(block, 64, layers[2], stride=2, last_phase=True) |
| self.avgpool = nn.AvgPool2d(8, stride=1) |
| |
|
|
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| elif isinstance(m, nn.BatchNorm2d): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
|
|
| def _make_layer(self, block, planes, blocks, stride=1, last_phase=False): |
| downsample = None |
| if stride != 1 or self.inplanes != planes * block.expansion: |
| downsample = nn.Sequential( |
| nn.Conv2d(self.inplanes, planes * block.expansion, |
| kernel_size=1, stride=stride, bias=False), |
| nn.BatchNorm2d(planes * block.expansion), |
| ) |
|
|
| layers = [] |
| layers.append(block(self.inplanes, planes, stride, downsample)) |
| self.inplanes = planes * block.expansion |
| if last_phase: |
| for i in range(1, blocks-1): |
| layers.append(block(self.inplanes, planes)) |
| layers.append(block(self.inplanes, planes, last=True)) |
| else: |
| for i in range(1, blocks): |
| layers.append(block(self.inplanes, planes)) |
|
|
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
|
|
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
|
|
| x = self.avgpool(x) |
| x = x.view(x.size(0), -1) |
| |
| return {"features": x} |
| |
| def feature(self, x): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
|
|
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
|
|
| x = self.avgpool(x) |
| x = x.view(x.size(0), -1) |
| |
| return x |
|
|
| |
| class BiasLayer(nn.Module): |
|
|
| def __init__(self): |
| super().__init__() |
| self.alpha = nn.Parameter(torch.ones(1, requires_grad=True)) |
| self.beta = nn.Parameter(torch.zeros(1, requires_grad=True)) |
|
|
| def forward(self, x): |
| return self.alpha * x + self.beta |
| |
| class BasicBlock2(nn.Module): |
| expansion = 1 |
|
|
| def __init__(self, inplanes, planes, stride=1, downsample=None): |
| super().__init__() |
| self.bn1 = nn.BatchNorm2d(inplanes) |
| self.relu = nn.ReLU(inplace=True) |
|
|
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
| self.bn2 = nn.BatchNorm2d(planes) |
| self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) |
|
|
| self.downsample = downsample |
| self.stride = stride |
|
|
| def forward(self, x): |
| residual = x |
|
|
| out = self.bn1(x) |
| out = self.relu(out) |
| out = self.conv1(out) |
|
|
| out = self.bn2(out) |
| out = self.relu(out) |
| out = self.conv2(out) |
|
|
| if self.downsample is not None: |
| residual = self.downsample(x) |
|
|
| out += residual |
|
|
| return out |
|
|
| class ResNet_BIC(nn.Module): |
|
|
| def __init__(self, depth, block_name='BasicBlock2'): |
| super().__init__() |
| |
| if block_name.lower() == 'basicblock2': |
| assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' |
| n = (depth - 2) // 6 |
| block = BasicBlock2 |
| elif block_name.lower() == 'bottleneck': |
| assert 0, 'bottleneck is called, should not happen in method : BIC' |
| assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' |
| n = (depth - 2) // 9 |
| block = Bottleneck |
| else: |
| raise ValueError('block_name shoule be Basicblock or Bottleneck') |
|
|
| self.inplanes = 16 |
| self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False) |
| self.layer1 = self._make_layer(block, 16, n) |
| self.layer2 = self._make_layer(block, 32, n, stride=2) |
| self.layer3 = self._make_layer(block, 64, n, stride=2) |
| self.bn = nn.BatchNorm2d(64 * block.expansion) |
| self.relu = nn.ReLU(inplace=True) |
| self.avgpool = nn.AvgPool2d(8) |
| |
| self.feat_dim = 256 |
| |
|
|
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
| m.weight.data.normal_(0, math.sqrt(2. / n)) |
| elif isinstance(m, nn.BatchNorm2d): |
| m.weight.data.fill_(1) |
| m.bias.data.zero_() |
|
|
| def _make_layer(self, block, planes, blocks, stride=1): |
| downsample = None |
| if stride != 1 or self.inplanes != planes * block.expansion: |
| downsample = nn.Sequential( |
| nn.Conv2d(self.inplanes, planes * block.expansion, |
| kernel_size=1, stride=stride, bias=False), |
| ) |
|
|
| layers = [] |
| layers.append(block(self.inplanes, planes, stride, downsample)) |
| self.inplanes = planes * block.expansion |
| for i in range(1, blocks): |
| layers.append(block(self.inplanes, planes)) |
|
|
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| x = self.conv1(x) |
|
|
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x = self.bn(x) |
| x = self.relu(x) |
|
|
| x = self.avgpool(x) |
| x = x.view(x.size(0), -1) |
|
|
| return x |
|
|
| |
| class BasicBlock_AML(nn.Module): |
| expansion = 1 |
|
|
| def __init__(self, in_planes, planes, stride=1): |
| super().__init__() |
| self.conv1 = conv3x3(in_planes, planes, stride) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.conv2 = conv3x3(planes, planes) |
| self.bn2 = nn.BatchNorm2d(planes) |
|
|
| self.shortcut = nn.Sequential() |
| if stride != 1 or in_planes != self.expansion * planes: |
| self.shortcut = nn.Sequential( |
| nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, |
| stride=stride, bias=False), |
| nn.BatchNorm2d(self.expansion * planes) |
| ) |
|
|
| self.activation = nn.ReLU() |
|
|
| def forward(self, x): |
| out = self.activation(self.bn1(self.conv1(x))) |
| out = self.bn2(self.conv2(out)) |
| out = out + self.shortcut(x) |
| out = self.activation(out) |
| return out |
|
|
| class ResNet_AML(nn.Module): |
| def __init__(self, block, num_blocks, num_classes, nf=20, input_size=(3, 32, 32)): |
| super().__init__() |
| self.in_planes = nf |
| self.input_size = input_size |
|
|
| self.conv1 = conv3x3(input_size[0], nf * 1) |
| self.bn1 = nn.BatchNorm2d(nf * 1) |
| self.layer1 = self._make_layer(block, nf * 1, num_blocks[0], stride=1) |
| self.layer2 = self._make_layer(block, nf * 2, num_blocks[1], stride=2) |
| self.layer3 = self._make_layer(block, nf * 4, num_blocks[2], stride=2) |
| self.layer4 = self._make_layer(block, nf * 8, num_blocks[3], stride=2) |
|
|
| self.activation = nn.ReLU() |
|
|
| with torch.no_grad(): |
| dummy = torch.zeros(1, *self.input_size) |
| out = self.forward(dummy) |
| self.out_dim = out.view(1, -1).shape[1] |
|
|
| def _make_layer(self, block, planes, num_blocks, stride): |
| strides = [stride] + [1] * (num_blocks - 1) |
| layers = [] |
| for stride in strides: |
| layers.append(block(self.in_planes, planes, stride)) |
| self.in_planes = planes * block.expansion |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| out = self.activation(self.bn1(self.conv1(x))) |
| out = self.layer1(out) |
| out = self.layer2(out) |
| out = self.layer3(out) |
| out = self.layer4(out) |
| out = F.avg_pool2d(out, 4) |
| return out.view(out.size(0), -1) |
|
|
|
|
| def cifar_resnet20(pretrained=False, **kwargs): |
| n = 3 |
| model = CifarResNet(ResNetBasicblock, 20) |
| return model |
|
|
| def cifar_resnet32(pretrained=False, **kwargs): |
| |
| model = CifarResNet(ResNetBasicblock, 32) |
| return model |
|
|
| def cifar_resnet32_V2(pretrained=False, **kwargs): |
| |
| return ResNet_BIC(32) |
|
|
| def resnet32_V2(pretrained=False, **kwargs): |
| |
| n = 5 |
| model = modified_ResNet(modified_BasicBlock, [n, n, n], num_classes=50) |
| return model |
|
|
| def resnet18_AML(pretrained=False, **kwargs): |
| if 'input_size' not in kwargs.keys(): |
| kwargs['input_size'] = [3, 32, 32] |
| return ResNet_AML(BasicBlock_AML, [2, 2, 2, 2], kwargs['num_classes'], input_size = kwargs['input_size']) |
|
|