| import torch |
| import torch.nn.functional as F |
| from huggingface_hub import PyTorchModelHubMixin |
| from torch import nn |
| from torchvision import models |
|
|
|
|
| class ICN(nn.Module, PyTorchModelHubMixin): |
| def __init__(self): |
| super().__init__() |
|
|
| cnn = models.resnet50(pretrained=False) |
| self.cnn_head = nn.Sequential( |
| *list(cnn.children())[:4], |
| *list(list(list(cnn.children())[4].children())[0].children())[:4], |
| ) |
| self.cnn_tail = nn.Sequential( |
| *list(list(cnn.children())[4].children() |
| )[1:], *list(cnn.children())[5:-2] |
| ) |
|
|
| self.conv1 = nn.Conv2d(128, 256, 3, padding=1) |
| self.bn1 = nn.BatchNorm2d(num_features=256) |
|
|
| self.fc1 = nn.Linear(2048 * 7 * 7, 256) |
| self.fc2 = nn.Linear(256, 7 * 7) |
|
|
| self.cls_fc = nn.Linear(256, 3) |
|
|
| self.criterion = nn.CrossEntropyLoss() |
|
|
| def forward(self, x): |
| |
| real = x[:, :3, :, :] |
| fake = x[:, 3:, :, :] |
|
|
| |
| real_features = F.relu(self.cnn_head(real)) |
| fake_features = F.relu(self.cnn_head(fake)) |
|
|
| |
| combined = torch.cat((real_features, fake_features), 1) |
|
|
| x = self.conv1(combined) |
| x = self.bn1(x) |
| x = F.relu(x) |
|
|
| x = self.cnn_tail(x) |
| x = x.view(-1, 2048 * 7 * 7) |
|
|
| |
| d = F.relu(self.fc1(x)) |
|
|
| |
| grid = self.fc2(d) |
|
|
| |
| cl = self.cls_fc(d) |
|
|
| return grid, cl |
|
|