|
|
| import torch |
| import torchvision |
| import torch.nn as nn |
| import torchvision.models as models |
| from PIL import Image |
| from torchvision import transforms |
| import torch.nn.functional as F |
| import torch.optim as optim |
| import matplotlib.pyplot as plt |
| import torchvision.transforms as transforms |
| import copy |
| import torchvision.models as models |
| from PIL import Image |
| import numpy as np |
|
|
| |
|
|
|
|
| |
| class ContentLoss(nn.Module): |
|
|
| def __init__(self, target,): |
| super(ContentLoss, self).__init__() |
| ''' |
| we 'detach' the target content from the tree used |
| to dynamically compute the gradient: this is a stated value, |
| not a variable. Otherwise the forward method of the criterion |
| will throw an error. |
| ''' |
| self.target = target.detach() |
| |
| def forward(self, input): |
| self.loss = F.mse_loss(input, self.target) |
| return input |
|
|
| |
| def gram_matrix(input): |
| a, b, c, d = input.size() |
| |
| |
|
|
| features = input.view(a * b, c * d) |
|
|
| G = torch.mm(features, features.t()) |
|
|
| |
| |
| return G.div(a * b * c * d) |
|
|
| class StyleLoss(nn.Module): |
|
|
| def __init__(self, target_feature): |
| super(StyleLoss, self).__init__() |
| self.target = gram_matrix(target_feature).detach() |
|
|
| def forward(self, input): |
| G = gram_matrix(input) |
| self.loss = F.mse_loss(G, self.target) |
| return input |
|
|
| |
| |
| transform = transforms.Compose([ |
| transforms.Resize((128,128)), |
| transforms.ToTensor()]) |
|
|
| def image_transform(image): |
| |
| if image is not None: |
| if isinstance(image, str): |
| |
| image = Image.open(image).convert('RGB') |
| else: |
| |
| image = Image.fromarray(image.astype('uint8'), 'RGB') |
| |
| image = transform(image).unsqueeze(0) |
| return image |
|
|
|
|
| |
| |
| class Normalization(nn.Module): |
| def __init__(self, mean, std): |
| super(Normalization, self).__init__() |
| |
| |
| |
| self.mean = torch.tensor(mean).view(-1, 1, 1) |
| self.std = torch.tensor(std).view(-1, 1, 1) |
|
|
| def forward(self, img): |
| |
| return (img - self.mean) / self.std |
|
|
| |
| |
|
|
|
|
| content_layers_default = ['conv_4'] |
| style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5'] |
|
|
| def get_style_model_and_losses(cnn, normalization_mean, normalization_std, |
| style_img, content_img, |
| content_layers=content_layers_default, |
| style_layers=style_layers_default): |
| |
| normalization = Normalization(normalization_mean, normalization_std) |
|
|
| |
| |
| content_losses = [] |
| style_losses = [] |
|
|
| |
| |
| model = nn.Sequential(normalization) |
|
|
| i = 0 |
| for layer in cnn.children(): |
| if isinstance(layer, nn.Conv2d): |
| i += 1 |
| name = 'conv_{}'.format(i) |
| elif isinstance(layer, nn.ReLU): |
| name = 'relu_{}'.format(i) |
| |
| |
| |
| layer = nn.ReLU(inplace=False) |
| elif isinstance(layer, nn.MaxPool2d): |
| name = 'pool_{}'.format(i) |
| elif isinstance(layer, nn.BatchNorm2d): |
| name = 'bn_{}'.format(i) |
| else: |
| raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__)) |
|
|
| model.add_module(name, layer) |
|
|
| if name in content_layers: |
| |
| target = model(content_img).detach() |
| content_loss = ContentLoss(target) |
| model.add_module("content_loss_{}".format(i), content_loss) |
| content_losses.append(content_loss) |
|
|
| if name in style_layers: |
| |
| target_feature = model(style_img).detach() |
| style_loss = StyleLoss(target_feature) |
| model.add_module("style_loss_{}".format(i), style_loss) |
| style_losses.append(style_loss) |
|
|
| |
| for i in range(len(model) - 1, -1, -1): |
| if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss): |
| break |
|
|
| model = model[:(i + 1)] |
|
|
| return model, style_losses, content_losses |
| |
| def get_input_optimizer(input_img): |
| |
| optimizer = optim.LBFGS([input_img]) |
| return optimizer |
|
|