Chen Zheng commited on
Commit
497a97b
·
1 Parent(s): 3167f6c

add PerceptualLoss

Browse files

Former-commit-id: 1a3bf6cbc3bf49fa0cc6818d027d1317c0d674fb

basicsr/archs/vgg_arch.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from torch import nn as nn
5
+ from torchvision.models import vgg as vgg
6
+
7
+ from basicsr.utils.registry import ARCH_REGISTRY
8
+
9
+ VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
10
+ NAMES = {
11
+ 'vgg11': [
12
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
13
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14
+ 'pool5'
15
+ ],
16
+ 'vgg13': [
17
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
18
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
19
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
20
+ ],
21
+ 'vgg16': [
22
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
23
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
24
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
25
+ 'pool5'
26
+ ],
27
+ 'vgg19': [
28
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
29
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
30
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
31
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
32
+ ]
33
+ }
34
+
35
+
36
+ def insert_bn(names):
37
+ """Insert bn layer after each conv.
38
+
39
+ Args:
40
+ names (list): The list of layer names.
41
+
42
+ Returns:
43
+ list: The list of layer names with bn layers.
44
+ """
45
+ names_bn = []
46
+ for name in names:
47
+ names_bn.append(name)
48
+ if 'conv' in name:
49
+ position = name.replace('conv', '')
50
+ names_bn.append('bn' + position)
51
+ return names_bn
52
+
53
+
54
+ @ARCH_REGISTRY.register()
55
+ class VGGFeatureExtractor(nn.Module):
56
+ """VGG network for feature extraction.
57
+
58
+ In this implementation, we allow users to choose whether use normalization
59
+ in the input feature and the type of vgg network. Note that the pretrained
60
+ path must fit the vgg type.
61
+
62
+ Args:
63
+ layer_name_list (list[str]): Forward function returns the corresponding
64
+ features according to the layer_name_list.
65
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
66
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
67
+ use_input_norm (bool): If True, normalize the input image. Importantly,
68
+ the input feature must in the range [0, 1]. Default: True.
69
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
70
+ Default: False.
71
+ requires_grad (bool): If true, the parameters of VGG network will be
72
+ optimized. Default: False.
73
+ remove_pooling (bool): If true, the max pooling operations in VGG net
74
+ will be removed. Default: False.
75
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
76
+ """
77
+
78
+ def __init__(self,
79
+ layer_name_list,
80
+ vgg_type='vgg19',
81
+ use_input_norm=True,
82
+ range_norm=False,
83
+ requires_grad=False,
84
+ remove_pooling=False,
85
+ pooling_stride=2):
86
+ super(VGGFeatureExtractor, self).__init__()
87
+
88
+ self.layer_name_list = layer_name_list
89
+ self.use_input_norm = use_input_norm
90
+ self.range_norm = range_norm
91
+
92
+ self.names = NAMES[vgg_type.replace('_bn', '')]
93
+ if 'bn' in vgg_type:
94
+ self.names = insert_bn(self.names)
95
+
96
+ # only borrow layers that will be used to avoid unused params
97
+ max_idx = 0
98
+ for v in layer_name_list:
99
+ idx = self.names.index(v)
100
+ if idx > max_idx:
101
+ max_idx = idx
102
+
103
+ if os.path.exists(VGG_PRETRAIN_PATH):
104
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
105
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
106
+ vgg_net.load_state_dict(state_dict)
107
+ else:
108
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
109
+
110
+ features = vgg_net.features[:max_idx + 1]
111
+
112
+ modified_net = OrderedDict()
113
+ for k, v in zip(self.names, features):
114
+ if 'pool' in k:
115
+ # if remove_pooling is true, pooling operation will be removed
116
+ if remove_pooling:
117
+ continue
118
+ else:
119
+ # in some cases, we may want to change the default stride
120
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
121
+ else:
122
+ modified_net[k] = v
123
+
124
+ self.vgg_net = nn.Sequential(modified_net)
125
+
126
+ if not requires_grad:
127
+ self.vgg_net.eval()
128
+ for param in self.parameters():
129
+ param.requires_grad = False
130
+ else:
131
+ self.vgg_net.train()
132
+ for param in self.parameters():
133
+ param.requires_grad = True
134
+
135
+ if self.use_input_norm:
136
+ # the mean is for image with range [0, 1]
137
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
138
+ # the std is for image with range [0, 1]
139
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
140
+
141
+ def forward(self, x):
142
+ """Forward function.
143
+
144
+ Args:
145
+ x (Tensor): Input tensor with shape (n, c, h, w).
146
+
147
+ Returns:
148
+ Tensor: Forward results.
149
+ """
150
+ if self.range_norm:
151
+ x = (x + 1) / 2
152
+ if self.use_input_norm:
153
+ x = (x - self.mean) / self.std
154
+
155
+ output = {}
156
+ for key, layer in self.vgg_net._modules.items():
157
+ x = layer(x)
158
+ if key in self.layer_name_list:
159
+ output[key] = x.clone()
160
+
161
+ return output
basicsr/losses/losses.py CHANGED
@@ -4,7 +4,7 @@ from torch import autograd as autograd
4
  from torch import nn as nn
5
  from torch.nn import functional as F
6
 
7
- # from basicsr.archs.vgg_arch import VGGFeatureExtractor
8
  from basicsr.utils.registry import LOSS_REGISTRY
9
  from .loss_util import weighted_loss
10
 
 
4
  from torch import nn as nn
5
  from torch.nn import functional as F
6
 
7
+ from basicsr.archs.vgg_arch import VGGFeatureExtractor
8
  from basicsr.utils.registry import LOSS_REGISTRY
9
  from .loss_util import weighted_loss
10