| import torch |
| import torchvision |
| import torch.nn as nn |
|
|
|
|
| class resnet18(nn.Module): |
| def __init__( |
| self, |
| pretrained: bool = True, |
| output_dim: int = 512, |
| unit_norm: bool = False, |
| ): |
| super().__init__() |
| resnet = torchvision.models.resnet18(pretrained=pretrained) |
| self.resnet = nn.Sequential(*list(resnet.children())[:-1]) |
| self.flatten = nn.Flatten() |
| self.pretrained = pretrained |
| self.normalize = torchvision.transforms.Normalize( |
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| ) |
| self.unit_norm = unit_norm |
|
|
| def forward(self, x): |
| dims = len(x.shape) |
| orig_shape = x.shape |
| if dims == 3: |
| x = x.unsqueeze(0) |
| elif dims > 4: |
| |
| x = x.reshape(-1, *orig_shape[-3:]) |
| x = self.normalize(x) |
| out = self.resnet(x) |
| out = self.flatten(out) |
| if self.unit_norm: |
| out = torch.nn.functional.normalize(out, p=2, dim=-1) |
| if dims == 3: |
| out = out.squeeze(0) |
| elif dims > 4: |
| out = out.reshape(*orig_shape[:-3], -1) |
| return out |
|
|