code
sssefe commited on
Commit
513ffa3
·
verified ·
1 Parent(s): 3b73cd5

Upload 22 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figs/DSCF_arch.png filter=lfs diff=lfs merge=lfs -text
37
+ utils/test.bmp filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 BinRen
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
NTIRE2025-EfficientSR.log ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,3 +1,70 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!--
2
+ * @Author: Yaozzz666
3
+ * @Date: 2025-03-21 13:49:25
4
+ * @LastEditors: Yaozzz666
5
+ * @LastEditTime: 2025-03-22 11:11:04
6
+ *
7
+ * Copyright (c) 2025 by ${Yaozzz666}, All Rights Reserved.
8
+ -->
9
+ # [NTIRE 2025 Challenge on Efficient Super-Resolution](https://cvlai.net/ntire/2025/) @ [CVPR 2025](https://cvpr.thecvf.com/)
10
+
11
+ ## Distillation Supervised ConvLora Finetuning for SR
12
+
13
+ <div align=center>
14
+ <img src="https://github.com/Yaozzz666/DSCF-SR/blob/main/figs/DSCF_arch.png" width="800px"/>
15
+ </div>
16
+
17
+ - An overview of our DSCF-SR
18
+
19
+ ## The Environments
20
+
21
+ The evaluation environments adopted by us is recorded in the `requirements.txt`. After you built your own basic Python (Python = 3.9 in our setting) setup via either *virtual environment* or *anaconda*, please try to keep similar to it via:
22
+
23
+ - Step1: install Pytorch first:
24
+ `pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117`
25
+
26
+ - Step2: install other libs via:
27
+ ```pip install -r requirements.txt```
28
+
29
+ or take it as a reference based on your original environments.
30
+
31
+ ## How to test the model?
32
+ 1. Run the [`run.sh`](./run.sh)
33
+ ```bash
34
+ CUDA_VISIBLE_DEVICES=0 python test_demo.py --data_dir [path to your data dir] --save_dir [path to your save dir] --model_id 23
35
+ ```
36
+ - Be sure the change the directories `--data_dir` and `--save_dir`.
37
+
38
+ ## How to calculate the number of parameters, FLOPs, and activations
39
+
40
+ ```python
41
+ from utils.model_summary import get_model_flops, get_model_activation
42
+ from models.team00_EFDN import EFDN
43
+ from fvcore.nn import FlopCountAnalysis
44
+
45
+ model = EFDN()
46
+
47
+ input_dim = (3, 256, 256) # set the input dimension
48
+ activations, num_conv = get_model_activation(model, input_dim)
49
+ activations = activations / 10 ** 6
50
+ print("{:>16s} : {:<.4f} [M]".format("#Activations", activations))
51
+ print("{:>16s} : {:<d}".format("#Conv2d", num_conv))
52
+
53
+ # The FLOPs calculation in previous NTIRE_ESR Challenge
54
+ # flops = get_model_flops(model, input_dim, False)
55
+ # flops = flops / 10 ** 9
56
+ # print("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
57
+
58
+ # fvcore is used in NTIRE2025_ESR for FLOPs calculation
59
+ input_fake = torch.rand(1, 3, 256, 256).to(device)
60
+ flops = FlopCountAnalysis(model, input_fake).total()
61
+ flops = flops/10**9
62
+ print("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
63
+
64
+ num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
65
+ num_parameters = num_parameters / 10 ** 6
66
+ print("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters))
67
+ ```
68
+
69
+ ## License and Acknowledgement
70
+ This code repository is release under [MIT License](LICENSE).
figs/DSCF_arch.png ADDED

Git LFS Details

  • SHA256: 2a773f0d18b5de473970820f653e49ba6bbef41a65cf3dc2210c657bda32c6db
  • Pointer size: 131 Bytes
  • Size of remote file: 577 kB
figs/logo.png ADDED
model_zoo/team00_EFDN.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71069f10ef234cd123f45ac8e69099f0a1fdc0c16afcfe2189e50071d47ce477
3
+ size 1153119
model_zoo/team23_DSCF.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:239e76bc4e4e738491c3963805c0a79f524400acf5fbde34b2cb1a855dd62cb1
3
+ size 535137
models/__pycache__/team00_EFDN.cpython-311.pyc ADDED
Binary file (25.8 kB). View file
 
models/__pycache__/team23_DSCF.cpython-311.pyc ADDED
Binary file (19.9 kB). View file
 
