| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import torch |
| from torchvision.models.utils import load_state_dict_from_url |
| from typing import Type, Any, Callable, Union, List, Optional |
| from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet |
|
|
|
|
| __all__ = [ |
| "ResNet", |
| "resnet18", |
| "resnet34", |
| "resnet50", |
| "resnet101", |
| "resnet152", |
| "resnext50_32x4d", |
| "resnext101_32x8d", |
| "wide_resnet50_2", |
| "wide_resnet101_2", |
| ] |
|
|
|
|
| model_urls = { |
| "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", |
| "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", |
| "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", |
| "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", |
| "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", |
| "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", |
| "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", |
| "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", |
| "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", |
| } |
|
|
|
|
| class ResNet_mine(ResNet): |
| def __init__(self, block, layers, classifier_run=True, **kwargs): |
| super().__init__(block, layers, **kwargs) |
| self.classifier_run = classifier_run |
|
|
| def _forward_impl(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor): |
| |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
| x = self.maxpool(x) |
|
|
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x_ = self.layer4(x) |
|
|
| x = self.avgpool(x_) |
| x = torch.flatten(x, 1) |
| if self.classifier_run: |
| x = self.fc(x) |
|
|
| return x, x_ |
|
|
| def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor): |
| return self._forward_impl(x) |
|
|
|
|
| def pnorm(weights, p): |
| normB = torch.norm(weights, 2, 1) |
| ws = weights.clone() |
| for i in range(weights.size(0)): |
| ws[i] = ws[i] / torch.pow(normB[i], p) |
| return ws |
|
|
|
|
| def _resnet( |
| arch: str, |
| block: Type[Union[BasicBlock, Bottleneck]], |
| layers: List[int], |
| pretrained: bool, |
| progress: bool, |
| **kwargs: Any |
| ) -> ResNet: |
| model = ResNet_mine(block, layers, **kwargs) |
| if pretrained: |
| print("Inside resnet function, using ImageNet pretrained from model url!") |
| state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) |
| model.load_state_dict(state_dict) |
| return model |
|
|
|
|
| def resnext50_32x4d( |
| pretrained: bool = False, progress: bool = True, **kwargs: Any |
| ) -> ResNet: |
| r"""ResNeXt-50 32x4d model from |
| `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_. |
| Args: |
| pretrained (bool): If True, returns a model pre-trained on ImageNet |
| progress (bool): If True, displays a progress bar of the download to stderr |
| """ |
| kwargs["groups"] = 32 |
| kwargs["width_per_group"] = 4 |
| return _resnet( |
| "resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs |
| ) |
|
|
|
|
| def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: |
| r"""ResNet-50 model from |
| `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_. |
| |
| Args: |
| pretrained (bool): If True, returns a model pre-trained on ImageNet |
| progress (bool): If True, displays a progress bar of the download to stderr |
| """ |
| return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) |
|
|