| |
| import torch |
| import torch.nn as nn |
| from transformers import PretrainedConfig, PreTrainedModel |
| from transformers.modeling_outputs import ImageClassifierOutput |
|
|
|
|
| class PrunedResNetConfig(PretrainedConfig): |
| model_type = "resnet" |
|
|
| def __init__( |
| self, channel_config: dict[str, int] | None = None, num_classes=1000, **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.channel_config = channel_config |
| self.num_classes = num_classes |
|
|
|
|
| class PrunedResNet50(PreTrainedModel): |
| config_class = PrunedResNetConfig |
| _tied_weights_keys = [] |
|
|
| def __init__(self, config: PrunedResNetConfig): |
| super().__init__(config) |
| self.config = config |
| c = config.channel_config |
| self.conv1 = nn.Conv2d( |
| 3, c["conv1"], kernel_size=7, stride=2, padding=3, bias=False |
| ) |
| self.bn1 = nn.BatchNorm2d(c["conv1"]) |
| self.relu = nn.ReLU(inplace=True) |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
| self.layer1 = self._make_layer(c, stage_idx=1, layers=3, stride=1) |
| self.layer2 = self._make_layer(c, stage_idx=2, layers=4, stride=2) |
| self.layer3 = self._make_layer(c, stage_idx=3, layers=6, stride=2) |
| self.layer4 = self._make_layer(c, stage_idx=4, layers=3, stride=2) |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| last_channel = c["layer4.2.conv3"] |
| self.fc = nn.Linear(last_channel, config.num_classes) |
| self.post_init() |
|
|
| def _make_layer(self, c, stage_idx, layers, stride): |
| |
| blocks = [] |
|
|
| |
| blocks.append( |
| Bottleneck( |
| inplanes=c[f"layer{stage_idx}.0.in"], |
| planes=[ |
| c[f"layer{stage_idx}.0.conv1"], |
| c[f"layer{stage_idx}.0.conv2"], |
| c[f"layer{stage_idx}.0.conv3"], |
| ], |
| stride=stride, |
| downsample_planes=c.get(f"layer{stage_idx}.0.downsample.0", None), |
| ) |
| ) |
|
|
| |
| for i in range(1, layers): |
| blocks.append( |
| Bottleneck( |
| inplanes=c[f"layer{stage_idx}.{i}.in"], |
| planes=[ |
| c[f"layer{stage_idx}.{i}.conv1"], |
| c[f"layer{stage_idx}.{i}.conv2"], |
| c[f"layer{stage_idx}.{i}.conv3"], |
| ], |
| ) |
| ) |
|
|
| return nn.Sequential(*blocks) |
|
|
| def forward(self, pixel_values=None, labels=None, **kwargs): |
| x = pixel_values |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
| x = self.maxpool(x) |
|
|
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x = self.layer4(x) |
|
|
| x = self.avgpool(x) |
| x = torch.flatten(x, 1) |
| logits = self.fc(x) |
| loss = None |
| if labels is not None: |
| |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1)) |
| return ImageClassifierOutput(logits=logits, loss=loss) |
|
|
|
|
| class Bottleneck(nn.Module): |
| |
| def __init__(self, inplanes, planes, stride=1, downsample_planes=None): |
| super().__init__() |
| c1, c2, c3 = planes |
|
|
| self.conv1 = nn.Conv2d(inplanes, c1, kernel_size=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(c1) |
|
|
| self.conv2 = nn.Conv2d( |
| c1, c2, kernel_size=3, stride=stride, padding=1, bias=False |
| ) |
| self.bn2 = nn.BatchNorm2d(c2) |
|
|
| self.conv3 = nn.Conv2d(c2, c3, kernel_size=1, bias=False) |
| self.bn3 = nn.BatchNorm2d(c3) |
|
|
| self.relu = nn.ReLU(inplace=True) |
|
|
| self.downsample = None |
| if downsample_planes is not None: |
| self.downsample = nn.Sequential( |
| nn.Conv2d( |
| inplanes, |
| downsample_planes, |
| kernel_size=1, |
| stride=stride, |
| bias=False, |
| ), |
| nn.BatchNorm2d(downsample_planes), |
| ) |
|
|
| def forward(self, x): |
| identity = x |
| out = self.conv1(x) |
| out = self.bn1(out) |
| out = self.relu(out) |
|
|
| out = self.conv2(out) |
| out = self.bn2(out) |
| out = self.relu(out) |
|
|
| out = self.conv3(out) |
| out = self.bn3(out) |
|
|
| if self.downsample is not None: |
| identity = self.downsample(x) |
|
|
| out += identity |
| out = self.relu(out) |
| return out |
|
|