models/team00_EFDN.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class ESA(nn.Module):
8
+ def __init__(self, n_feats, conv):
9
+ super(ESA, self).__init__()
10
+ f = n_feats // 4
11
+ self.conv1 = conv(n_feats, f, kernel_size=1)
12
+ self.conv_f = conv(f, f, kernel_size=1)
13
+ self.conv_max = conv(f, f, kernel_size=3, padding=1)
14
+ self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
15
+ self.conv3 = conv(f, f, kernel_size=3, padding=1)
16
+ self.conv3_ = conv(f, f, kernel_size=3, padding=1)
17
+ self.conv4 = conv(f, n_feats, kernel_size=1)
18
+ self.sigmoid = nn.Sigmoid()
19
+ self.relu = nn.ReLU(inplace=True)
20
+
21
+ def forward(self, x):
22
+ c1_ = (self.conv1(x))
23
+ c1 = self.conv2(c1_)
24
+ v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
25
+ v_range = self.relu(self.conv_max(v_max))
26
+ c3 = self.relu(self.conv3(v_range))
27
+ c3 = self.conv3_(c3)
28
+ c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False)
29
+ cf = self.conv_f(c1_)
30
+ c4 = self.conv4(c3+cf)
31
+ m = self.sigmoid(c4)
32
+
33
+ return x * m
34
+
35
+
36
+ class conv(nn.Module):
37
+ def __init__(self, n_feats):
38
+ super(conv, self).__init__()
39
+ self.conv1x1 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)
40
+ self.act = nn.PReLU(num_parameters=n_feats)
41
+ def forward(self, x):
42
+ return self.act(self.conv1x1(x))
43
+
44
+
45
+ class Cell(nn.Module):
46
+ def __init__(self, n_feats=48, dynamic = True, deploy = False, L= None, with_13=False):
47
+ super(Cell, self).__init__()
48
+
49
+ self.conv1 = conv(n_feats)#nn.Conv2d(n_feats, n_feats, 1, 1, 0)
50
+ self.conv2 = EDBB_deploy(n_feats,n_feats)
51
+ self.conv3 = EDBB_deploy(n_feats,n_feats)
52
+
53
+ self.fuse = nn.Conv2d(n_feats*2, n_feats, 1, 1, 0)
54
+
55
+ self.att = ESA(n_feats, nn.Conv2d) #MAB(n_feats)# ENLCA(n_feats) #CoordAtt(n_feats,n_feats,10)#
56
+
57
+ self.branch = nn.ModuleList([nn.Conv2d(n_feats, n_feats//2, 1, 1, 0) for _ in range(4)])
58
+
59
+ def forward(self, x):
60
+ out1 = self.conv1(x)
61
+ out2 = self.conv2(out1)
62
+ out3 = self.conv3(out2)
63
+
64
+ # fuse [x, out1, out2, out3]
65
+ out = self.fuse(torch.cat([self.branch[0](x), self.branch[1](out1), self.branch[2](out2), self.branch[3](out3)], dim=1))
66
+ out = self.att(out)
67
+ out += x
68
+
69
+ return out
70
+
71
+
72
+ class EFDN(nn.Module):
73
+ def __init__(self, scale=4, in_channels=3, n_feats=48, out_channels=3):
74
+ super(EFDN, self).__init__()
75
+ self.head = nn.Conv2d(in_channels, n_feats, 3, 1, 1)
76
+ # body cells
77
+ self.cells = nn.ModuleList([Cell(n_feats) for _ in range(4)])
78
+
79
+ # fusion
80
+ self.local_fuse = nn.ModuleList([nn.Conv2d(n_feats*2, n_feats, 1, 1, 0) for _ in range(3)])
81
+
82
+ self.tail = nn.Sequential(
83
+ nn.Conv2d(n_feats, out_channels*(scale**2), 3, 1, 1),
84
+ nn.PixelShuffle(scale)
85
+ )
86
+
87
+ def forward(self, x):
88
+ # head
89
+ out0 = self.head(x)
90
+
91
+ # body cells
92
+ out1 = self.cells[0](out0)
93
+ out2 = self.cells[1](out1)
94
+ out2_fuse = self.local_fuse[0](torch.cat([out1, out2], dim=1))
95
+ out3 = self.cells[2](out2_fuse)
96
+ out3_fuse = self.local_fuse[1](torch.cat([out2, out3], dim=1))
97
+ out4 = self.cells[3](out3_fuse)
98
+ out4_fuse = self.local_fuse[2](torch.cat([out2, out4], dim=1))
99
+
100
+ out = out4_fuse + out0
101
+
102
+ # tail
103
+ out = self.tail(out)
104
+
105
+ return out.clamp(0,1)
106
+
107
+
108
+ # -------------------------------------------------
109
+ #This part code based on DBB(https://github.com/DingXiaoH/DiverseBranchBlock) and ECB(https://github.com/xindongzhang/ECBSR)
110
+ def multiscale(kernel, target_kernel_size):
111
+ H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
112
+ W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
113
+ return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])
114
+
115
+
116
+ class SeqConv3x3(nn.Module):
117
+ def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier):
118
+ super(SeqConv3x3, self).__init__()
119
+
120
+ self.type = seq_type
121
+ self.inp_planes = inp_planes
122
+ self.out_planes = out_planes
123
+
124
+ if self.type == 'conv1x1-conv3x3':
125
+ self.mid_planes = int(out_planes * depth_multiplier)
126
+ conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0)
127
+ self.k0 = conv0.weight
128
+ self.b0 = conv0.bias
129
+
130
+ conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3)
131
+ self.k1 = conv1.weight
132
+ self.b1 = conv1.bias
133
+
134
+ elif self.type == 'conv1x1-sobelx':
135
+ conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
136
+ self.k0 = conv0.weight
137
+ self.b0 = conv0.bias
138
+
139
+ # init scale & bias
140
+ scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
141
+ self.scale = nn.Parameter(scale)
142
+ # bias = 0.0
143
+ # bias = [bias for c in range(self.out_planes)]
144
+ # bias = torch.FloatTensor(bias)
145
+ bias = torch.randn(self.out_planes) * 1e-3
146
+ bias = torch.reshape(bias, (self.out_planes,))
147
+ self.bias = nn.Parameter(bias)
148
+ # init mask
149
+ self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
150
+ for i in range(self.out_planes):
151
+ self.mask[i, 0, 0, 0] = 1.0
152
+ self.mask[i, 0, 1, 0] = 2.0
153
+ self.mask[i, 0, 2, 0] = 1.0
154
+ self.mask[i, 0, 0, 2] = -1.0
155
+ self.mask[i, 0, 1, 2] = -2.0
156
+ self.mask[i, 0, 2, 2] = -1.0
157
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
158
+
159
+ elif self.type == 'conv1x1-sobely':
160
+ conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
161
+ self.k0 = conv0.weight
162
+ self.b0 = conv0.bias
163
+
164
+ # init scale & bias
165
+ scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
166
+ self.scale = nn.Parameter(torch.FloatTensor(scale))
167
+ # bias = 0.0
168
+ # bias = [bias for c in range(self.out_planes)]
169
+ # bias = torch.FloatTensor(bias)
170
+ bias = torch.randn(self.out_planes) * 1e-3
171
+ bias = torch.reshape(bias, (self.out_planes,))
172
+ self.bias = nn.Parameter(torch.FloatTensor(bias))
173
+ # init mask
174
+ self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
175
+ for i in range(self.out_planes):
176
+ self.mask[i, 0, 0, 0] = 1.0
177
+ self.mask[i, 0, 0, 1] = 2.0
178
+ self.mask[i, 0, 0, 2] = 1.0
179
+ self.mask[i, 0, 2, 0] = -1.0
180
+ self.mask[i, 0, 2, 1] = -2.0
181
+ self.mask[i, 0, 2, 2] = -1.0
182
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
183
+
184
+ elif self.type == 'conv1x1-laplacian':
185
+ conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
186
+ self.k0 = conv0.weight
187
+ self.b0 = conv0.bias
188
+
189
+ # init scale & bias
190
+ scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
191
+ self.scale = nn.Parameter(torch.FloatTensor(scale))
192
+ # bias = 0.0
193
+ # bias = [bias for c in range(self.out_planes)]
194
+ # bias = torch.FloatTensor(bias)
195
+ bias = torch.randn(self.out_planes) * 1e-3
196
+ bias = torch.reshape(bias, (self.out_planes,))
197
+ self.bias = nn.Parameter(torch.FloatTensor(bias))
198
+ # init mask
199
+ self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
200
+ for i in range(self.out_planes):
201
+ self.mask[i, 0, 0, 1] = 1.0
202
+ self.mask[i, 0, 1, 0] = 1.0
203
+ self.mask[i, 0, 1, 2] = 1.0
204
+ self.mask[i, 0, 2, 1] = 1.0
205
+ self.mask[i, 0, 1, 1] = -4.0
206
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
207
+ else:
208
+ raise ValueError('the type of seqconv is not supported!')
209
+
210
+ def forward(self, x):
211
+ if self.type == 'conv1x1-conv3x3':
212
+ # conv-1x1
213
+ y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
214
+ # explicitly padding with bias
215
+ y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
216
+ b0_pad = self.b0.view(1, -1, 1, 1)
217
+ y0[:, :, 0:1, :] = b0_pad
218
+ y0[:, :, -1:, :] = b0_pad
219
+ y0[:, :, :, 0:1] = b0_pad
220
+ y0[:, :, :, -1:] = b0_pad
221
+ # conv-3x3
222
+ y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
223
+ else:
224
+ y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
225
+ # explicitly padding with bias
226
+ y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
227
+ b0_pad = self.b0.view(1, -1, 1, 1)
228
+ y0[:, :, 0:1, :] = b0_pad
229
+ y0[:, :, -1:, :] = b0_pad
230
+ y0[:, :, :, 0:1] = b0_pad
231
+ y0[:, :, :, -1:] = b0_pad
232
+ # conv-3x3
233
+ y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes)
234
+ return y1
235
+
236
+ def rep_params(self):
237
+ device = self.k0.get_device()
238
+ if device < 0:
239
+ device = None
240
+
241
+ if self.type == 'conv1x1-conv3x3':
242
+ # re-param conv kernel
243
+ RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
244
+ # re-param conv bias
245
+ RB = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
246
+ RB = F.conv2d(input=RB, weight=self.k1).view(-1,) + self.b1
247
+ else:
248
+ tmp = self.scale * self.mask
249
+ k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3), device=device)
250
+ for i in range(self.out_planes):
251
+ k1[i, i, :, :] = tmp[i, 0, :, :]
252
+ b1 = self.bias
253
+ # re-param conv kernel
254
+ RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
255
+ # re-param conv bias
256
+ RB = torch.ones(1, self.out_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
257
+ RB = F.conv2d(input=RB, weight=k1).view(-1,) + b1
258
+ return RK, RB
259
+
260
+
261
+ class EDBB(nn.Module):
262
+ def __init__(self, inp_planes, out_planes, depth_multiplier=None, act_type='prelu', with_idt = False, deploy=False, with_13=False, gv=False):
263
+ super(EDBB, self).__init__()
264
+
265
+ self.deploy = deploy
266
+ self.act_type = act_type
267
+
268
+ self.inp_planes = inp_planes
269
+ self.out_planes = out_planes
270
+
271
+ self.gv = gv
272
+
273
+ if depth_multiplier is None:
274
+ self.depth_multiplier = 1.0
275
+ else:
276
+ self.depth_multiplier = depth_multiplier # For mobilenet, it is better to have 2X internal channels
277
+
278
+ if deploy:
279
+ self.rep_conv = nn.Conv2d(in_channels=inp_planes, out_channels=out_planes, kernel_size=3, stride=1,
280
+ padding=1, bias=True)
281
+ else:
282
+ self.with_13 = with_13
283
+ if with_idt and (self.inp_planes == self.out_planes):
284
+ self.with_idt = True
285
+ else:
286
+ self.with_idt = False
287
+
288
+ self.rep_conv = nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=3, padding=1)
289
+ self.conv1x1 = nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
290
+ self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.inp_planes, self.out_planes, self.depth_multiplier)
291
+ self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes, self.out_planes, -1)
292
+ self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes, self.out_planes, -1)
293
+ self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes, self.out_planes, -1)
294
+
295
+ if self.act_type == 'prelu':
296
+ self.act = nn.PReLU(num_parameters=self.out_planes)
297
+ elif self.act_type == 'relu':
298
+ self.act = nn.ReLU(inplace=True)
299
+ elif self.act_type == 'rrelu':
300
+ self.act = nn.RReLU(lower=-0.05, upper=0.05)
301
+ elif self.act_type == 'softplus':
302
+ self.act = nn.Softplus()
303
+ elif self.act_type == 'linear':
304
+ pass
305
+ else:
306
+ raise ValueError('The type of activation if not support!')
307
+
308
+ def forward(self, x):
309
+ if self.deploy:
310
+ y = self.rep_conv(x)
311
+ elif self.gv:
312
+ y = self.rep_conv(x) + \
313
+ self.conv1x1_sbx(x) + \
314
+ self.conv1x1_sby(x) + \
315
+ self.conv1x1_lpl(x) + x
316
+ else:
317
+ y = self.rep_conv(x) + \
318
+ self.conv1x1(x) + \
319
+ self.conv1x1_sbx(x) + \
320
+ self.conv1x1_sby(x) + \
321
+ self.conv1x1_lpl(x)
322
+ #self.conv1x1_3x3(x) + \
323
+ if self.with_idt:
324
+ y += x
325
+ if self.with_13:
326
+ y += self.conv1x1_3x3(x)
327
+
328
+ if self.act_type != 'linear':
329
+ y = self.act(y)
330
+ return y
331
+
332
+ def switch_to_gv(self):
333
+ if self.gv:
334
+ return
335
+ self.gv = True
336
+
337
+ K0, B0 = self.rep_conv.weight, self.rep_conv.bias
338
+ K1, B1 = self.conv1x1_3x3.rep_params()
339
+ K5, B5 = multiscale(self.conv1x1.weight,3), self.conv1x1.bias
340
+ RK, RB = (K0+K5), (B0+B5)
341
+ if self.with_13:
342
+ RK, RB = RK + K1, RB + B1
343
+
344
+ self.rep_conv.weight.data = RK
345
+ self.rep_conv.bias.data = RB
346
+
347
+ for para in self.parameters():
348
+ para.detach_()
349
+
350
+
351
+ def switch_to_deploy(self):
352
+
353
+ if self.deploy:
354
+ return
355
+ self.deploy = True
356
+
357
+ K0, B0 = self.rep_conv.weight, self.rep_conv.bias
358
+ K1, B1 = self.conv1x1_3x3.rep_params()
359
+ K2, B2 = self.conv1x1_sbx.rep_params()
360
+ K3, B3 = self.conv1x1_sby.rep_params()
361
+ K4, B4 = self.conv1x1_lpl.rep_params()
362
+ K5, B5 = multiscale(self.conv1x1.weight,3), self.conv1x1.bias
363
+ if self.gv:
364
+ RK, RB = (K0+K2+K3+K4), (B0+B2+B3+B4)
365
+ else:
366
+ RK, RB = (K0+K2+K3+K4+K5), (B0+B2+B3+B4+B5)
367
+ if self.with_13:
368
+ RK, RB = RK + K1, RB + B1
369
+ if self.with_idt:
370
+ device = RK.get_device()
371
+ if device < 0:
372
+ device = None
373
+ K_idt = torch.zeros(self.out_planes, self.out_planes, 3, 3, device=device)
374
+ for i in range(self.out_planes):
375
+ K_idt[i, i, 1, 1] = 1.0
376
+ B_idt = 0.0
377
+ RK, RB = RK + K_idt, RB + B_idt
378
+
379
+
380
+ self.rep_conv = nn.Conv2d(in_channels=self.inp_planes, out_channels=self.out_planes, kernel_size=3, stride=1,
381
+ padding=1, bias=True)
382
+ self.rep_conv.weight.data = RK
383
+ self.rep_conv.bias.data = RB
384
+
385
+ for para in self.parameters():
386
+ para.detach_()
387
+
388
+ #self.__delattr__('conv3x3')
389
+ self.__delattr__('conv1x1_3x3')
390
+ self.__delattr__('conv1x1')
391
+ self.__delattr__('conv1x1_sbx')
392
+ self.__delattr__('conv1x1_sby')
393
+ self.__delattr__('conv1x1_lpl')
394
+
395
+
396
+ class EDBB_deploy(nn.Module):
397
+ def __init__(self, inp_planes, out_planes):
398
+ super(EDBB_deploy, self).__init__()
399
+
400
+ self.rep_conv = nn.Conv2d(in_channels=inp_planes, out_channels=out_planes, kernel_size=3, stride=1,
401
+ padding=1, bias=True)
402
+
403
+ self.act = nn.PReLU(num_parameters=out_planes)
404
+
405
+ def forward(self, x):
406
+ y = self.rep_conv(x)
407
+ y = self.act(y)
408
+
409
+ return y
410
+ # -------------------------------------------------
models/team23_DSCF.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import torch
3
+ from torch import nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+ from typing import Optional, List
7
+ # from IPython import embed
8
+
9
+ class LoRALayer():
10
+ def __init__(
11
+ self,
12
+ r: int,
13
+ lora_alpha: int,
14
+ lora_dropout: float,
15
+ merge_weights: bool,
16
+ ):
17
+ self.r = r
18
+ self.lora_alpha = lora_alpha
19
+ # Optional dropout
20
+ if lora_dropout > 0.:
21
+ self.lora_dropout = nn.Dropout(p=lora_dropout)
22
+ else:
23
+ self.lora_dropout = lambda x: x
24
+ # Mark the weight as unmerged
25
+ self.merged = False
26
+ self.merge_weights = merge_weights
27
+
28
+ class Lora_Conv2d(nn.Conv2d, LoRALayer):
29
+ # LoRA implemented in a dense layer
30
+ def __init__(
31
+ self,
32
+ in_channels: int,
33
+ out_channels: int,
34
+ kernel_size: int,
35
+ r: int = 0,
36
+ lora_alpha: int = 1,
37
+ lora_dropout: float = 0.,
38
+ merge_weights: bool = True,
39
+ **kwargs
40
+ ):
41
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
42
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
43
+ merge_weights=merge_weights)
44
+ assert type(kernel_size) is int
45
+ # print("in init")
46
+ # embed()
47
+ # Actual trainable parameters
48
+ if r > 0:
49
+ self.lora_A = nn.Parameter(
50
+ self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
51
+ )
52
+ self.lora_B = nn.Parameter(
53
+ self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
54
+ )
55
+ self.scaling = self.lora_alpha / self.r
56
+ # Freezing the pre-trained weight matrix
57
+ self.weight.requires_grad = False
58
+ # Freeze the bias
59
+ # if self.bias is not None:
60
+ # self.bias.requires_grad = False
61
+ self.reset_parameters()
62
+
63
+ def reset_parameters(self):
64
+ nn.Conv2d.reset_parameters(self)
65
+ if hasattr(self, 'lora_A'):
66
+ # initialize A the same way as the default for nn.Linear and B to zero
67
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
68
+ nn.init.zeros_(self.lora_B)
69
+
70
+ def train(self, mode: bool = True): # True for train and False for eval
71
+
72
+ nn.Conv2d.train(self, mode)
73
+ if mode:
74
+ if self.merge_weights and self.merged:
75
+ # Make sure that the weights are not merged
76
+ self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
77
+ self.merged = False
78
+ else:
79
+ # print("test")
80
+ # embed()
81
+ if self.merge_weights and not self.merged:
82
+ # print("merging")
83
+ # embed()
84
+ # Merge the weights and mark it
85
+ self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
86
+ self.merged = True
87
+
88
+ def forward(self, x: torch.Tensor):
89
+ # print(f"LoRA merged status: {self.merged}")
90
+ if self.r > 0 and not self.merged:
91
+ # print(f"lora_A: {self.lora_A}")
92
+ # print(f"lora_B: {self.lora_B}")
93
+ # print(f"LoRA contribution: {(self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling}")
94
+ return F.conv2d(
95
+ x,
96
+ self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
97
+ self.bias, self.stride, self.padding, self.dilation, self.groups
98
+ )
99
+
100
+ return nn.Conv2d.forward(self, x)
101
+ def _make_pair(value):
102
+ if isinstance(value, int):
103
+ value = (value,) * 2
104
+ return value
105
+
106
+
107
+ def conv_layer(in_channels,
108
+ out_channels,
109
+ kernel_size,
110
+ bias=True):
111
+ """
112
+ Re-write convolution layer for adaptive `padding`.
113
+ """
114
+ kernel_size = _make_pair(kernel_size)
115
+ padding = (int((kernel_size[0] - 1) / 2),
116
+ int((kernel_size[1] - 1) / 2))
117
+ return nn.Conv2d(in_channels,
118
+ out_channels,
119
+ kernel_size,
120
+ padding=padding,
121
+ bias=bias)
122
+
123
+
124
+ def activation(act_type, inplace=True, neg_slope=0.05, n_prelu=1):
125
+ """
126
+ Activation functions for ['relu', 'lrelu', 'prelu'].
127
+ Parameters
128
+ ----------
129
+ act_type: str
130
+ one of ['relu', 'lrelu', 'prelu'].
131
+ inplace: bool
132
+ whether to use inplace operator.
133
+ neg_slope: float
134
+ slope of negative region for `lrelu` or `prelu`.
135
+ n_prelu: int
136
+ `num_parameters` for `prelu`.
137
+ ----------
138
+ """
139
+ act_type = act_type.lower()
140
+ if act_type == 'relu':
141
+ layer = nn.ReLU(inplace)
142
+ elif act_type == 'lrelu':
143
+ layer = nn.LeakyReLU(neg_slope, inplace)
144
+ elif act_type == 'prelu':
145
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
146
+ else:
147
+ raise NotImplementedError(
148
+ 'activation layer [{:s}] is not found'.format(act_type))
149
+ return layer
150
+
151
+
152
+ def sequential(*args):
153
+ """
154
+ Modules will be added to the a Sequential Container in the order they
155
+ are passed.
156
+
157
+ Parameters
158
+ ----------
159
+ args: Definition of Modules in order.
160
+ -------
161
+ """
162
+ if len(args) == 1:
163
+ if isinstance(args[0], OrderedDict):
164
+ raise NotImplementedError(
165
+ 'sequential does not support OrderedDict input.')
166
+ return args[0]
167
+ modules = []
168
+ for module in args:
169
+ if isinstance(module, nn.Sequential):
170
+ for submodule in module.children():
171
+ modules.append(submodule)
172
+ elif isinstance(module, nn.Module):
173
+ modules.append(module)
174
+ return nn.Sequential(*modules)
175
+
176
+
177
+ def pixelshuffle_block(in_channels,
178
+ out_channels,
179
+ upscale_factor=2,
180
+ kernel_size=3):
181
+ """
182
+ Upsample features according to `upscale_factor`.
183
+ """
184
+ conv = conv_layer(in_channels,
185
+ out_channels * (upscale_factor ** 2),
186
+ kernel_size)
187
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
188
+ return sequential(conv, pixel_shuffle)
189
+
190
+ class Conv3XC(nn.Module):
191
+ def __init__(self, c_in, c_out, gain1=1, gain2=0, s=1, bias=True, relu=False):
192
+ super(Conv3XC, self).__init__()
193
+ self.weight_concat = None
194
+ self.bias_concat = None
195
+ self.update_params_flag = False
196
+ self.stride = s
197
+ self.has_relu = relu
198
+
199
+
200
+ self.eval_conv = nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, padding=1, stride=s, bias=bias)
201
+
202
+ def forward(self, x):
203
+ out = self.eval_conv(x)
204
+ if self.has_relu:
205
+ out = F.leaky_relu(out, negative_slope=0.05)
206
+ return out
207
+
208
+
209
+ class SPAB(nn.Module):
210
+ def __init__(self,
211
+ in_channels,
212
+ mid_channels=None,
213
+ out_channels=None,
214
+ bias=False):
215
+ super(SPAB, self).__init__()
216
+ if mid_channels is None:
217
+ mid_channels = in_channels
218
+ if out_channels is None:
219
+ out_channels = in_channels
220
+
221
+ self.in_channels = in_channels
222
+ self.c1_r = Conv3XC(in_channels, mid_channels, gain1=2, s=1)
223
+ self.c2_r = Conv3XC(mid_channels, mid_channels, gain1=2, s=1)
224
+ self.c3_r = Conv3XC(mid_channels, out_channels, gain1=2, s=1)
225
+ self.act1 = torch.nn.SiLU(inplace=True)
226
+ # self.act2 = activation('lrelu', neg_slope=0.1, inplace=True)
227
+
228
+ def forward(self, x):
229
+ out1 = (self.c1_r(x))
230
+ out1_act = self.act1(out1)
231
+
232
+ out2 = (self.c2_r(out1_act))
233
+ out2_act = self.act1(out2)
234
+
235
+ out3 = (self.c3_r(out2_act))
236
+
237
+ sim_att = torch.sigmoid(out3) - 0.5
238
+ out = (out3 + x) * sim_att
239
+ # out = out3 * sim_att
240
+ # return out, out1, sim_att
241
+ return out, out1, out2,out3
242
+
243
+
244
+ class DSCF(nn.Module):
245
+ """
246
+ Swift Parameter-free Attention Network for Efficient Super-Resolution
247
+ """
248
+
249
+ def __init__(self,
250
+ num_in_ch,
251
+ num_out_ch,
252
+ feature_channels=26,
253
+ upscale=4,
254
+ bias=True,
255
+ img_range=255.,
256
+ rgb_mean=(0.4488, 0.4371, 0.4040)
257
+ ):
258
+ super(DSCF, self).__init__()
259
+
260
+ in_channels = num_in_ch
261
+ out_channels = num_out_ch
262
+ self.img_range = img_range
263
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
264
+
265
+ self.conv_1 = Conv3XC(in_channels, feature_channels, gain1=2, s=1)
266
+ self.block_1 = SPAB(feature_channels, bias=bias)
267
+ self.block_2 = SPAB(feature_channels, bias=bias)
268
+ self.block_3 = SPAB(feature_channels, bias=bias)
269
+ self.block_4 = SPAB(feature_channels, bias=bias)
270
+ self.block_5 = SPAB(feature_channels, bias=bias)
271
+ self.block_6 = SPAB(feature_channels, bias=bias)
272
+
273
+ self.conv_cat = conv_layer(feature_channels * 4, feature_channels, kernel_size=1, bias=True)
274
+ self.conv_2 = Conv3XC(feature_channels, feature_channels, gain1=2, s=1)
275
+
276
+ self.upsampler = pixelshuffle_block(feature_channels, out_channels, upscale_factor=upscale)
277
+
278
+ # 指定需要替换 LoRA 层的子模块名称
279
+ # desired_submodules = ["conv_1.eval_conv",
280
+ # "block_1.c1_r.eval_conv","block_1.c2_r.eval_conv","block_1.c3_r.eval_conv",
281
+ # "block_2.c1_r.eval_conv","block_2.c2_r.eval_conv","block_2.c3_r.eval_conv",
282
+ # "block_3.c1_r.eval_conv","block_3.c2_r.eval_conv","block_3.c3_r.eval_conv",
283
+ # "block_4.c1_r.eval_conv","block_4.c2_r.eval_conv","block_4.c3_r.eval_conv",
284
+ # "block_5.c1_r.eval_conv","block_5.c2_r.eval_conv","block_5.c3_r.eval_conv",
285
+ # "block_6.c1_r.eval_conv","block_6.c2_r.eval_conv","block_6.c3_r.eval_conv",
286
+ # "conv_2.eval_conv",
287
+ # "conv_cat",
288
+ # "upsampler.0"]
289
+
290
+ # desired_submodules = ["conv_2.eval_conv","upsampler.0"]
291
+ # # 替换需要 LoRA 处理的层
292
+ # self.replace_layers(desired_submodules)
293
+
294
+ # self.mark_only_lora_as_trainable(bias='none')
295
+ # 分层LoRA配置字典(模块名: (r, lora_alpha))
296
+ # self.lora_config = {
297
+ # # 高频重建核心层 (最高优先级)
298
+ # "conv_2.eval_conv": (8, 16), # 最大秩
299
+ # "upsampler.0": (8, 16), # 高秩
300
+
301
+ # # 中间处理层 (梯度传播关键路径)
302
+ # **{f"block_{i}.c{j}_r.eval_conv": (2, 4)
303
+ # for i in [2,3,4,5] # block_2到block_5
304
+ # for j in [1,2,3]}, # 每个block的三个卷积
305
+
306
+ # # 首尾层 (适度调整)
307
+ # "block_1.c1_r.eval_conv": (2, 4),
308
+ # "block_1.c2_r.eval_conv": (2, 4),
309
+ # "block_1.c3_r.eval_conv": (2, 4),
310
+ # "block_6.c1_r.eval_conv": (2, 4),
311
+ # "block_6.c2_r.eval_conv": (2, 4),
312
+ # "block_6.c3_r.eval_conv": (2, 4),
313
+ # }
314
+
315
+ # # 替换需要 LoRA 处理的层
316
+ # self.replace_layers_with_strategy()
317
+
318
+ # 冻结非LoRA参数
319
+ # self.mark_only_lora_as_trainable(bias='none')
320
+ # self.cuda()(torch.randn(1, 3, 256, 256).cuda())
321
+ # self.eval().cuda()
322
+ self.eval().cuda()
323
+ input_tensor = torch.randn(1, 3, 256, 256).cuda()
324
+ output = self(input_tensor)
325
+ # 确保 LoRA 层参数可训练
326
+ # print("可训练参数:")
327
+ # for name, param in self.named_parameters():
328
+ # if param.requires_grad:
329
+ # print(f"{name}: {param.shape}")
330
+
331
+
332
+ # def replace_layers_with_strategy(self):
333
+ # """根据分层策略替换卷积层"""
334
+ # for full_name, (r, alpha) in self.lora_config.items():
335
+ # parent, child_name = self._get_parent_and_child(full_name)
336
+ # if parent is None:
337
+ # # print(f"⚠️ Skip {full_name}: module not found")
338
+ # continue
339
+
340
+ # original_conv = getattr(parent, child_name, None)
341
+ # if not isinstance(original_conv, nn.Conv2d):
342
+ # # print(f"⚠️ {full_name} is not Conv2d (found {type(original_conv)})")
343
+ # continue
344
+
345
+ # # 动态设置参数
346
+ # new_layer = Lora_Conv2d(
347
+ # in_channels=original_conv.in_channels,
348
+ # out_channels=original_conv.out_channels,
349
+ # kernel_size=original_conv.kernel_size[0],
350
+ # stride=original_conv.stride,
351
+ # padding=original_conv.padding,
352
+ # bias=original_conv.bias is not None,
353
+ # r=r, # 动态设置秩
354
+ # lora_alpha=alpha # 动态设置缩放系数
355
+ # )
356
+
357
+ # # 继承原始权重
358
+ # with torch.no_grad():
359
+ # new_layer.weight.copy_(original_conv.weight)
360
+ # if original_conv.bias is not None:
361
+ # new_layer.bias.copy_(original_conv.bias)
362
+
363
+ # setattr(parent, child_name, new_layer)
364
+ # # print(f"✅ {full_name} => r={r}, alpha={alpha}")
365
+
366
+ # def _get_parent_and_child(self, module_name):
367
+ # """
368
+ # 获取模块的父级模块和子模块名称
369
+ # 例如:
370
+ # module_name = "block_5.c1_r.eval_conv"
371
+ # 则返回 (model.block_5.c1_r, "eval_conv")
372
+ # """
373
+ # parts = module_name.split(".")
374
+ # parent = self
375
+ # for part in parts[:-1]: # 遍历到倒数第二个
376
+ # if hasattr(parent, part):
377
+ # parent = getattr(parent, part)
378
+ # else:
379
+ # return None, None # 没找到路径
380
+ # return parent, parts[-1] # 返回父模块和子模块名称
381
+
382
+ # def replace_layers(self, desired_submodules):
383
+ # """
384
+ # 遍历模型的子模块,将符合条件的层替换为 Lora_Conv2d
385
+ # """
386
+ # # 替换conv_layer
387
+ # for name, module in self._modules.items():
388
+ # if name in desired_submodules:
389
+ # print('--------------------self._modules.items--------------------------')
390
+ # print(name)
391
+ # if isinstance(module, nn.Conv2d):
392
+ # print(f"Replacing {name} with Lora_Conv2d")
393
+ # setattr(self, name, Lora_Conv2d(
394
+ # module.in_channels,
395
+ # module.out_channels,
396
+ # kernel_size=module.kernel_size[0],
397
+ # stride=module.stride,
398
+ # padding=module.padding,
399
+ # bias=True,
400
+ # r=2,
401
+ # lora_alpha=2
402
+ # ))
403
+
404
+ # def mark_only_lora_as_trainable(self, bias: str = 'none'):
405
+ # """
406
+ # 只训练 LoRA 相关参数,而冻结所有其他参数。
407
+
408
+ # 参数:
409
+ # - bias: 'none' (不训练 bias), 'all' (训练所有 bias), 'lora_only' (只训练 LoRA 层的 bias)
410
+ # """
411
+ # # 冻结所有非 LoRA 参数
412
+ # # for n, p in self.named_parameters():
413
+ # # if 'lora_' not in n:
414
+ # # p.requires_grad = False
415
+ # for n, p in self.named_parameters():
416
+ # if 'lora_' not in n:
417
+ # p.requires_grad = False # 冻结非 LoRA 参数
418
+ # else:
419
+ # p.requires_grad = True # 解冻 LoRA 参数
420
+
421
+ # if bias == 'none':
422
+ # return
423
+ # elif bias == 'all':
424
+ # for n, p in self.named_parameters():
425
+ # if 'bias' in n:
426
+ # p.requires_grad = True
427
+ # elif bias == 'lora_only':
428
+ # for m in self.modules():
429
+ # if isinstance(m, LoRALayer) and hasattr(m, 'bias') and m.bias is not None:
430
+ # m.bias.requires_grad = True
431
+ # else:
432
+ # raise NotImplementedError(f"未知 bias 选项: {bias}")
433
+ def forward(self, x, return_features=False):
434
+ # features = []
435
+ self.mean = self.mean.type_as(x)
436
+ x = (x - self.mean) * self.img_range
437
+
438
+ out_feature = self.conv_1(x)
439
+
440
+ out_b1, out_b1_1, out_b1_2, out_b1_3 = self.block_1(out_feature)
441
+ out_b2, out_b2_1, out_b2_2, out_b2_3 = self.block_2(out_b1)
442
+ out_b3, out_b3_1, out_b3_2, out_b3_3 = self.block_3(out_b2)
443
+
444
+ out_b4, _, _, _ = self.block_4(out_b3)
445
+ out_b5, _, _, _ = self.block_5(out_b4)
446
+ out_b6, out_b5_2, _, _ = self.block_6(out_b5)
447
+
448
+ out_b6 = self.conv_2(out_b6)
449
+ out = self.conv_cat(torch.cat([out_feature, out_b6, out_b1, out_b5_2], 1))
450
+ output = self.upsampler(out)
451
+
452
+ # features.append(out_b1_1)
453
+ # features.append(out_b1_2)
454
+ # features.append(out_b1_3)
455
+ # features.append(out_b2_1)
456
+ # features.append(out_b2_2)
457
+ # features.append(out_b2_3)
458
+ # features.append(out_b3_1)
459
+ # features.append(out_b3_2)
460
+ # features.append(out_b3_3)
461
+
462
+
463
+ if return_features:
464
+ return output, features # Return output and intermediate features
465
+ return output
466
+
requirements.txt ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
+ anyio==4.0.0
3
+ appdirs==1.4.4
4
+ beartype==0.16.4
5
+ blessed==1.20.0
6
+ brotlipy==0.7.0
7
+ cachetools==5.3.1
8
+ certifi==2023.7.22
9
+ cffi==1.15.1
10
+ charset-normalizer==2.0.4
11
+ click==8.1.7
12
+ clip==0.2.0
13
+ cmake==3.27.7
14
+ contourpy==1.1.1
15
+ cryptography==41.0.3
16
+ cycler==0.12.1
17
+ docker-pycreds==0.4.0
18
+ dpcpp-cpp-rt==2024.0.2
19
+ einops==0.7.0
20
+ ema-pytorch==0.3.1
21
+ exceptiongroup==1.1.3
22
+ filelock==3.12.4
23
+ fonttools==4.43.1
24
+ fsspec==2023.9.2
25
+ ftfy==6.1.1
26
+ fvcore==0.1.5.post20221221
27
+ gitdb==4.0.10
28
+ GitPython==3.1.40
29
+ google-auth==2.23.3
30
+ google-auth-oauthlib==1.0.0
31
+ gpustat==1.1.1
32
+ grpcio==1.59.0
33
+ h11==0.14.0
34
+ h5py==3.10.0
35
+ huggingface-hub==0.18.0
36
+ idna==3.4
37
+ imageio==2.31.5
38
+ importlib-metadata==6.8.0
39
+ importlib-resources==6.1.0
40
+ intel-cmplr-lib-rt==2024.0.2
41
+ intel-cmplr-lic-rt==2024.0.2
42
+ intel-opencl-rt==2024.0.2
43
+ intel-openmp==2024.0.2
44
+ iopath==0.1.10
45
+ itsdangerous==2.1.2
46
+ kiwisolver==1.4.5
47
+ lit==17.0.3
48
+ lpips==0.1.4
49
+ Markdown==3.5
50
+ markdown-it-py==2.2.0
51
+ MarkupSafe==2.1.3
52
+ matplotlib==3.7.3
53
+ mdurl==0.1.2
54
+ mkl==2024.0.0
55
+ mkl-fft==1.3.6
56
+ mkl-random==1.2.2
57
+ mkl-service==2.4.0
58
+ mpmath==1.3.0
59
+ multidict==6.0.4
60
+ networkx==3.1
61
+ numpy==1.24.3
62
+ nvidia-ml-py==12.535.108
63
+ oauthlib==3.2.2
64
+ opencv-python==4.8.1.78
65
+ ordered-set==4.1.0
66
+ orjson==3.8.9
67
+ packaging==23.1
68
+ pandas==1.5.3
69
+ pathtools==0.1.2
70
+ Pillow==9.4.0
71
+ portalocker==2.8.2
72
+ protobuf==4.24.4
73
+ psutil==5.9.6
74
+ py-cpuinfo==9.0.0
75
+ pyasn1==0.5.0
76
+ pyasn1-modules==0.3.0
77
+ pycparser==2.21
78
+ pydantic==1.10.7
79
+ Pygments==2.16.1
80
+ PyJWT==2.6.0
81
+ pyOpenSSL==23.2.0
82
+ pyparsing==3.0.9
83
+ PySocks==1.7.1
84
+ python-dateutil==2.8.2
85
+ pytorch-fid==0.3.0
86
+ pytz==2023.3.post1
87
+ PyWavelets==1.4.1
88
+ PyYAML==6.0
89
+ readchar==4.0.5
90
+ regex==2023.10.3
91
+ requests==2.28.2
92
+ requests-oauthlib==1.3.1
93
+ rfc3986==1.5.0
94
+ rich==13.3.3
95
+ rsa==4.9
96
+ scikit-image==0.19.3
97
+ scikit-video==1.1.11
98
+ scipy==1.10.1
99
+ seaborn==0.12.2
100
+ sentry-sdk==1.14.0
101
+ setproctitle==1.3.2
102
+ six==1.16.0
103
+ smmap==5.0.0
104
+ sniffio==1.3.0
105
+ soupsieve==2.4
106
+ starlette==0.22.0
107
+ starsessions==1.3.0
108
+ sympy==1.11.1
109
+ tabulate==0.9.0
110
+ tbb==2021.11.0
111
+ tensorboard==2.14.0
112
+ tensorboard-data-server==0.7.1
113
+ termcolor==2.4.0
114
+ tifffile==2023.1.23.1
115
+ timm==0.6.12
116
+ torchmetrics==0.11.4
117
+ torchsummary==1.5.1
118
+ tqdm==4.66.1
119
+ triton==2.0.0
120
+ tsnecuda==3.0.1
121
+ typing_extensions==4.5.0
122
+ ujson==5.7.0
123
+ urllib3==1.26.14
124
+ uvicorn==0.21.1
125
+ uvloop==0.17.0
126
+ wandb==0.13.9
127
+ warmup-scheduler==0.3
128
+ watchfiles==0.19.0
129
+ wcwidth==0.2.8
130
+ websocket-client==1.5.1
131
+ websockets==11.0.1
132
+ Werkzeug==3.0.0
133
+ yacs==0.1.8
134
+ yapf==0.32.0
135
+ yarl==1.8.2
136
+ zipp==3.17.0
results.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ Model Val PSNR Val Time [ms] Params [M] FLOPs [G] Acts [M] Mem [M] Conv
2
+ 00_EFDN_baseline 26.93 34.31 0.276 16.70 111.12 662.89 65
3
+ 23_DSCF 26.92 8.37 0.131 8.54 38.93 728.41 22
run.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Evaluation on LSDIR_DIV2K_valid datasets for One Method: ---
2
+ CUDA_VISIBLE_DEVICES=0 python test_demo.py \
3
+ --data_dir ../ \
4
+ --save_dir ../results \
5
+ --model_id 23
6
+
7
+
8
+ # --- When only LSDIR_DIV2K_test datasets are included (For Organizer) ---
9
+ # CUDA_VISIBLE_DEVICES=0 python test_demo.py \
10
+ # --data_dir ../ \
11
+ # --save_dir ../results \
12
+ # --include_test \
13
+ # --model_id 0
14
+
15
+ # --- Test all the methods (For Organizer) ---
16
+ #!/bin/bash
17
+ # DATA_DIR="/Your/Validate/Datasets/Path"
18
+ # SAVE_DIR="./results"
19
+ # MODEL_IDS=(
20
+ # 0 1 3 4 5 7 10 11 13 15
21
+ # 16 17 18 19 21 23 25 26
22
+ # 28 29 30 31 33 34 38 39
23
+ # 41 42 43 44 45 46 48
24
+ # )
25
+
26
+ # for model_id in "${MODEL_IDS[@]}"
27
+ # do
28
+ # CUDA_VISIBLE_DEVICES=0 python test_demo.py --data_dir "$DATA_DIR" --save_dir "$SAVE_DIR" --include_test --model_id "$model_id"
29
+ # done
test_demo.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import logging
3
+ import torch
4
+ import argparse
5
+ import json
6
+ import glob
7
+
8
+ from pprint import pprint
9
+ from fvcore.nn import FlopCountAnalysis
10
+ from utils.model_summary import get_model_activation, get_model_flops
11
+ from utils import utils_logger
12
+ from utils import utils_image as util
13
+
14
+
15
+ def select_model(args, device):
16
+ # Model ID is assigned according to the order of the submissions.
17
+ # Different networks are trained with input range of either [0,1] or [0,255]. The range is determined manually.
18
+ model_id = args.model_id
19
+ if model_id == 0:
20
+ # Baseline: The 1st Place of the `Overall Performance`` of the NTIRE 2023 Efficient SR Challenge
21
+ # Edge-enhanced Feature Distillation Network for Efficient Super-Resolution
22
+ # arXiv: https://arxiv.org/pdf/2204.08759
23
+ # Original Code: https://github.com/icandle/EFDN
24
+ # Ckpts: EFDN_gv.pth
25
+ from models.team00_EFDN import EFDN
26
+ name, data_range = f"{model_id:02}_EFDN_baseline", 1.0
27
+ model_path = os.path.join('model_zoo', 'team00_EFDN.pth')
28
+ model = EFDN()
29
+ model.load_state_dict(torch.load(model_path), strict=True)
30
+ elif model_id == 23:
31
+ from models.team23_DSCF import DSCF
32
+
33
+ name, data_range = f"{model_id:02}_DSCF", 1.0
34
+ model_path = os.path.join('model_zoo', 'team23_DSCF.pth')
35
+ model = DSCF(3,3,feature_channels=26,upscale=4)
36
+ state_dict = torch.load(model_path)
37
+
38
+ model.load_state_dict(state_dict, strict=False)
39
+ else:
40
+ raise NotImplementedError(f"Model {model_id} is not implemented.")
41
+
42
+ # print(model)
43
+ model.eval()
44
+ tile = None
45
+ for k, v in model.named_parameters():
46
+ v.requires_grad = False
47
+ model = model.to(device)
48
+ return model, name, data_range, tile
49
+
50
+
51
+ def select_dataset(data_dir, mode):
52
+ # inference on the DIV2K_LSDIR_test set
53
+ if mode == "test":
54
+ path = [
55
+ (
56
+ p.replace("_HR", "_LR").replace(".png", "x4.png"),
57
+ p
58
+ ) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_test_HR/*.png")))
59
+ ]
60
+
61
+ # inference on the DIV2K_LSDIR_valid set
62
+ elif mode == "valid":
63
+ path = [
64
+ (
65
+ p.replace("_HR", "_LR").replace(".png", "x4.png"),
66
+ p
67
+ ) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_valid_HR/*.png")))
68
+ ]
69
+ else:
70
+ raise NotImplementedError(f"{mode} is not implemented in select_dataset")
71
+
72
+ return path
73
+
74
+
75
+ def forward(img_lq, model, tile=None, tile_overlap=32, scale=4):
76
+ if tile is None:
77
+ # test the image as a whole
78
+ output = model(img_lq)
79
+ else:
80
+ # test the image tile by tile
81
+ b, c, h, w = img_lq.size()
82
+ tile = min(tile, h, w)
83
+ tile_overlap = tile_overlap
84
+ sf = scale
85
+
86
+ stride = tile - tile_overlap
87
+ h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
88
+ w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
89
+ E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
90
+ W = torch.zeros_like(E)
91
+
92
+ for h_idx in h_idx_list:
93
+ for w_idx in w_idx_list:
94
+ in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
95
+ out_patch = model(in_patch)
96
+ out_patch_mask = torch.ones_like(out_patch)
97
+
98
+ E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
99
+ W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
100
+ output = E.div_(W)
101
+
102
+ return output
103
+
104
+ def run(model, model_name, data_range, tile, logger, device, args, mode="test"):
105
+
106
+ sf = 4
107
+ border = sf
108
+ results = dict()
109
+ results[f"{mode}_runtime"] = []
110
+ results[f"{mode}_psnr"] = []
111
+ if args.ssim:
112
+ results[f"{mode}_ssim"] = []
113
+ # results[f"{mode}_psnr_y"] = []
114
+ # results[f"{mode}_ssim_y"] = []
115
+
116
+ # --------------------------------
117
+ # dataset path
118
+ # --------------------------------
119
+ data_path = select_dataset(args.data_dir, mode)
120
+ save_path = os.path.join(args.save_dir, model_name, mode)
121
+ util.mkdir(save_path)
122
+
123
+ start = torch.cuda.Event(enable_timing=True)
124
+ end = torch.cuda.Event(enable_timing=True)
125
+
126
+ for i, (img_lr, img_hr) in enumerate(data_path):
127
+
128
+ # --------------------------------
129
+ # (1) img_lr
130
+ # --------------------------------
131
+ img_name, ext = os.path.splitext(os.path.basename(img_hr))
132
+ img_lr = util.imread_uint(img_lr, n_channels=3)
133
+ img_lr = util.uint2tensor4(img_lr, data_range)
134
+ img_lr = img_lr.to(device)
135
+
136
+ # --------------------------------
137
+ # (2) img_sr
138
+ # --------------------------------
139
+ start.record()
140
+ img_sr = forward(img_lr, model, tile)
141
+ end.record()
142
+ torch.cuda.synchronize()
143
+ results[f"{mode}_runtime"].append(start.elapsed_time(end)) # milliseconds
144
+ img_sr = util.tensor2uint(img_sr, data_range)
145
+
146
+ # --------------------------------
147
+ # (3) img_hr
148
+ # --------------------------------
149
+ img_hr = util.imread_uint(img_hr, n_channels=3)
150
+ img_hr = img_hr.squeeze()
151
+ img_hr = util.modcrop(img_hr, sf)
152
+
153
+ # --------------------------------
154
+ # PSNR and SSIM
155
+ # --------------------------------
156
+
157
+ # print(img_sr.shape, img_hr.shape)
158
+ psnr = util.calculate_psnr(img_sr, img_hr, border=border)
159
+ results[f"{mode}_psnr"].append(psnr)
160
+
161
+ if args.ssim:
162
+ ssim = util.calculate_ssim(img_sr, img_hr, border=border)
163
+ results[f"{mode}_ssim"].append(ssim)
164
+ logger.info("{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.".format(img_name + ext, psnr, ssim))
165
+ else:
166
+ logger.info("{:s} - PSNR: {:.2f} dB".format(img_name + ext, psnr))
167
+
168
+ # if np.ndim(img_hr) == 3: # RGB image
169
+ # img_sr_y = util.rgb2ycbcr(img_sr, only_y=True)
170
+ # img_hr_y = util.rgb2ycbcr(img_hr, only_y=True)
171
+ # psnr_y = util.calculate_psnr(img_sr_y, img_hr_y, border=border)
172
+ # ssim_y = util.calculate_ssim(img_sr_y, img_hr_y, border=border)
173
+ # results[f"{mode}_psnr_y"].append(psnr_y)
174
+ # results[f"{mode}_ssim_y"].append(ssim_y)
175
+ # print(os.path.join(save_path, img_name+ext))
176
+
177
+ # --- Save Restored Images ---
178
+ # util.imsave(img_sr, os.path.join(save_path, img_name+ext))
179
+
180
+ results[f"{mode}_memory"] = torch.cuda.max_memory_allocated(torch.cuda.current_device()) / 1024 ** 2
181
+ results[f"{mode}_ave_runtime"] = sum(results[f"{mode}_runtime"]) / len(results[f"{mode}_runtime"]) #/ 1000.0
182
+ results[f"{mode}_ave_psnr"] = sum(results[f"{mode}_psnr"]) / len(results[f"{mode}_psnr"])
183
+ if args.ssim:
184
+ results[f"{mode}_ave_ssim"] = sum(results[f"{mode}_ssim"]) / len(results[f"{mode}_ssim"])
185
+ # results[f"{mode}_ave_psnr_y"] = sum(results[f"{mode}_psnr_y"]) / len(results[f"{mode}_psnr_y"])
186
+ # results[f"{mode}_ave_ssim_y"] = sum(results[f"{mode}_ssim_y"]) / len(results[f"{mode}_ssim_y"])
187
+ logger.info("{:>16s} : {:<.3f} [M]".format("Max Memory", results[f"{mode}_memory"])) # Memery
188
+ logger.info("------> Average runtime of ({}) is : {:.6f} milliseconds".format("test" if mode == "test" else "valid", results[f"{mode}_ave_runtime"]))
189
+ logger.info("------> Average PSNR of ({}) is : {:.6f} dB".format("test" if mode == "test" else "valid", results[f"{mode}_ave_psnr"]))
190
+
191
+ return results
192
+
193
+
194
+ def main(args):
195
+
196
+ utils_logger.logger_info("NTIRE2025-EfficientSR", log_path="NTIRE2025-EfficientSR.log")
197
+ logger = logging.getLogger("NTIRE2025-EfficientSR")
198
+
199
+ # --------------------------------
200
+ # basic settings
201
+ # --------------------------------
202
+ torch.cuda.current_device()
203
+ torch.cuda.empty_cache()
204
+ torch.backends.cudnn.benchmark = False
205
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
206
+
207
+ json_dir = os.path.join(os.getcwd(), "results.json")
208
+ if not os.path.exists(json_dir):
209
+ results = dict()
210
+ else:
211
+ with open(json_dir, "r") as f:
212
+ results = json.load(f)
213
+
214
+ # --------------------------------
215
+ # load model
216
+ # --------------------------------
217
+ model, model_name, data_range, tile = select_model(args, device)
218
+ logger.info(model_name)
219
+
220
+ # if model not in results:
221
+ if True:
222
+ # --------------------------------
223
+ # restore image
224
+ # --------------------------------
225
+
226
+ # inference on the DIV2K_LSDIR_valid set
227
+ valid_results = run(model, model_name, data_range, tile, logger, device, args, mode="valid")
228
+ # record PSNR, runtime
229
+ results[model_name] = valid_results
230
+
231
+ # inference conducted by the Organizer on DIV2K_LSDIR_test set
232
+ if args.include_test:
233
+ test_results = run(model, model_name, data_range, tile, logger, device, args, mode="test")
234
+ results[model_name].update(test_results)
235
+
236
+ input_dim = (3, 256, 256) # set the input dimension
237
+ activations, num_conv = get_model_activation(model, input_dim)
238
+ activations = activations/10**6
239
+ logger.info("{:>16s} : {:<.4f} [M]".format("#Activations", activations))
240
+ logger.info("{:>16s} : {:<d}".format("#Conv2d", num_conv))
241
+
242
+ # The FLOPs calculation in previous NTIRE_ESR Challenge
243
+ # flops = get_model_flops(model, input_dim, False)
244
+ # flops = flops/10**9
245
+ # logger.info("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
246
+
247
+ # fvcore is used in NTIRE2025_ESR for FLOPs calculation
248
+ input_fake = torch.rand(1, 3, 256, 256).to(device)
249
+ flops = FlopCountAnalysis(model, input_fake).total()
250
+ flops = flops/10**9
251
+ logger.info("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
252
+
253
+ num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
254
+ num_parameters = num_parameters/10**6
255
+ logger.info("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters))
256
+ results[model_name].update({"activations": activations, "num_conv": num_conv, "flops": flops, "num_parameters": num_parameters})
257
+
258
+ with open(json_dir, "w") as f:
259
+ json.dump(results, f)
260
+ if args.include_test:
261
+ fmt = "{:20s}\t{:10s}\t{:10s}\t{:14s}\t{:14s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n"
262
+ s = fmt.format("Model", "Val PSNR", "Test PSNR", "Val Time [ms]", "Test Time [ms]", "Ave Time [ms]",
263
+ "Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv")
264
+ else:
265
+ fmt = "{:20s}\t{:10s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n"
266
+ s = fmt.format("Model", "Val PSNR", "Val Time [ms]", "Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv")
267
+ for k, v in results.items():
268
+ val_psnr = f"{v['valid_ave_psnr']:2.2f}"
269
+ val_time = f"{v['valid_ave_runtime']:3.2f}"
270
+ mem = f"{v['valid_memory']:2.2f}"
271
+
272
+ num_param = f"{v['num_parameters']:2.3f}"
273
+ flops = f"{v['flops']:2.2f}"
274
+ acts = f"{v['activations']:2.2f}"
275
+ conv = f"{v['num_conv']:4d}"
276
+ if args.include_test:
277
+ # from IPython import embed; embed()
278
+ test_psnr = f"{v['test_ave_psnr']:2.2f}"
279
+ test_time = f"{v['test_ave_runtime']:3.2f}"
280
+ ave_time = f"{(v['valid_ave_runtime'] + v['test_ave_runtime']) / 2:3.2f}"
281
+ s += fmt.format(k, val_psnr, test_psnr, val_time, test_time, ave_time, num_param, flops, acts, mem, conv)
282
+ else:
283
+ s += fmt.format(k, val_psnr, val_time, num_param, flops, acts, mem, conv)
284
+ with open(os.path.join(os.getcwd(), 'results.txt'), "w") as f:
285
+ f.write(s)
286
+
287
+
288
+ if __name__ == "__main__":
289
+ parser = argparse.ArgumentParser("NTIRE2025-EfficientSR")
290
+ parser.add_argument("--data_dir", default="../", type=str)
291
+ parser.add_argument("--save_dir", default="../results", type=str)
292
+ parser.add_argument("--model_id", default=0, type=int)
293
+ parser.add_argument("--include_test", action="store_true", help="Inference on the `DIV2K_LSDIR_test` set")
294
+ parser.add_argument("--ssim", action="store_true", help="Calculate SSIM")
295
+
296
+ args = parser.parse_args()
297
+ pprint(args)
298
+
299
+ main(args)
utils/__pycache__/model_summary.cpython-311.pyc ADDED
Binary file (22.3 kB). View file
 
utils/__pycache__/utils_image.cpython-311.pyc ADDED
Binary file (38.5 kB). View file
 
utils/__pycache__/utils_logger.cpython-311.pyc ADDED
Binary file (2.87 kB). View file
 
utils/model_summary.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+
5
+ '''
6
+ ---- 1) FLOPs: floating point operations
7
+ ---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs
8
+ ---- 3) #Conv2d: the number of ‘Conv2d’ layers
9
+ '''
10
+
11
+ def get_model_flops(model, input_res, print_per_layer_stat=True,
12
+ input_constructor=None):
13
+ assert type(input_res) is tuple, 'Please provide the size of the input image.'
14
+ assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
15
+ flops_model = add_flops_counting_methods(model)
16
+ flops_model.eval().start_flops_count()
17
+ if input_constructor:
18
+ input = input_constructor(input_res)
19
+ _ = flops_model(**input)
20
+ else:
21
+ device = list(flops_model.parameters())[-1].device
22
+ batch = torch.FloatTensor(1, *input_res).to(device)
23
+ _ = flops_model(batch)
24
+
25
+ if print_per_layer_stat:
26
+ print_model_with_flops(flops_model)
27
+ flops_count = flops_model.compute_average_flops_cost()
28
+ flops_model.stop_flops_count()
29
+
30
+ return flops_count
31
+
32
+ def get_model_activation(model, input_res, input_constructor=None):
33
+ assert type(input_res) is tuple, 'Please provide the size of the input image.'
34
+ assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
35
+ activation_model = add_activation_counting_methods(model)
36
+ activation_model.eval().start_activation_count()
37
+ if input_constructor:
38
+ input = input_constructor(input_res)
39
+ _ = activation_model(**input)
40
+ else:
41
+ device = list(activation_model.parameters())[-1].device
42
+ batch = torch.FloatTensor(1, *input_res).to(device)
43
+ _ = activation_model(batch)
44
+
45
+ activation_count, num_conv = activation_model.compute_average_activation_cost()
46
+ activation_model.stop_activation_count()
47
+
48
+ return activation_count, num_conv
49
+
50
+
51
+ def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True,
52
+ input_constructor=None):
53
+ assert type(input_res) is tuple
54
+ assert len(input_res) >= 3
55
+ flops_model = add_flops_counting_methods(model)
56
+ flops_model.eval().start_flops_count()
57
+ if input_constructor:
58
+ input = input_constructor(input_res)
59
+ _ = flops_model(**input)
60
+ else:
61
+ batch = torch.FloatTensor(1, *input_res)
62
+ _ = flops_model(batch)
63
+
64
+ if print_per_layer_stat:
65
+ print_model_with_flops(flops_model)
66
+ flops_count = flops_model.compute_average_flops_cost()
67
+ params_count = get_model_parameters_number(flops_model)
68
+ flops_model.stop_flops_count()
69
+
70
+ if as_strings:
71
+ return flops_to_string(flops_count), params_to_string(params_count)
72
+
73
+ return flops_count, params_count
74
+
75
+
76
+ def flops_to_string(flops, units='GMac', precision=2):
77
+ if units is None:
78
+ if flops // 10**9 > 0:
79
+ return str(round(flops / 10.**9, precision)) + ' GMac'
80
+ elif flops // 10**6 > 0:
81
+ return str(round(flops / 10.**6, precision)) + ' MMac'
82
+ elif flops // 10**3 > 0:
83
+ return str(round(flops / 10.**3, precision)) + ' KMac'
84
+ else:
85
+ return str(flops) + ' Mac'
86
+ else:
87
+ if units == 'GMac':
88
+ return str(round(flops / 10.**9, precision)) + ' ' + units
89
+ elif units == 'MMac':
90
+ return str(round(flops / 10.**6, precision)) + ' ' + units
91
+ elif units == 'KMac':
92
+ return str(round(flops / 10.**3, precision)) + ' ' + units
93
+ else:
94
+ return str(flops) + ' Mac'
95
+
96
+
97
+ def params_to_string(params_num):
98
+ if params_num // 10 ** 6 > 0:
99
+ return str(round(params_num / 10 ** 6, 2)) + ' M'
100
+ elif params_num // 10 ** 3:
101
+ return str(round(params_num / 10 ** 3, 2)) + ' k'
102
+ else:
103
+ return str(params_num)
104
+
105
+
106
+ def print_model_with_flops(model, units='GMac', precision=3):
107
+ total_flops = model.compute_average_flops_cost()
108
+
109
+ def accumulate_flops(self):
110
+ if is_supported_instance(self):
111
+ return self.__flops__ / model.__batch_counter__
112
+ else:
113
+ sum = 0
114
+ for m in self.children():
115
+ sum += m.accumulate_flops()
116
+ return sum
117
+
118
+ def flops_repr(self):
119
+ accumulated_flops_cost = self.accumulate_flops()
120
+ return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
121
+ '{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
122
+ self.original_extra_repr()])
123
+
124
+ def add_extra_repr(m):
125
+ m.accumulate_flops = accumulate_flops.__get__(m)
126
+ flops_extra_repr = flops_repr.__get__(m)
127
+ if m.extra_repr != flops_extra_repr:
128
+ m.original_extra_repr = m.extra_repr
129
+ m.extra_repr = flops_extra_repr
130
+ assert m.extra_repr != m.original_extra_repr
131
+
132
+ def del_extra_repr(m):
133
+ if hasattr(m, 'original_extra_repr'):
134
+ m.extra_repr = m.original_extra_repr
135
+ del m.original_extra_repr
136
+ if hasattr(m, 'accumulate_flops'):
137
+ del m.accumulate_flops
138
+
139
+ model.apply(add_extra_repr)
140
+ print(model)
141
+ model.apply(del_extra_repr)
142
+
143
+
144
+ def get_model_parameters_number(model):
145
+ params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
146
+ return params_num
147
+
148
+
149
+ def add_flops_counting_methods(net_main_module):
150
+ # adding additional methods to the existing module object,
151
+ # this is done this way so that each function has access to self object
152
+ # embed()
153
+ net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
154
+ net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
155
+ net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
156
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
157
+
158
+ net_main_module.reset_flops_count()
159
+ return net_main_module
160
+
161
+
162
+ def compute_average_flops_cost(self):
163
+ """
164
+ A method that will be available after add_flops_counting_methods() is called
165
+ on a desired net object.
166
+
167
+ Returns current mean flops consumption per image.
168
+
169
+ """
170
+
171
+ flops_sum = 0
172
+ for module in self.modules():
173
+ if is_supported_instance(module):
174
+ flops_sum += module.__flops__
175
+
176
+ return flops_sum
177
+
178
+
179
+ def start_flops_count(self):
180
+ """
181
+ A method that will be available after add_flops_counting_methods() is called
182
+ on a desired net object.
183
+
184
+ Activates the computation of mean flops consumption per image.
185
+ Call it before you run the network.
186
+
187
+ """
188
+ self.apply(add_flops_counter_hook_function)
189
+
190
+
191
+ def stop_flops_count(self):
192
+ """
193
+ A method that will be available after add_flops_counting_methods() is called
194
+ on a desired net object.
195
+
196
+ Stops computing the mean flops consumption per image.
197
+ Call whenever you want to pause the computation.
198
+
199
+ """
200
+ self.apply(remove_flops_counter_hook_function)
201
+
202
+
203
+ def reset_flops_count(self):
204
+ """
205
+ A method that will be available after add_flops_counting_methods() is called
206
+ on a desired net object.
207
+
208
+ Resets statistics computed so far.
209
+
210
+ """
211
+ self.apply(add_flops_counter_variable_or_reset)
212
+
213
+
214
+ def add_flops_counter_hook_function(module):
215
+ if is_supported_instance(module):
216
+ if hasattr(module, '__flops_handle__'):
217
+ return
218
+
219
+ if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
220
+ handle = module.register_forward_hook(conv_flops_counter_hook)
221
+ elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)):
222
+ handle = module.register_forward_hook(relu_flops_counter_hook)
223
+ elif isinstance(module, nn.Linear):
224
+ handle = module.register_forward_hook(linear_flops_counter_hook)
225
+ elif isinstance(module, (nn.BatchNorm2d)):
226
+ handle = module.register_forward_hook(bn_flops_counter_hook)
227
+ else:
228
+ handle = module.register_forward_hook(empty_flops_counter_hook)
229
+ module.__flops_handle__ = handle
230
+
231
+
232
+ def remove_flops_counter_hook_function(module):
233
+ if is_supported_instance(module):
234
+ if hasattr(module, '__flops_handle__'):
235
+ module.__flops_handle__.remove()
236
+ del module.__flops_handle__
237
+
238
+
239
+ def add_flops_counter_variable_or_reset(module):
240
+ if is_supported_instance(module):
241
+ module.__flops__ = 0
242
+
243
+
244
+ # ---- Internal functions
245
+ def is_supported_instance(module):
246
+ if isinstance(module,
247
+ (
248
+ nn.Conv2d, nn.ConvTranspose2d,
249
+ nn.BatchNorm2d,
250
+ nn.Linear,
251
+ nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6,
252
+ )):
253
+ return True
254
+
255
+ return False
256
+
257
+
258
+ def conv_flops_counter_hook(conv_module, input, output):
259
+ # Can have multiple inputs, getting the first one
260
+ # input = input[0]
261
+
262
+ batch_size = output.shape[0]
263
+ output_dims = list(output.shape[2:])
264
+
265
+ kernel_dims = list(conv_module.kernel_size)
266
+ in_channels = conv_module.in_channels
267
+ out_channels = conv_module.out_channels
268
+ groups = conv_module.groups
269
+
270
+ filters_per_channel = out_channels // groups
271
+ conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
272
+
273
+ active_elements_count = batch_size * np.prod(output_dims)
274
+ overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count)
275
+
276
+ # overall_flops = overall_conv_flops
277
+
278
+ conv_module.__flops__ += int(overall_conv_flops)
279
+ # conv_module.__output_dims__ = output_dims
280
+
281
+
282
+ def relu_flops_counter_hook(module, input, output):
283
+ active_elements_count = output.numel()
284
+ module.__flops__ += int(active_elements_count)
285
+ # print(module.__flops__, id(module))
286
+ # print(module)
287
+
288
+
289
+ def linear_flops_counter_hook(module, input, output):
290
+ input = input[0]
291
+ if len(input.shape) == 1:
292
+ batch_size = 1
293
+ module.__flops__ += int(batch_size * input.shape[0] * output.shape[0])
294
+ else:
295
+ batch_size = input.shape[0]
296
+ module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
297
+
298
+
299
+ def bn_flops_counter_hook(module, input, output):
300
+ # input = input[0]
301
+ # TODO: need to check here
302
+ # batch_flops = np.prod(input.shape)
303
+ # if module.affine:
304
+ # batch_flops *= 2
305
+ # module.__flops__ += int(batch_flops)
306
+ batch = output.shape[0]
307
+ output_dims = output.shape[2:]
308
+ channels = module.num_features
309
+ batch_flops = batch * channels * np.prod(output_dims)
310
+ if module.affine:
311
+ batch_flops *= 2
312
+ module.__flops__ += int(batch_flops)
313
+
314
+
315
+ # ---- Count the number of convolutional layers and the activation
316
+ def add_activation_counting_methods(net_main_module):
317
+ # adding additional methods to the existing module object,
318
+ # this is done this way so that each function has access to self object
319
+ # embed()
320
+ net_main_module.start_activation_count = start_activation_count.__get__(net_main_module)
321
+ net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module)
322
+ net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module)
323
+ net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module)
324
+
325
+ net_main_module.reset_activation_count()
326
+ return net_main_module
327
+
328
+
329
+ def compute_average_activation_cost(self):
330
+ """
331
+ A method that will be available after add_activation_counting_methods() is called
332
+ on a desired net object.
333
+
334
+ Returns current mean activation consumption per image.
335
+
336
+ """
337
+
338
+ activation_sum = 0
339
+ num_conv = 0
340
+ for module in self.modules():
341
+ if is_supported_instance_for_activation(module):
342
+ activation_sum += module.__activation__
343
+ num_conv += module.__num_conv__
344
+ return activation_sum, num_conv
345
+
346
+
347
+ def start_activation_count(self):
348
+ """
349
+ A method that will be available after add_activation_counting_methods() is called
350
+ on a desired net object.
351
+
352
+ Activates the computation of mean activation consumption per image.
353
+ Call it before you run the network.
354
+
355
+ """
356
+ self.apply(add_activation_counter_hook_function)
357
+
358
+
359
+ def stop_activation_count(self):
360
+ """
361
+ A method that will be available after add_activation_counting_methods() is called
362
+ on a desired net object.
363
+
364
+ Stops computing the mean activation consumption per image.
365
+ Call whenever you want to pause the computation.
366
+
367
+ """
368
+ self.apply(remove_activation_counter_hook_function)
369
+
370
+
371
+ def reset_activation_count(self):
372
+ """
373
+ A method that will be available after add_activation_counting_methods() is called
374
+ on a desired net object.
375
+
376
+ Resets statistics computed so far.
377
+
378
+ """
379
+ self.apply(add_activation_counter_variable_or_reset)
380
+
381
+
382
+ def add_activation_counter_hook_function(module):
383
+ if is_supported_instance_for_activation(module):
384
+ if hasattr(module, '__activation_handle__'):
385
+ return
386
+
387
+ if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
388
+ handle = module.register_forward_hook(conv_activation_counter_hook)
389
+ module.__activation_handle__ = handle
390
+
391
+
392
+ def remove_activation_counter_hook_function(module):
393
+ if is_supported_instance_for_activation(module):
394
+ if hasattr(module, '__activation_handle__'):
395
+ module.__activation_handle__.remove()
396
+ del module.__activation_handle__
397
+
398
+
399
+ def add_activation_counter_variable_or_reset(module):
400
+ if is_supported_instance_for_activation(module):
401
+ module.__activation__ = 0
402
+ module.__num_conv__ = 0
403
+
404
+
405
+ def is_supported_instance_for_activation(module):
406
+ if isinstance(module,
407
+ (
408
+ nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d, nn.Linear, nn.ConvTranspose1d
409
+ )):
410
+ return True
411
+
412
+ return False
413
+
414
+ def conv_activation_counter_hook(module, input, output):
415
+ """
416
+ Calculate the activations in the convolutional operation.
417
+ Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces.
418
+ :param module:
419
+ :param input:
420
+ :param output:
421
+ :return:
422
+ """
423
+ module.__activation__ += output.numel()
424
+ module.__num_conv__ += 1
425
+
426
+
427
+ def empty_flops_counter_hook(module, input, output):
428
+ module.__flops__ += 0
429
+
430
+
431
+ def upsample_flops_counter_hook(module, input, output):
432
+ output_size = output[0]
433
+ batch_size = output_size.shape[0]
434
+ output_elements_count = batch_size
435
+ for val in output_size.shape[1:]:
436
+ output_elements_count *= val
437
+ module.__flops__ += int(output_elements_count)
438
+
439
+
440
+ def pool_flops_counter_hook(module, input, output):
441
+ input = input[0]
442
+ module.__flops__ += int(np.prod(input.shape))
443
+
444
+
445
+ def dconv_flops_counter_hook(dconv_module, input, output):
446
+ input = input[0]
447
+
448
+ batch_size = input.shape[0]
449
+ output_dims = list(output.shape[2:])
450
+
451
+ m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape
452
+ out_channels, _, kernel_dim2, _, = dconv_module.projection.shape
453
+ # groups = dconv_module.groups
454
+
455
+ # filters_per_channel = out_channels // groups
456
+ conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels
457
+ conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels
458
+ active_elements_count = batch_size * np.prod(output_dims)
459
+
460
+ overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count
461
+ overall_flops = overall_conv_flops
462
+
463
+ dconv_module.__flops__ += int(overall_flops)
464
+ # dconv_module.__output_dims__ = output_dims
465
+
utils/test.bmp ADDED

Git LFS Details

  • SHA256: 43534bd9f59f06ac7b0d6c8b991137baa4c764b8241a5bd1136b9dd810f4dba2
  • Pointer size: 131 Bytes
  • Size of remote file: 197 kB
utils/utils_image.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+ from torchvision.utils import make_grid
8
+ from datetime import datetime
9
+ # import torchvision.transforms as transforms
10
+ import matplotlib.pyplot as plt
11
+
12
+
13
+ IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
14
+
15
+
16
+ def is_image_file(filename):
17
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
18
+
19
+
20
+ def get_timestamp():
21
+ return datetime.now().strftime('%y%m%d-%H%M%S')
22
+
23
+
24
+ def imshow(x, title=None, cbar=False, figsize=None):
25
+ plt.figure(figsize=figsize)
26
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
27
+ if title:
28
+ plt.title(title)
29
+ if cbar:
30
+ plt.colorbar()
31
+ plt.show()
32
+
33
+
34
+ '''
35
+ # =======================================
36
+ # get image pathes of files
37
+ # =======================================
38
+ '''
39
+
40
+
41
+ def get_image_paths(dataroot):
42
+ paths = None # return None if dataroot is None
43
+ if dataroot is not None:
44
+ paths = sorted(_get_paths_from_images(dataroot))
45
+ return paths
46
+
47
+
48
+ def _get_paths_from_images(path):
49
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
50
+ images = []
51
+ for dirpath, _, fnames in sorted(os.walk(path)):
52
+ for fname in sorted(fnames):
53
+ if is_image_file(fname):
54
+ img_path = os.path.join(dirpath, fname)
55
+ images.append(img_path)
56
+ assert images, '{:s} has no valid image file'.format(path)
57
+ return images
58
+
59
+
60
+ '''
61
+ # =======================================
62
+ # makedir
63
+ # =======================================
64
+ '''
65
+
66
+
67
+ def mkdir(path):
68
+ if not os.path.exists(path):
69
+ os.makedirs(path)
70
+
71
+
72
+ def mkdirs(paths):
73
+ if isinstance(paths, str):
74
+ mkdir(paths)
75
+ else:
76
+ for path in paths:
77
+ mkdir(path)
78
+
79
+
80
+ def mkdir_and_rename(path):
81
+ if os.path.exists(path):
82
+ new_name = path + '_archived_' + get_timestamp()
83
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
84
+ os.rename(path, new_name)
85
+ os.makedirs(path)
86
+
87
+
88
+ '''
89
+ # =======================================
90
+ # read image from path
91
+ # Note: opencv is fast
92
+ # but read BGR numpy image
93
+ # =======================================
94
+ '''
95
+
96
+
97
+ # ----------------------------------------
98
+ # get single image of size HxWxn_channles (BGR)
99
+ # ----------------------------------------
100
+ def read_img(path):
101
+ # read image by cv2
102
+ # return: Numpy float32, HWC, BGR, [0,1]
103
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
104
+ img = img.astype(np.float32) / 255.
105
+ if img.ndim == 2:
106
+ img = np.expand_dims(img, axis=2)
107
+ # some images have 4 channels
108
+ if img.shape[2] > 3:
109
+ img = img[:, :, :3]
110
+ return img
111
+
112
+
113
+ # ----------------------------------------
114
+ # get uint8 image of size HxWxn_channles (RGB)
115
+ # ----------------------------------------
116
+ def imread_uint(path, n_channels=3):
117
+ # input: path
118
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
119
+ if n_channels == 1:
120
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
121
+ img = np.expand_dims(img, axis=2) # HxWx1
122
+ elif n_channels == 3:
123
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
124
+ if img.ndim == 2:
125
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
126
+ else:
127
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
128
+ return img
129
+
130
+
131
+ def imsave(img, img_path):
132
+ img = np.squeeze(img)
133
+ if img.ndim == 3:
134
+ img = img[:, :, [2, 1, 0]]
135
+ cv2.imwrite(img_path, img)
136
+
137
+
138
+ '''
139
+ # =======================================
140
+ # numpy(single) <---> numpy(uint)
141
+ # numpy(single) <---> tensor
142
+ # numpy(uint) <---> tensor
143
+ # =======================================
144
+ '''
145
+
146
+
147
+ # --------------------------------
148
+ # numpy(single) <---> numpy(uint)
149
+ # --------------------------------
150
+
151
+
152
+ def uint2single(img):
153
+
154
+ return np.float32(img/255.)
155
+
156
+
157
+ def uint2single1(img):
158
+
159
+ return np.float32(np.squeeze(img)/255.)
160
+
161
+
162
+ def single2uint(img):
163
+
164
+ return np.uint8((img.clip(0, 1)*255.).round())
165
+
166
+
167
+ def uint162single(img):
168
+
169
+ return np.float32(img/65535.)
170
+
171
+
172
+ def single2uint16(img):
173
+
174
+ return np.uint8((img.clip(0, 1)*65535.).round())
175
+
176
+
177
+ # --------------------------------
178
+ # numpy(uint) <---> tensor
179
+ # uint (HxWxn_channels (RGB) or G)
180
+ # --------------------------------
181
+
182
+
183
+ # convert uint (HxWxn_channels) to 4-dimensional torch tensor
184
+ def uint2tensor4(img, data_range):
185
+ if img.ndim == 2:
186
+ img = np.expand_dims(img, axis=2)
187
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255./data_range).unsqueeze(0)
188
+
189
+
190
+ # convert uint (HxWxn_channels) to 3-dimensional torch tensor
191
+ def uint2tensor3(img):
192
+ if img.ndim == 2:
193
+ img = np.expand_dims(img, axis=2)
194
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
195
+
196
+
197
+ # convert torch tensor to uint
198
+ def tensor2uint(img, data_range):
199
+ img = img.data.squeeze().float().clamp_(0, 1*data_range).cpu().numpy()
200
+ if img.ndim == 3:
201
+ img = np.transpose(img, (1, 2, 0))
202
+ return np.uint8((img*255.0/data_range).round())
203
+
204
+
205
+ # --------------------------------
206
+ # numpy(single) <---> tensor
207
+ # single (HxWxn_channels (RGB) or G)
208
+ # --------------------------------
209
+
210
+
211
+ # convert single (HxWxn_channels) to 4-dimensional torch tensor
212
+ def single2tensor4(img):
213
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
214
+
215
+
216
+ # convert single (HxWxn_channels) to 3-dimensional torch tensor
217
+ def single2tensor3(img):
218
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
219
+
220
+
221
+ # convert torch tensor to single
222
+ def tensor2single(img):
223
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
224
+ if img.ndim == 3:
225
+ img = np.transpose(img, (1, 2, 0))
226
+
227
+ return img
228
+
229
+ def tensor2single3(img):
230
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
231
+ if img.ndim == 3:
232
+ img = np.transpose(img, (1, 2, 0))
233
+ elif img.ndim == 2:
234
+ img = np.expand_dims(img, axis=2)
235
+ return img
236
+
237
+
238
+ # from skimage.io import imread, imsave
239
+ def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
240
+ '''
241
+ Converts a torch Tensor into an image Numpy array of BGR channel order
242
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
243
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
244
+ '''
245
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
246
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
247
+ n_dim = tensor.dim()
248
+ if n_dim == 4:
249
+ n_img = len(tensor)
250
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
251
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
252
+ elif n_dim == 3:
253
+ img_np = tensor.numpy()
254
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
255
+ elif n_dim == 2:
256
+ img_np = tensor.numpy()
257
+ else:
258
+ raise TypeError(
259
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
260
+ if out_type == np.uint8:
261
+ img_np = (img_np * 255.0).round()
262
+ # Important. Unlike matlab, numpy.uint8() WILL NOT round by default.
263
+ return img_np.astype(out_type)
264
+
265
+
266
+ '''
267
+ # =======================================
268
+ # image processing process on numpy image
269
+ # augment(img_list, hflip=True, rot=True):
270
+ # =======================================
271
+ '''
272
+
273
+
274
+ def augment_img(img, mode=0):
275
+ if mode == 0:
276
+ return img
277
+ elif mode == 1:
278
+ return np.flipud(np.rot90(img))
279
+ elif mode == 2:
280
+ return np.flipud(img)
281
+ elif mode == 3:
282
+ return np.rot90(img, k=3)
283
+ elif mode == 4:
284
+ return np.flipud(np.rot90(img, k=2))
285
+ elif mode == 5:
286
+ return np.rot90(img)
287
+ elif mode == 6:
288
+ return np.rot90(img, k=2)
289
+ elif mode == 7:
290
+ return np.flipud(np.rot90(img, k=3))
291
+
292
+
293
+ def augment_img_np3(img, mode=0):
294
+ if mode == 0:
295
+ return img
296
+ elif mode == 1:
297
+ return img.transpose(1, 0, 2)
298
+ elif mode == 2:
299
+ return img[::-1, :, :]
300
+ elif mode == 3:
301
+ img = img[::-1, :, :]
302
+ img = img.transpose(1, 0, 2)
303
+ return img
304
+ elif mode == 4:
305
+ return img[:, ::-1, :]
306
+ elif mode == 5:
307
+ img = img[:, ::-1, :]
308
+ img = img.transpose(1, 0, 2)
309
+ return img
310
+ elif mode == 6:
311
+ img = img[:, ::-1, :]
312
+ img = img[::-1, :, :]
313
+ return img
314
+ elif mode == 7:
315
+ img = img[:, ::-1, :]
316
+ img = img[::-1, :, :]
317
+ img = img.transpose(1, 0, 2)
318
+ return img
319
+
320
+
321
+ def augment_img_tensor(img, mode=0):
322
+ img_size = img.size()
323
+ img_np = img.data.cpu().numpy()
324
+ if len(img_size) == 3:
325
+ img_np = np.transpose(img_np, (1, 2, 0))
326
+ elif len(img_size) == 4:
327
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
328
+ img_np = augment_img(img_np, mode=mode)
329
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
330
+ if len(img_size) == 3:
331
+ img_tensor = img_tensor.permute(2, 0, 1)
332
+ elif len(img_size) == 4:
333
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
334
+
335
+ return img_tensor.type_as(img)
336
+
337
+
338
+ def augment_imgs(img_list, hflip=True, rot=True):
339
+ # horizontal flip OR rotate
340
+ hflip = hflip and random.random() < 0.5
341
+ vflip = rot and random.random() < 0.5
342
+ rot90 = rot and random.random() < 0.5
343
+
344
+ def _augment(img):
345
+ if hflip:
346
+ img = img[:, ::-1, :]
347
+ if vflip:
348
+ img = img[::-1, :, :]
349
+ if rot90:
350
+ img = img.transpose(1, 0, 2)
351
+ return img
352
+
353
+ return [_augment(img) for img in img_list]
354
+
355
+
356
+ '''
357
+ # =======================================
358
+ # image processing process on numpy image
359
+ # channel_convert(in_c, tar_type, img_list):
360
+ # rgb2ycbcr(img, only_y=True):
361
+ # bgr2ycbcr(img, only_y=True):
362
+ # ycbcr2rgb(img):
363
+ # modcrop(img_in, scale):
364
+ # =======================================
365
+ '''
366
+
367
+
368
+ def rgb2ycbcr(img, only_y=True):
369
+ '''same as matlab rgb2ycbcr
370
+ only_y: only return Y channel
371
+ Input:
372
+ uint8, [0, 255]
373
+ float, [0, 1]
374
+ '''
375
+ in_img_type = img.dtype
376
+ img.astype(np.float32)
377
+ if in_img_type != np.uint8:
378
+ img *= 255.
379
+ # convert
380
+ if only_y:
381
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
382
+ else:
383
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
384
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
385
+ if in_img_type == np.uint8:
386
+ rlt = rlt.round()
387
+ else:
388
+ rlt /= 255.
389
+ return rlt.astype(in_img_type)
390
+
391
+
392
+ def ycbcr2rgb(img):
393
+ '''same as matlab ycbcr2rgb
394
+ Input:
395
+ uint8, [0, 255]
396
+ float, [0, 1]
397
+ '''
398
+ in_img_type = img.dtype
399
+ img.astype(np.float32)
400
+ if in_img_type != np.uint8:
401
+ img *= 255.
402
+ # convert
403
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
404
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
405
+ if in_img_type == np.uint8:
406
+ rlt = rlt.round()
407
+ else:
408
+ rlt /= 255.
409
+ return rlt.astype(in_img_type)
410
+
411
+
412
+ def bgr2ycbcr(img, only_y=True):
413
+ '''bgr version of rgb2ycbcr
414
+ only_y: only return Y channel
415
+ Input:
416
+ uint8, [0, 255]
417
+ float, [0, 1]
418
+ '''
419
+ in_img_type = img.dtype
420
+ img.astype(np.float32)
421
+ if in_img_type != np.uint8:
422
+ img *= 255.
423
+ # convert
424
+ if only_y:
425
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
426
+ else:
427
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
428
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
429
+ if in_img_type == np.uint8:
430
+ rlt = rlt.round()
431
+ else:
432
+ rlt /= 255.
433
+ return rlt.astype(in_img_type)
434
+
435
+
436
+ def modcrop(img_in, scale):
437
+ # img_in: Numpy, HWC or HW
438
+ img = np.copy(img_in)
439
+ if img.ndim == 2:
440
+ H, W = img.shape
441
+ H_r, W_r = H % scale, W % scale
442
+ img = img[:H - H_r, :W - W_r]
443
+ elif img.ndim == 3:
444
+ H, W, C = img.shape
445
+ H_r, W_r = H % scale, W % scale
446
+ img = img[:H - H_r, :W - W_r, :]
447
+ else:
448
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
449
+ return img
450
+
451
+
452
+ def shave(img_in, border=0):
453
+ # img_in: Numpy, HWC or HW
454
+ img = np.copy(img_in)
455
+ h, w = img.shape[:2]
456
+ img = img[border:h-border, border:w-border]
457
+ return img
458
+
459
+
460
+ def channel_convert(in_c, tar_type, img_list):
461
+ # conversion among BGR, gray and y
462
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
463
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
464
+ return [np.expand_dims(img, axis=2) for img in gray_list]
465
+ elif in_c == 3 and tar_type == 'y': # BGR to y
466
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
467
+ return [np.expand_dims(img, axis=2) for img in y_list]
468
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
469
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
470
+ else:
471
+ return img_list
472
+
473
+
474
+ '''
475
+ # =======================================
476
+ # metric, PSNR and SSIM
477
+ # =======================================
478
+ '''
479
+
480
+
481
+ # ----------
482
+ # PSNR
483
+ # ----------
484
+ def calculate_psnr(img1, img2, border=0):
485
+ # img1 and img2 have range [0, 255]
486
+ if not img1.shape == img2.shape:
487
+ raise ValueError('Input images must have the same dimensions.')
488
+ h, w = img1.shape[:2]
489
+ img1 = img1[border:h-border, border:w-border]
490
+ img2 = img2[border:h-border, border:w-border]
491
+
492
+ img1 = img1.astype(np.float64)
493
+ img2 = img2.astype(np.float64)
494
+ mse = np.mean((img1 - img2)**2)
495
+ if mse == 0:
496
+ return float('inf')
497
+ return 20 * math.log10(255.0 / math.sqrt(mse))
498
+
499
+
500
+ # ----------
501
+ # SSIM
502
+ # ----------
503
+ def calculate_ssim(img1, img2, border=0):
504
+ '''calculate SSIM
505
+ the same outputs as MATLAB's
506
+ img1, img2: [0, 255]
507
+ '''
508
+ if not img1.shape == img2.shape:
509
+ raise ValueError('Input images must have the same dimensions.')
510
+ h, w = img1.shape[:2]
511
+ img1 = img1[border:h-border, border:w-border]
512
+ img2 = img2[border:h-border, border:w-border]
513
+
514
+ if img1.ndim == 2:
515
+ return ssim(img1, img2)
516
+ elif img1.ndim == 3:
517
+ if img1.shape[2] == 3:
518
+ ssims = []
519
+ for i in range(3):
520
+ ssims.append(ssim(img1, img2))
521
+ return np.array(ssims).mean()
522
+ elif img1.shape[2] == 1:
523
+ return ssim(np.squeeze(img1), np.squeeze(img2))
524
+ else:
525
+ raise ValueError('Wrong input image dimensions.')
526
+
527
+
528
+ def ssim(img1, img2):
529
+ C1 = (0.01 * 255)**2
530
+ C2 = (0.03 * 255)**2
531
+
532
+ img1 = img1.astype(np.float64)
533
+ img2 = img2.astype(np.float64)
534
+ kernel = cv2.getGaussianKernel(11, 1.5)
535
+ window = np.outer(kernel, kernel.transpose())
536
+
537
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
538
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
539
+ mu1_sq = mu1**2
540
+ mu2_sq = mu2**2
541
+ mu1_mu2 = mu1 * mu2
542
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
543
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
544
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
545
+
546
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
547
+ (sigma1_sq + sigma2_sq + C2))
548
+ return ssim_map.mean()
549
+
550
+
551
+ '''
552
+ # =======================================
553
+ # pytorch version of matlab imresize
554
+ # =======================================
555
+ '''
556
+
557
+
558
+ # matlab 'imresize' function, now only support 'bicubic'
559
+ def cubic(x):
560
+ absx = torch.abs(x)
561
+ absx2 = absx**2
562
+ absx3 = absx**3
563
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
564
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
565
+
566
+
567
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
568
+ if (scale < 1) and (antialiasing):
569
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
570
+ kernel_width = kernel_width / scale
571
+
572
+ # Output-space coordinates
573
+ x = torch.linspace(1, out_length, out_length)
574
+
575
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
576
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
577
+ # space maps to 1.5 in input space.
578
+ u = x / scale + 0.5 * (1 - 1 / scale)
579
+
580
+ # What is the left-most pixel that can be involved in the computation?
581
+ left = torch.floor(u - kernel_width / 2)
582
+
583
+ # What is the maximum number of pixels that can be involved in the
584
+ # computation? Note: it's OK to use an extra pixel here; if the
585
+ # corresponding weights are all zero, it will be eliminated at the end
586
+ # of this function.
587
+ P = math.ceil(kernel_width) + 2
588
+
589
+ # The indices of the input pixels involved in computing the k-th output
590
+ # pixel are in row k of the indices matrix.
591
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
592
+ 1, P).expand(out_length, P)
593
+
594
+ # The weights used to compute the k-th output pixel are in row k of the
595
+ # weights matrix.
596
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
597
+ # apply cubic kernel
598
+ if (scale < 1) and (antialiasing):
599
+ weights = scale * cubic(distance_to_center * scale)
600
+ else:
601
+ weights = cubic(distance_to_center)
602
+ # Normalize the weights matrix so that each row sums to 1.
603
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
604
+ weights = weights / weights_sum.expand(out_length, P)
605
+
606
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
607
+ weights_zero_tmp = torch.sum((weights == 0), 0)
608
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
609
+ indices = indices.narrow(1, 1, P - 2)
610
+ weights = weights.narrow(1, 1, P - 2)
611
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
612
+ indices = indices.narrow(1, 0, P - 2)
613
+ weights = weights.narrow(1, 0, P - 2)
614
+ weights = weights.contiguous()
615
+ indices = indices.contiguous()
616
+ sym_len_s = -indices.min() + 1
617
+ sym_len_e = indices.max() - in_length
618
+ indices = indices + sym_len_s - 1
619
+ return weights, indices, int(sym_len_s), int(sym_len_e)
620
+
621
+
622
+ # --------------------------------
623
+ # imresize for tensor image
624
+ # --------------------------------
625
+ def imresize(img, scale, antialiasing=True):
626
+ # Now the scale should be the same for H and W
627
+ # input: img: pytorch tensor, CHW or HW [0,1]
628
+ # output: CHW or HW [0,1] w/o round
629
+ need_squeeze = True if img.dim() == 2 else False
630
+ if need_squeeze:
631
+ img.unsqueeze_(0)
632
+ in_C, in_H, in_W = img.size()
633
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
634
+ kernel_width = 4
635
+ kernel = 'cubic'
636
+
637
+ # Return the desired dimension order for performing the resize. The
638
+ # strategy is to perform the resize first along the dimension with the
639
+ # smallest scale factor.
640
+ # Now we do not support this.
641
+
642
+ # get weights and indices
643
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
644
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
645
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
646
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
647
+ # process H dimension
648
+ # symmetric copying
649
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
650
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
651
+
652
+ sym_patch = img[:, :sym_len_Hs, :]
653
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
654
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
655
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
656
+
657
+ sym_patch = img[:, -sym_len_He:, :]
658
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
659
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
660
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
661
+
662
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
663
+ kernel_width = weights_H.size(1)
664
+ for i in range(out_H):
665
+ idx = int(indices_H[i][0])
666
+ for j in range(out_C):
667
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
668
+
669
+ # process W dimension
670
+ # symmetric copying
671
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
672
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
673
+
674
+ sym_patch = out_1[:, :, :sym_len_Ws]
675
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
676
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
677
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
678
+
679
+ sym_patch = out_1[:, :, -sym_len_We:]
680
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
681
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
682
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
683
+
684
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
685
+ kernel_width = weights_W.size(1)
686
+ for i in range(out_W):
687
+ idx = int(indices_W[i][0])
688
+ for j in range(out_C):
689
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
690
+ if need_squeeze:
691
+ out_2.squeeze_()
692
+ return out_2
693
+
694
+
695
+ # --------------------------------
696
+ # imresize for numpy image
697
+ # --------------------------------
698
+ def imresize_np(img, scale, antialiasing=True):
699
+ # Now the scale should be the same for H and W
700
+ # input: img: Numpy, HWC or HW [0,1]
701
+ # output: HWC or HW [0,1] w/o round
702
+ img = torch.from_numpy(img)
703
+ need_squeeze = True if img.dim() == 2 else False
704
+ if need_squeeze:
705
+ img.unsqueeze_(2)
706
+
707
+ in_H, in_W, in_C = img.size()
708
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
709
+ kernel_width = 4
710
+ kernel = 'cubic'
711
+
712
+ # Return the desired dimension order for performing the resize. The
713
+ # strategy is to perform the resize first along the dimension with the
714
+ # smallest scale factor.
715
+ # Now we do not support this.
716
+
717
+ # get weights and indices
718
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
719
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
720
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
721
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
722
+ # process H dimension
723
+ # symmetric copying
724
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
725
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
726
+
727
+ sym_patch = img[:sym_len_Hs, :, :]
728
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
729
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
730
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
731
+
732
+ sym_patch = img[-sym_len_He:, :, :]
733
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
734
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
735
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
736
+
737
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
738
+ kernel_width = weights_H.size(1)
739
+ for i in range(out_H):
740
+ idx = int(indices_H[i][0])
741
+ for j in range(out_C):
742
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
743
+
744
+ # process W dimension
745
+ # symmetric copying
746
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
747
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
748
+
749
+ sym_patch = out_1[:, :sym_len_Ws, :]
750
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
751
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
752
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
753
+
754
+ sym_patch = out_1[:, -sym_len_We:, :]
755
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
756
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
757
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
758
+
759
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
760
+ kernel_width = weights_W.size(1)
761
+ for i in range(out_W):
762
+ idx = int(indices_W[i][0])
763
+ for j in range(out_C):
764
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
765
+ if need_squeeze:
766
+ out_2.squeeze_()
767
+
768
+ return out_2.numpy()
769
+
770
+
771
+ if __name__ == '__main__':
772
+ img = imread_uint('test.bmp',3)
utils/utils_logger.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import datetime
4
+ import logging
5
+
6
+
7
+ def log(*args, **kwargs):
8
+ print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
9
+
10
+
11
+ '''
12
+ # ===============================
13
+ # logger
14
+ # logger_name = None = 'base' ???
15
+ # ===============================
16
+ '''
17
+
18
+
19
+ def logger_info(logger_name, log_path='default_logger.log'):
20
+ ''' set up logger
21
+ modified by Kai Zhang (github: https://github.com/cszn)
22
+ '''
23
+ log = logging.getLogger(logger_name)
24
+ if log.hasHandlers():
25
+ print('LogHandlers exist!')
26
+ else:
27
+ print('LogHandlers setup!')
28
+ level = logging.INFO
29
+ formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
30
+ fh = logging.FileHandler(log_path, mode='a')
31
+ fh.setFormatter(formatter)
32
+ log.setLevel(level)
33
+ log.addHandler(fh)
34
+ # print(len(log.handlers))
35
+
36
+ sh = logging.StreamHandler()
37
+ sh.setFormatter(formatter)
38
+ log.addHandler(sh)
39
+
40
+
41
+ '''
42
+ # ===============================
43
+ # print to file and std_out simultaneously
44
+ # ===============================
45
+ '''
46
+
47
+
48
+ class logger_print(object):
49
+ def __init__(self, log_path="default.log"):
50
+ self.terminal = sys.stdout
51
+ self.log = open(log_path, 'a')
52
+
53
+ def write(self, message):
54
+ self.terminal.write(message)
55
+ self.log.write(message) # write the message
56
+
57
+ def flush(self):
58
+ pass