Upload 22 files
Browse files- .gitattributes +2 -0
- LICENSE +21 -0
- NTIRE2025-EfficientSR.log +0 -0
- README.md +70 -3
- figs/DSCF_arch.png +3 -0
- figs/logo.png +0 -0
- model_zoo/team00_EFDN.pth +3 -0
- model_zoo/team23_DSCF.pth +3 -0
- models/__pycache__/team00_EFDN.cpython-311.pyc +0 -0
- models/__pycache__/team23_DSCF.cpython-311.pyc +0 -0
- models/team00_EFDN.py +410 -0
- models/team23_DSCF.py +466 -0
- requirements.txt +136 -0
- results.txt +3 -0
- run.sh +29 -0
- test_demo.py +299 -0
- utils/__pycache__/model_summary.cpython-311.pyc +0 -0
- utils/__pycache__/utils_image.cpython-311.pyc +0 -0
- utils/__pycache__/utils_logger.cpython-311.pyc +0 -0
- utils/model_summary.py +465 -0
- utils/test.bmp +3 -0
- utils/utils_image.py +772 -0
- utils/utils_logger.py +58 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
figs/DSCF_arch.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
utils/test.bmp filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 BinRen
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
NTIRE2025-EfficientSR.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
CHANGED
|
@@ -1,3 +1,70 @@
|
|
| 1 |
-
--
|
| 2 |
-
|
| 3 |
-
--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!--
|
| 2 |
+
* @Author: Yaozzz666
|
| 3 |
+
* @Date: 2025-03-21 13:49:25
|
| 4 |
+
* @LastEditors: Yaozzz666
|
| 5 |
+
* @LastEditTime: 2025-03-22 11:11:04
|
| 6 |
+
*
|
| 7 |
+
* Copyright (c) 2025 by ${Yaozzz666}, All Rights Reserved.
|
| 8 |
+
-->
|
| 9 |
+
# [NTIRE 2025 Challenge on Efficient Super-Resolution](https://cvlai.net/ntire/2025/) @ [CVPR 2025](https://cvpr.thecvf.com/)
|
| 10 |
+
|
| 11 |
+
## Distillation Supervised ConvLora Finetuning for SR
|
| 12 |
+
|
| 13 |
+
<div align=center>
|
| 14 |
+
<img src="https://github.com/Yaozzz666/DSCF-SR/blob/main/figs/DSCF_arch.png" width="800px"/>
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
- An overview of our DSCF-SR
|
| 18 |
+
|
| 19 |
+
## The Environments
|
| 20 |
+
|
| 21 |
+
The evaluation environments adopted by us is recorded in the `requirements.txt`. After you built your own basic Python (Python = 3.9 in our setting) setup via either *virtual environment* or *anaconda*, please try to keep similar to it via:
|
| 22 |
+
|
| 23 |
+
- Step1: install Pytorch first:
|
| 24 |
+
`pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117`
|
| 25 |
+
|
| 26 |
+
- Step2: install other libs via:
|
| 27 |
+
```pip install -r requirements.txt```
|
| 28 |
+
|
| 29 |
+
or take it as a reference based on your original environments.
|
| 30 |
+
|
| 31 |
+
## How to test the model?
|
| 32 |
+
1. Run the [`run.sh`](./run.sh)
|
| 33 |
+
```bash
|
| 34 |
+
CUDA_VISIBLE_DEVICES=0 python test_demo.py --data_dir [path to your data dir] --save_dir [path to your save dir] --model_id 23
|
| 35 |
+
```
|
| 36 |
+
- Be sure the change the directories `--data_dir` and `--save_dir`.
|
| 37 |
+
|
| 38 |
+
## How to calculate the number of parameters, FLOPs, and activations
|
| 39 |
+
|
| 40 |
+
```python
|
| 41 |
+
from utils.model_summary import get_model_flops, get_model_activation
|
| 42 |
+
from models.team00_EFDN import EFDN
|
| 43 |
+
from fvcore.nn import FlopCountAnalysis
|
| 44 |
+
|
| 45 |
+
model = EFDN()
|
| 46 |
+
|
| 47 |
+
input_dim = (3, 256, 256) # set the input dimension
|
| 48 |
+
activations, num_conv = get_model_activation(model, input_dim)
|
| 49 |
+
activations = activations / 10 ** 6
|
| 50 |
+
print("{:>16s} : {:<.4f} [M]".format("#Activations", activations))
|
| 51 |
+
print("{:>16s} : {:<d}".format("#Conv2d", num_conv))
|
| 52 |
+
|
| 53 |
+
# The FLOPs calculation in previous NTIRE_ESR Challenge
|
| 54 |
+
# flops = get_model_flops(model, input_dim, False)
|
| 55 |
+
# flops = flops / 10 ** 9
|
| 56 |
+
# print("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
|
| 57 |
+
|
| 58 |
+
# fvcore is used in NTIRE2025_ESR for FLOPs calculation
|
| 59 |
+
input_fake = torch.rand(1, 3, 256, 256).to(device)
|
| 60 |
+
flops = FlopCountAnalysis(model, input_fake).total()
|
| 61 |
+
flops = flops/10**9
|
| 62 |
+
print("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
|
| 63 |
+
|
| 64 |
+
num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
|
| 65 |
+
num_parameters = num_parameters / 10 ** 6
|
| 66 |
+
print("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters))
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## License and Acknowledgement
|
| 70 |
+
This code repository is release under [MIT License](LICENSE).
|
figs/DSCF_arch.png
ADDED
|
Git LFS Details
|
figs/logo.png
ADDED
|
model_zoo/team00_EFDN.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:71069f10ef234cd123f45ac8e69099f0a1fdc0c16afcfe2189e50071d47ce477
|
| 3 |
+
size 1153119
|
model_zoo/team23_DSCF.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:239e76bc4e4e738491c3963805c0a79f524400acf5fbde34b2cb1a855dd62cb1
|
| 3 |
+
size 535137
|
models/__pycache__/team00_EFDN.cpython-311.pyc
ADDED
|
Binary file (25.8 kB). View file
|
|
|
models/__pycache__/team23_DSCF.cpython-311.pyc
ADDED
|
Binary file (19.9 kB). View file
|
|
|
models/team00_EFDN.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ESA(nn.Module):
|
| 8 |
+
def __init__(self, n_feats, conv):
|
| 9 |
+
super(ESA, self).__init__()
|
| 10 |
+
f = n_feats // 4
|
| 11 |
+
self.conv1 = conv(n_feats, f, kernel_size=1)
|
| 12 |
+
self.conv_f = conv(f, f, kernel_size=1)
|
| 13 |
+
self.conv_max = conv(f, f, kernel_size=3, padding=1)
|
| 14 |
+
self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
|
| 15 |
+
self.conv3 = conv(f, f, kernel_size=3, padding=1)
|
| 16 |
+
self.conv3_ = conv(f, f, kernel_size=3, padding=1)
|
| 17 |
+
self.conv4 = conv(f, n_feats, kernel_size=1)
|
| 18 |
+
self.sigmoid = nn.Sigmoid()
|
| 19 |
+
self.relu = nn.ReLU(inplace=True)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
c1_ = (self.conv1(x))
|
| 23 |
+
c1 = self.conv2(c1_)
|
| 24 |
+
v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
|
| 25 |
+
v_range = self.relu(self.conv_max(v_max))
|
| 26 |
+
c3 = self.relu(self.conv3(v_range))
|
| 27 |
+
c3 = self.conv3_(c3)
|
| 28 |
+
c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False)
|
| 29 |
+
cf = self.conv_f(c1_)
|
| 30 |
+
c4 = self.conv4(c3+cf)
|
| 31 |
+
m = self.sigmoid(c4)
|
| 32 |
+
|
| 33 |
+
return x * m
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class conv(nn.Module):
|
| 37 |
+
def __init__(self, n_feats):
|
| 38 |
+
super(conv, self).__init__()
|
| 39 |
+
self.conv1x1 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)
|
| 40 |
+
self.act = nn.PReLU(num_parameters=n_feats)
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
return self.act(self.conv1x1(x))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Cell(nn.Module):
|
| 46 |
+
def __init__(self, n_feats=48, dynamic = True, deploy = False, L= None, with_13=False):
|
| 47 |
+
super(Cell, self).__init__()
|
| 48 |
+
|
| 49 |
+
self.conv1 = conv(n_feats)#nn.Conv2d(n_feats, n_feats, 1, 1, 0)
|
| 50 |
+
self.conv2 = EDBB_deploy(n_feats,n_feats)
|
| 51 |
+
self.conv3 = EDBB_deploy(n_feats,n_feats)
|
| 52 |
+
|
| 53 |
+
self.fuse = nn.Conv2d(n_feats*2, n_feats, 1, 1, 0)
|
| 54 |
+
|
| 55 |
+
self.att = ESA(n_feats, nn.Conv2d) #MAB(n_feats)# ENLCA(n_feats) #CoordAtt(n_feats,n_feats,10)#
|
| 56 |
+
|
| 57 |
+
self.branch = nn.ModuleList([nn.Conv2d(n_feats, n_feats//2, 1, 1, 0) for _ in range(4)])
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
out1 = self.conv1(x)
|
| 61 |
+
out2 = self.conv2(out1)
|
| 62 |
+
out3 = self.conv3(out2)
|
| 63 |
+
|
| 64 |
+
# fuse [x, out1, out2, out3]
|
| 65 |
+
out = self.fuse(torch.cat([self.branch[0](x), self.branch[1](out1), self.branch[2](out2), self.branch[3](out3)], dim=1))
|
| 66 |
+
out = self.att(out)
|
| 67 |
+
out += x
|
| 68 |
+
|
| 69 |
+
return out
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class EFDN(nn.Module):
|
| 73 |
+
def __init__(self, scale=4, in_channels=3, n_feats=48, out_channels=3):
|
| 74 |
+
super(EFDN, self).__init__()
|
| 75 |
+
self.head = nn.Conv2d(in_channels, n_feats, 3, 1, 1)
|
| 76 |
+
# body cells
|
| 77 |
+
self.cells = nn.ModuleList([Cell(n_feats) for _ in range(4)])
|
| 78 |
+
|
| 79 |
+
# fusion
|
| 80 |
+
self.local_fuse = nn.ModuleList([nn.Conv2d(n_feats*2, n_feats, 1, 1, 0) for _ in range(3)])
|
| 81 |
+
|
| 82 |
+
self.tail = nn.Sequential(
|
| 83 |
+
nn.Conv2d(n_feats, out_channels*(scale**2), 3, 1, 1),
|
| 84 |
+
nn.PixelShuffle(scale)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
# head
|
| 89 |
+
out0 = self.head(x)
|
| 90 |
+
|
| 91 |
+
# body cells
|
| 92 |
+
out1 = self.cells[0](out0)
|
| 93 |
+
out2 = self.cells[1](out1)
|
| 94 |
+
out2_fuse = self.local_fuse[0](torch.cat([out1, out2], dim=1))
|
| 95 |
+
out3 = self.cells[2](out2_fuse)
|
| 96 |
+
out3_fuse = self.local_fuse[1](torch.cat([out2, out3], dim=1))
|
| 97 |
+
out4 = self.cells[3](out3_fuse)
|
| 98 |
+
out4_fuse = self.local_fuse[2](torch.cat([out2, out4], dim=1))
|
| 99 |
+
|
| 100 |
+
out = out4_fuse + out0
|
| 101 |
+
|
| 102 |
+
# tail
|
| 103 |
+
out = self.tail(out)
|
| 104 |
+
|
| 105 |
+
return out.clamp(0,1)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# -------------------------------------------------
|
| 109 |
+
#This part code based on DBB(https://github.com/DingXiaoH/DiverseBranchBlock) and ECB(https://github.com/xindongzhang/ECBSR)
|
| 110 |
+
def multiscale(kernel, target_kernel_size):
|
| 111 |
+
H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
|
| 112 |
+
W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
|
| 113 |
+
return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class SeqConv3x3(nn.Module):
|
| 117 |
+
def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier):
|
| 118 |
+
super(SeqConv3x3, self).__init__()
|
| 119 |
+
|
| 120 |
+
self.type = seq_type
|
| 121 |
+
self.inp_planes = inp_planes
|
| 122 |
+
self.out_planes = out_planes
|
| 123 |
+
|
| 124 |
+
if self.type == 'conv1x1-conv3x3':
|
| 125 |
+
self.mid_planes = int(out_planes * depth_multiplier)
|
| 126 |
+
conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0)
|
| 127 |
+
self.k0 = conv0.weight
|
| 128 |
+
self.b0 = conv0.bias
|
| 129 |
+
|
| 130 |
+
conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3)
|
| 131 |
+
self.k1 = conv1.weight
|
| 132 |
+
self.b1 = conv1.bias
|
| 133 |
+
|
| 134 |
+
elif self.type == 'conv1x1-sobelx':
|
| 135 |
+
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
|
| 136 |
+
self.k0 = conv0.weight
|
| 137 |
+
self.b0 = conv0.bias
|
| 138 |
+
|
| 139 |
+
# init scale & bias
|
| 140 |
+
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
|
| 141 |
+
self.scale = nn.Parameter(scale)
|
| 142 |
+
# bias = 0.0
|
| 143 |
+
# bias = [bias for c in range(self.out_planes)]
|
| 144 |
+
# bias = torch.FloatTensor(bias)
|
| 145 |
+
bias = torch.randn(self.out_planes) * 1e-3
|
| 146 |
+
bias = torch.reshape(bias, (self.out_planes,))
|
| 147 |
+
self.bias = nn.Parameter(bias)
|
| 148 |
+
# init mask
|
| 149 |
+
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
|
| 150 |
+
for i in range(self.out_planes):
|
| 151 |
+
self.mask[i, 0, 0, 0] = 1.0
|
| 152 |
+
self.mask[i, 0, 1, 0] = 2.0
|
| 153 |
+
self.mask[i, 0, 2, 0] = 1.0
|
| 154 |
+
self.mask[i, 0, 0, 2] = -1.0
|
| 155 |
+
self.mask[i, 0, 1, 2] = -2.0
|
| 156 |
+
self.mask[i, 0, 2, 2] = -1.0
|
| 157 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
| 158 |
+
|
| 159 |
+
elif self.type == 'conv1x1-sobely':
|
| 160 |
+
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
|
| 161 |
+
self.k0 = conv0.weight
|
| 162 |
+
self.b0 = conv0.bias
|
| 163 |
+
|
| 164 |
+
# init scale & bias
|
| 165 |
+
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
|
| 166 |
+
self.scale = nn.Parameter(torch.FloatTensor(scale))
|
| 167 |
+
# bias = 0.0
|
| 168 |
+
# bias = [bias for c in range(self.out_planes)]
|
| 169 |
+
# bias = torch.FloatTensor(bias)
|
| 170 |
+
bias = torch.randn(self.out_planes) * 1e-3
|
| 171 |
+
bias = torch.reshape(bias, (self.out_planes,))
|
| 172 |
+
self.bias = nn.Parameter(torch.FloatTensor(bias))
|
| 173 |
+
# init mask
|
| 174 |
+
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
|
| 175 |
+
for i in range(self.out_planes):
|
| 176 |
+
self.mask[i, 0, 0, 0] = 1.0
|
| 177 |
+
self.mask[i, 0, 0, 1] = 2.0
|
| 178 |
+
self.mask[i, 0, 0, 2] = 1.0
|
| 179 |
+
self.mask[i, 0, 2, 0] = -1.0
|
| 180 |
+
self.mask[i, 0, 2, 1] = -2.0
|
| 181 |
+
self.mask[i, 0, 2, 2] = -1.0
|
| 182 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
| 183 |
+
|
| 184 |
+
elif self.type == 'conv1x1-laplacian':
|
| 185 |
+
conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
|
| 186 |
+
self.k0 = conv0.weight
|
| 187 |
+
self.b0 = conv0.bias
|
| 188 |
+
|
| 189 |
+
# init scale & bias
|
| 190 |
+
scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
|
| 191 |
+
self.scale = nn.Parameter(torch.FloatTensor(scale))
|
| 192 |
+
# bias = 0.0
|
| 193 |
+
# bias = [bias for c in range(self.out_planes)]
|
| 194 |
+
# bias = torch.FloatTensor(bias)
|
| 195 |
+
bias = torch.randn(self.out_planes) * 1e-3
|
| 196 |
+
bias = torch.reshape(bias, (self.out_planes,))
|
| 197 |
+
self.bias = nn.Parameter(torch.FloatTensor(bias))
|
| 198 |
+
# init mask
|
| 199 |
+
self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
|
| 200 |
+
for i in range(self.out_planes):
|
| 201 |
+
self.mask[i, 0, 0, 1] = 1.0
|
| 202 |
+
self.mask[i, 0, 1, 0] = 1.0
|
| 203 |
+
self.mask[i, 0, 1, 2] = 1.0
|
| 204 |
+
self.mask[i, 0, 2, 1] = 1.0
|
| 205 |
+
self.mask[i, 0, 1, 1] = -4.0
|
| 206 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
| 207 |
+
else:
|
| 208 |
+
raise ValueError('the type of seqconv is not supported!')
|
| 209 |
+
|
| 210 |
+
def forward(self, x):
|
| 211 |
+
if self.type == 'conv1x1-conv3x3':
|
| 212 |
+
# conv-1x1
|
| 213 |
+
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
|
| 214 |
+
# explicitly padding with bias
|
| 215 |
+
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
|
| 216 |
+
b0_pad = self.b0.view(1, -1, 1, 1)
|
| 217 |
+
y0[:, :, 0:1, :] = b0_pad
|
| 218 |
+
y0[:, :, -1:, :] = b0_pad
|
| 219 |
+
y0[:, :, :, 0:1] = b0_pad
|
| 220 |
+
y0[:, :, :, -1:] = b0_pad
|
| 221 |
+
# conv-3x3
|
| 222 |
+
y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
|
| 223 |
+
else:
|
| 224 |
+
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
|
| 225 |
+
# explicitly padding with bias
|
| 226 |
+
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
|
| 227 |
+
b0_pad = self.b0.view(1, -1, 1, 1)
|
| 228 |
+
y0[:, :, 0:1, :] = b0_pad
|
| 229 |
+
y0[:, :, -1:, :] = b0_pad
|
| 230 |
+
y0[:, :, :, 0:1] = b0_pad
|
| 231 |
+
y0[:, :, :, -1:] = b0_pad
|
| 232 |
+
# conv-3x3
|
| 233 |
+
y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes)
|
| 234 |
+
return y1
|
| 235 |
+
|
| 236 |
+
def rep_params(self):
|
| 237 |
+
device = self.k0.get_device()
|
| 238 |
+
if device < 0:
|
| 239 |
+
device = None
|
| 240 |
+
|
| 241 |
+
if self.type == 'conv1x1-conv3x3':
|
| 242 |
+
# re-param conv kernel
|
| 243 |
+
RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
|
| 244 |
+
# re-param conv bias
|
| 245 |
+
RB = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
|
| 246 |
+
RB = F.conv2d(input=RB, weight=self.k1).view(-1,) + self.b1
|
| 247 |
+
else:
|
| 248 |
+
tmp = self.scale * self.mask
|
| 249 |
+
k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3), device=device)
|
| 250 |
+
for i in range(self.out_planes):
|
| 251 |
+
k1[i, i, :, :] = tmp[i, 0, :, :]
|
| 252 |
+
b1 = self.bias
|
| 253 |
+
# re-param conv kernel
|
| 254 |
+
RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
|
| 255 |
+
# re-param conv bias
|
| 256 |
+
RB = torch.ones(1, self.out_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
|
| 257 |
+
RB = F.conv2d(input=RB, weight=k1).view(-1,) + b1
|
| 258 |
+
return RK, RB
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class EDBB(nn.Module):
|
| 262 |
+
def __init__(self, inp_planes, out_planes, depth_multiplier=None, act_type='prelu', with_idt = False, deploy=False, with_13=False, gv=False):
|
| 263 |
+
super(EDBB, self).__init__()
|
| 264 |
+
|
| 265 |
+
self.deploy = deploy
|
| 266 |
+
self.act_type = act_type
|
| 267 |
+
|
| 268 |
+
self.inp_planes = inp_planes
|
| 269 |
+
self.out_planes = out_planes
|
| 270 |
+
|
| 271 |
+
self.gv = gv
|
| 272 |
+
|
| 273 |
+
if depth_multiplier is None:
|
| 274 |
+
self.depth_multiplier = 1.0
|
| 275 |
+
else:
|
| 276 |
+
self.depth_multiplier = depth_multiplier # For mobilenet, it is better to have 2X internal channels
|
| 277 |
+
|
| 278 |
+
if deploy:
|
| 279 |
+
self.rep_conv = nn.Conv2d(in_channels=inp_planes, out_channels=out_planes, kernel_size=3, stride=1,
|
| 280 |
+
padding=1, bias=True)
|
| 281 |
+
else:
|
| 282 |
+
self.with_13 = with_13
|
| 283 |
+
if with_idt and (self.inp_planes == self.out_planes):
|
| 284 |
+
self.with_idt = True
|
| 285 |
+
else:
|
| 286 |
+
self.with_idt = False
|
| 287 |
+
|
| 288 |
+
self.rep_conv = nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=3, padding=1)
|
| 289 |
+
self.conv1x1 = nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
|
| 290 |
+
self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.inp_planes, self.out_planes, self.depth_multiplier)
|
| 291 |
+
self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes, self.out_planes, -1)
|
| 292 |
+
self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes, self.out_planes, -1)
|
| 293 |
+
self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes, self.out_planes, -1)
|
| 294 |
+
|
| 295 |
+
if self.act_type == 'prelu':
|
| 296 |
+
self.act = nn.PReLU(num_parameters=self.out_planes)
|
| 297 |
+
elif self.act_type == 'relu':
|
| 298 |
+
self.act = nn.ReLU(inplace=True)
|
| 299 |
+
elif self.act_type == 'rrelu':
|
| 300 |
+
self.act = nn.RReLU(lower=-0.05, upper=0.05)
|
| 301 |
+
elif self.act_type == 'softplus':
|
| 302 |
+
self.act = nn.Softplus()
|
| 303 |
+
elif self.act_type == 'linear':
|
| 304 |
+
pass
|
| 305 |
+
else:
|
| 306 |
+
raise ValueError('The type of activation if not support!')
|
| 307 |
+
|
| 308 |
+
def forward(self, x):
|
| 309 |
+
if self.deploy:
|
| 310 |
+
y = self.rep_conv(x)
|
| 311 |
+
elif self.gv:
|
| 312 |
+
y = self.rep_conv(x) + \
|
| 313 |
+
self.conv1x1_sbx(x) + \
|
| 314 |
+
self.conv1x1_sby(x) + \
|
| 315 |
+
self.conv1x1_lpl(x) + x
|
| 316 |
+
else:
|
| 317 |
+
y = self.rep_conv(x) + \
|
| 318 |
+
self.conv1x1(x) + \
|
| 319 |
+
self.conv1x1_sbx(x) + \
|
| 320 |
+
self.conv1x1_sby(x) + \
|
| 321 |
+
self.conv1x1_lpl(x)
|
| 322 |
+
#self.conv1x1_3x3(x) + \
|
| 323 |
+
if self.with_idt:
|
| 324 |
+
y += x
|
| 325 |
+
if self.with_13:
|
| 326 |
+
y += self.conv1x1_3x3(x)
|
| 327 |
+
|
| 328 |
+
if self.act_type != 'linear':
|
| 329 |
+
y = self.act(y)
|
| 330 |
+
return y
|
| 331 |
+
|
| 332 |
+
def switch_to_gv(self):
|
| 333 |
+
if self.gv:
|
| 334 |
+
return
|
| 335 |
+
self.gv = True
|
| 336 |
+
|
| 337 |
+
K0, B0 = self.rep_conv.weight, self.rep_conv.bias
|
| 338 |
+
K1, B1 = self.conv1x1_3x3.rep_params()
|
| 339 |
+
K5, B5 = multiscale(self.conv1x1.weight,3), self.conv1x1.bias
|
| 340 |
+
RK, RB = (K0+K5), (B0+B5)
|
| 341 |
+
if self.with_13:
|
| 342 |
+
RK, RB = RK + K1, RB + B1
|
| 343 |
+
|
| 344 |
+
self.rep_conv.weight.data = RK
|
| 345 |
+
self.rep_conv.bias.data = RB
|
| 346 |
+
|
| 347 |
+
for para in self.parameters():
|
| 348 |
+
para.detach_()
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def switch_to_deploy(self):
|
| 352 |
+
|
| 353 |
+
if self.deploy:
|
| 354 |
+
return
|
| 355 |
+
self.deploy = True
|
| 356 |
+
|
| 357 |
+
K0, B0 = self.rep_conv.weight, self.rep_conv.bias
|
| 358 |
+
K1, B1 = self.conv1x1_3x3.rep_params()
|
| 359 |
+
K2, B2 = self.conv1x1_sbx.rep_params()
|
| 360 |
+
K3, B3 = self.conv1x1_sby.rep_params()
|
| 361 |
+
K4, B4 = self.conv1x1_lpl.rep_params()
|
| 362 |
+
K5, B5 = multiscale(self.conv1x1.weight,3), self.conv1x1.bias
|
| 363 |
+
if self.gv:
|
| 364 |
+
RK, RB = (K0+K2+K3+K4), (B0+B2+B3+B4)
|
| 365 |
+
else:
|
| 366 |
+
RK, RB = (K0+K2+K3+K4+K5), (B0+B2+B3+B4+B5)
|
| 367 |
+
if self.with_13:
|
| 368 |
+
RK, RB = RK + K1, RB + B1
|
| 369 |
+
if self.with_idt:
|
| 370 |
+
device = RK.get_device()
|
| 371 |
+
if device < 0:
|
| 372 |
+
device = None
|
| 373 |
+
K_idt = torch.zeros(self.out_planes, self.out_planes, 3, 3, device=device)
|
| 374 |
+
for i in range(self.out_planes):
|
| 375 |
+
K_idt[i, i, 1, 1] = 1.0
|
| 376 |
+
B_idt = 0.0
|
| 377 |
+
RK, RB = RK + K_idt, RB + B_idt
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
self.rep_conv = nn.Conv2d(in_channels=self.inp_planes, out_channels=self.out_planes, kernel_size=3, stride=1,
|
| 381 |
+
padding=1, bias=True)
|
| 382 |
+
self.rep_conv.weight.data = RK
|
| 383 |
+
self.rep_conv.bias.data = RB
|
| 384 |
+
|
| 385 |
+
for para in self.parameters():
|
| 386 |
+
para.detach_()
|
| 387 |
+
|
| 388 |
+
#self.__delattr__('conv3x3')
|
| 389 |
+
self.__delattr__('conv1x1_3x3')
|
| 390 |
+
self.__delattr__('conv1x1')
|
| 391 |
+
self.__delattr__('conv1x1_sbx')
|
| 392 |
+
self.__delattr__('conv1x1_sby')
|
| 393 |
+
self.__delattr__('conv1x1_lpl')
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class EDBB_deploy(nn.Module):
|
| 397 |
+
def __init__(self, inp_planes, out_planes):
|
| 398 |
+
super(EDBB_deploy, self).__init__()
|
| 399 |
+
|
| 400 |
+
self.rep_conv = nn.Conv2d(in_channels=inp_planes, out_channels=out_planes, kernel_size=3, stride=1,
|
| 401 |
+
padding=1, bias=True)
|
| 402 |
+
|
| 403 |
+
self.act = nn.PReLU(num_parameters=out_planes)
|
| 404 |
+
|
| 405 |
+
def forward(self, x):
|
| 406 |
+
y = self.rep_conv(x)
|
| 407 |
+
y = self.act(y)
|
| 408 |
+
|
| 409 |
+
return y
|
| 410 |
+
# -------------------------------------------------
|
models/team23_DSCF.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import math
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
# from IPython import embed
|
| 8 |
+
|
| 9 |
+
class LoRALayer():
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
r: int,
|
| 13 |
+
lora_alpha: int,
|
| 14 |
+
lora_dropout: float,
|
| 15 |
+
merge_weights: bool,
|
| 16 |
+
):
|
| 17 |
+
self.r = r
|
| 18 |
+
self.lora_alpha = lora_alpha
|
| 19 |
+
# Optional dropout
|
| 20 |
+
if lora_dropout > 0.:
|
| 21 |
+
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
| 22 |
+
else:
|
| 23 |
+
self.lora_dropout = lambda x: x
|
| 24 |
+
# Mark the weight as unmerged
|
| 25 |
+
self.merged = False
|
| 26 |
+
self.merge_weights = merge_weights
|
| 27 |
+
|
| 28 |
+
class Lora_Conv2d(nn.Conv2d, LoRALayer):
|
| 29 |
+
# LoRA implemented in a dense layer
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
in_channels: int,
|
| 33 |
+
out_channels: int,
|
| 34 |
+
kernel_size: int,
|
| 35 |
+
r: int = 0,
|
| 36 |
+
lora_alpha: int = 1,
|
| 37 |
+
lora_dropout: float = 0.,
|
| 38 |
+
merge_weights: bool = True,
|
| 39 |
+
**kwargs
|
| 40 |
+
):
|
| 41 |
+
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
|
| 42 |
+
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
| 43 |
+
merge_weights=merge_weights)
|
| 44 |
+
assert type(kernel_size) is int
|
| 45 |
+
# print("in init")
|
| 46 |
+
# embed()
|
| 47 |
+
# Actual trainable parameters
|
| 48 |
+
if r > 0:
|
| 49 |
+
self.lora_A = nn.Parameter(
|
| 50 |
+
self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
|
| 51 |
+
)
|
| 52 |
+
self.lora_B = nn.Parameter(
|
| 53 |
+
self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
|
| 54 |
+
)
|
| 55 |
+
self.scaling = self.lora_alpha / self.r
|
| 56 |
+
# Freezing the pre-trained weight matrix
|
| 57 |
+
self.weight.requires_grad = False
|
| 58 |
+
# Freeze the bias
|
| 59 |
+
# if self.bias is not None:
|
| 60 |
+
# self.bias.requires_grad = False
|
| 61 |
+
self.reset_parameters()
|
| 62 |
+
|
| 63 |
+
def reset_parameters(self):
|
| 64 |
+
nn.Conv2d.reset_parameters(self)
|
| 65 |
+
if hasattr(self, 'lora_A'):
|
| 66 |
+
# initialize A the same way as the default for nn.Linear and B to zero
|
| 67 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
| 68 |
+
nn.init.zeros_(self.lora_B)
|
| 69 |
+
|
| 70 |
+
def train(self, mode: bool = True): # True for train and False for eval
|
| 71 |
+
|
| 72 |
+
nn.Conv2d.train(self, mode)
|
| 73 |
+
if mode:
|
| 74 |
+
if self.merge_weights and self.merged:
|
| 75 |
+
# Make sure that the weights are not merged
|
| 76 |
+
self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
|
| 77 |
+
self.merged = False
|
| 78 |
+
else:
|
| 79 |
+
# print("test")
|
| 80 |
+
# embed()
|
| 81 |
+
if self.merge_weights and not self.merged:
|
| 82 |
+
# print("merging")
|
| 83 |
+
# embed()
|
| 84 |
+
# Merge the weights and mark it
|
| 85 |
+
self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
|
| 86 |
+
self.merged = True
|
| 87 |
+
|
| 88 |
+
def forward(self, x: torch.Tensor):
|
| 89 |
+
# print(f"LoRA merged status: {self.merged}")
|
| 90 |
+
if self.r > 0 and not self.merged:
|
| 91 |
+
# print(f"lora_A: {self.lora_A}")
|
| 92 |
+
# print(f"lora_B: {self.lora_B}")
|
| 93 |
+
# print(f"LoRA contribution: {(self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling}")
|
| 94 |
+
return F.conv2d(
|
| 95 |
+
x,
|
| 96 |
+
self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
|
| 97 |
+
self.bias, self.stride, self.padding, self.dilation, self.groups
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
return nn.Conv2d.forward(self, x)
|
| 101 |
+
def _make_pair(value):
|
| 102 |
+
if isinstance(value, int):
|
| 103 |
+
value = (value,) * 2
|
| 104 |
+
return value
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def conv_layer(in_channels,
|
| 108 |
+
out_channels,
|
| 109 |
+
kernel_size,
|
| 110 |
+
bias=True):
|
| 111 |
+
"""
|
| 112 |
+
Re-write convolution layer for adaptive `padding`.
|
| 113 |
+
"""
|
| 114 |
+
kernel_size = _make_pair(kernel_size)
|
| 115 |
+
padding = (int((kernel_size[0] - 1) / 2),
|
| 116 |
+
int((kernel_size[1] - 1) / 2))
|
| 117 |
+
return nn.Conv2d(in_channels,
|
| 118 |
+
out_channels,
|
| 119 |
+
kernel_size,
|
| 120 |
+
padding=padding,
|
| 121 |
+
bias=bias)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def activation(act_type, inplace=True, neg_slope=0.05, n_prelu=1):
|
| 125 |
+
"""
|
| 126 |
+
Activation functions for ['relu', 'lrelu', 'prelu'].
|
| 127 |
+
Parameters
|
| 128 |
+
----------
|
| 129 |
+
act_type: str
|
| 130 |
+
one of ['relu', 'lrelu', 'prelu'].
|
| 131 |
+
inplace: bool
|
| 132 |
+
whether to use inplace operator.
|
| 133 |
+
neg_slope: float
|
| 134 |
+
slope of negative region for `lrelu` or `prelu`.
|
| 135 |
+
n_prelu: int
|
| 136 |
+
`num_parameters` for `prelu`.
|
| 137 |
+
----------
|
| 138 |
+
"""
|
| 139 |
+
act_type = act_type.lower()
|
| 140 |
+
if act_type == 'relu':
|
| 141 |
+
layer = nn.ReLU(inplace)
|
| 142 |
+
elif act_type == 'lrelu':
|
| 143 |
+
layer = nn.LeakyReLU(neg_slope, inplace)
|
| 144 |
+
elif act_type == 'prelu':
|
| 145 |
+
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
| 146 |
+
else:
|
| 147 |
+
raise NotImplementedError(
|
| 148 |
+
'activation layer [{:s}] is not found'.format(act_type))
|
| 149 |
+
return layer
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def sequential(*args):
|
| 153 |
+
"""
|
| 154 |
+
Modules will be added to the a Sequential Container in the order they
|
| 155 |
+
are passed.
|
| 156 |
+
|
| 157 |
+
Parameters
|
| 158 |
+
----------
|
| 159 |
+
args: Definition of Modules in order.
|
| 160 |
+
-------
|
| 161 |
+
"""
|
| 162 |
+
if len(args) == 1:
|
| 163 |
+
if isinstance(args[0], OrderedDict):
|
| 164 |
+
raise NotImplementedError(
|
| 165 |
+
'sequential does not support OrderedDict input.')
|
| 166 |
+
return args[0]
|
| 167 |
+
modules = []
|
| 168 |
+
for module in args:
|
| 169 |
+
if isinstance(module, nn.Sequential):
|
| 170 |
+
for submodule in module.children():
|
| 171 |
+
modules.append(submodule)
|
| 172 |
+
elif isinstance(module, nn.Module):
|
| 173 |
+
modules.append(module)
|
| 174 |
+
return nn.Sequential(*modules)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def pixelshuffle_block(in_channels,
|
| 178 |
+
out_channels,
|
| 179 |
+
upscale_factor=2,
|
| 180 |
+
kernel_size=3):
|
| 181 |
+
"""
|
| 182 |
+
Upsample features according to `upscale_factor`.
|
| 183 |
+
"""
|
| 184 |
+
conv = conv_layer(in_channels,
|
| 185 |
+
out_channels * (upscale_factor ** 2),
|
| 186 |
+
kernel_size)
|
| 187 |
+
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
| 188 |
+
return sequential(conv, pixel_shuffle)
|
| 189 |
+
|
| 190 |
+
class Conv3XC(nn.Module):
|
| 191 |
+
def __init__(self, c_in, c_out, gain1=1, gain2=0, s=1, bias=True, relu=False):
|
| 192 |
+
super(Conv3XC, self).__init__()
|
| 193 |
+
self.weight_concat = None
|
| 194 |
+
self.bias_concat = None
|
| 195 |
+
self.update_params_flag = False
|
| 196 |
+
self.stride = s
|
| 197 |
+
self.has_relu = relu
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
self.eval_conv = nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, padding=1, stride=s, bias=bias)
|
| 201 |
+
|
| 202 |
+
def forward(self, x):
|
| 203 |
+
out = self.eval_conv(x)
|
| 204 |
+
if self.has_relu:
|
| 205 |
+
out = F.leaky_relu(out, negative_slope=0.05)
|
| 206 |
+
return out
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class SPAB(nn.Module):
|
| 210 |
+
def __init__(self,
|
| 211 |
+
in_channels,
|
| 212 |
+
mid_channels=None,
|
| 213 |
+
out_channels=None,
|
| 214 |
+
bias=False):
|
| 215 |
+
super(SPAB, self).__init__()
|
| 216 |
+
if mid_channels is None:
|
| 217 |
+
mid_channels = in_channels
|
| 218 |
+
if out_channels is None:
|
| 219 |
+
out_channels = in_channels
|
| 220 |
+
|
| 221 |
+
self.in_channels = in_channels
|
| 222 |
+
self.c1_r = Conv3XC(in_channels, mid_channels, gain1=2, s=1)
|
| 223 |
+
self.c2_r = Conv3XC(mid_channels, mid_channels, gain1=2, s=1)
|
| 224 |
+
self.c3_r = Conv3XC(mid_channels, out_channels, gain1=2, s=1)
|
| 225 |
+
self.act1 = torch.nn.SiLU(inplace=True)
|
| 226 |
+
# self.act2 = activation('lrelu', neg_slope=0.1, inplace=True)
|
| 227 |
+
|
| 228 |
+
def forward(self, x):
|
| 229 |
+
out1 = (self.c1_r(x))
|
| 230 |
+
out1_act = self.act1(out1)
|
| 231 |
+
|
| 232 |
+
out2 = (self.c2_r(out1_act))
|
| 233 |
+
out2_act = self.act1(out2)
|
| 234 |
+
|
| 235 |
+
out3 = (self.c3_r(out2_act))
|
| 236 |
+
|
| 237 |
+
sim_att = torch.sigmoid(out3) - 0.5
|
| 238 |
+
out = (out3 + x) * sim_att
|
| 239 |
+
# out = out3 * sim_att
|
| 240 |
+
# return out, out1, sim_att
|
| 241 |
+
return out, out1, out2,out3
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class DSCF(nn.Module):
|
| 245 |
+
"""
|
| 246 |
+
Swift Parameter-free Attention Network for Efficient Super-Resolution
|
| 247 |
+
"""
|
| 248 |
+
|
| 249 |
+
def __init__(self,
|
| 250 |
+
num_in_ch,
|
| 251 |
+
num_out_ch,
|
| 252 |
+
feature_channels=26,
|
| 253 |
+
upscale=4,
|
| 254 |
+
bias=True,
|
| 255 |
+
img_range=255.,
|
| 256 |
+
rgb_mean=(0.4488, 0.4371, 0.4040)
|
| 257 |
+
):
|
| 258 |
+
super(DSCF, self).__init__()
|
| 259 |
+
|
| 260 |
+
in_channels = num_in_ch
|
| 261 |
+
out_channels = num_out_ch
|
| 262 |
+
self.img_range = img_range
|
| 263 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
| 264 |
+
|
| 265 |
+
self.conv_1 = Conv3XC(in_channels, feature_channels, gain1=2, s=1)
|
| 266 |
+
self.block_1 = SPAB(feature_channels, bias=bias)
|
| 267 |
+
self.block_2 = SPAB(feature_channels, bias=bias)
|
| 268 |
+
self.block_3 = SPAB(feature_channels, bias=bias)
|
| 269 |
+
self.block_4 = SPAB(feature_channels, bias=bias)
|
| 270 |
+
self.block_5 = SPAB(feature_channels, bias=bias)
|
| 271 |
+
self.block_6 = SPAB(feature_channels, bias=bias)
|
| 272 |
+
|
| 273 |
+
self.conv_cat = conv_layer(feature_channels * 4, feature_channels, kernel_size=1, bias=True)
|
| 274 |
+
self.conv_2 = Conv3XC(feature_channels, feature_channels, gain1=2, s=1)
|
| 275 |
+
|
| 276 |
+
self.upsampler = pixelshuffle_block(feature_channels, out_channels, upscale_factor=upscale)
|
| 277 |
+
|
| 278 |
+
# 指定需要替换 LoRA 层的子模块名称
|
| 279 |
+
# desired_submodules = ["conv_1.eval_conv",
|
| 280 |
+
# "block_1.c1_r.eval_conv","block_1.c2_r.eval_conv","block_1.c3_r.eval_conv",
|
| 281 |
+
# "block_2.c1_r.eval_conv","block_2.c2_r.eval_conv","block_2.c3_r.eval_conv",
|
| 282 |
+
# "block_3.c1_r.eval_conv","block_3.c2_r.eval_conv","block_3.c3_r.eval_conv",
|
| 283 |
+
# "block_4.c1_r.eval_conv","block_4.c2_r.eval_conv","block_4.c3_r.eval_conv",
|
| 284 |
+
# "block_5.c1_r.eval_conv","block_5.c2_r.eval_conv","block_5.c3_r.eval_conv",
|
| 285 |
+
# "block_6.c1_r.eval_conv","block_6.c2_r.eval_conv","block_6.c3_r.eval_conv",
|
| 286 |
+
# "conv_2.eval_conv",
|
| 287 |
+
# "conv_cat",
|
| 288 |
+
# "upsampler.0"]
|
| 289 |
+
|
| 290 |
+
# desired_submodules = ["conv_2.eval_conv","upsampler.0"]
|
| 291 |
+
# # 替换需要 LoRA 处理的层
|
| 292 |
+
# self.replace_layers(desired_submodules)
|
| 293 |
+
|
| 294 |
+
# self.mark_only_lora_as_trainable(bias='none')
|
| 295 |
+
# 分层LoRA配置字典(模块名: (r, lora_alpha))
|
| 296 |
+
# self.lora_config = {
|
| 297 |
+
# # 高频重建核心层 (最高优先级)
|
| 298 |
+
# "conv_2.eval_conv": (8, 16), # 最大秩
|
| 299 |
+
# "upsampler.0": (8, 16), # 高秩
|
| 300 |
+
|
| 301 |
+
# # 中间处理层 (梯度传播关键路径)
|
| 302 |
+
# **{f"block_{i}.c{j}_r.eval_conv": (2, 4)
|
| 303 |
+
# for i in [2,3,4,5] # block_2到block_5
|
| 304 |
+
# for j in [1,2,3]}, # 每个block的三个卷积
|
| 305 |
+
|
| 306 |
+
# # 首尾层 (适度调整)
|
| 307 |
+
# "block_1.c1_r.eval_conv": (2, 4),
|
| 308 |
+
# "block_1.c2_r.eval_conv": (2, 4),
|
| 309 |
+
# "block_1.c3_r.eval_conv": (2, 4),
|
| 310 |
+
# "block_6.c1_r.eval_conv": (2, 4),
|
| 311 |
+
# "block_6.c2_r.eval_conv": (2, 4),
|
| 312 |
+
# "block_6.c3_r.eval_conv": (2, 4),
|
| 313 |
+
# }
|
| 314 |
+
|
| 315 |
+
# # 替换需要 LoRA 处理的层
|
| 316 |
+
# self.replace_layers_with_strategy()
|
| 317 |
+
|
| 318 |
+
# 冻结非LoRA参数
|
| 319 |
+
# self.mark_only_lora_as_trainable(bias='none')
|
| 320 |
+
# self.cuda()(torch.randn(1, 3, 256, 256).cuda())
|
| 321 |
+
# self.eval().cuda()
|
| 322 |
+
self.eval().cuda()
|
| 323 |
+
input_tensor = torch.randn(1, 3, 256, 256).cuda()
|
| 324 |
+
output = self(input_tensor)
|
| 325 |
+
# 确保 LoRA 层参数可训练
|
| 326 |
+
# print("可训练参数:")
|
| 327 |
+
# for name, param in self.named_parameters():
|
| 328 |
+
# if param.requires_grad:
|
| 329 |
+
# print(f"{name}: {param.shape}")
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# def replace_layers_with_strategy(self):
|
| 333 |
+
# """根据分层策略替换卷积层"""
|
| 334 |
+
# for full_name, (r, alpha) in self.lora_config.items():
|
| 335 |
+
# parent, child_name = self._get_parent_and_child(full_name)
|
| 336 |
+
# if parent is None:
|
| 337 |
+
# # print(f"⚠️ Skip {full_name}: module not found")
|
| 338 |
+
# continue
|
| 339 |
+
|
| 340 |
+
# original_conv = getattr(parent, child_name, None)
|
| 341 |
+
# if not isinstance(original_conv, nn.Conv2d):
|
| 342 |
+
# # print(f"⚠️ {full_name} is not Conv2d (found {type(original_conv)})")
|
| 343 |
+
# continue
|
| 344 |
+
|
| 345 |
+
# # 动态设置参数
|
| 346 |
+
# new_layer = Lora_Conv2d(
|
| 347 |
+
# in_channels=original_conv.in_channels,
|
| 348 |
+
# out_channels=original_conv.out_channels,
|
| 349 |
+
# kernel_size=original_conv.kernel_size[0],
|
| 350 |
+
# stride=original_conv.stride,
|
| 351 |
+
# padding=original_conv.padding,
|
| 352 |
+
# bias=original_conv.bias is not None,
|
| 353 |
+
# r=r, # 动态设置秩
|
| 354 |
+
# lora_alpha=alpha # 动态设置缩放系数
|
| 355 |
+
# )
|
| 356 |
+
|
| 357 |
+
# # 继承原始权重
|
| 358 |
+
# with torch.no_grad():
|
| 359 |
+
# new_layer.weight.copy_(original_conv.weight)
|
| 360 |
+
# if original_conv.bias is not None:
|
| 361 |
+
# new_layer.bias.copy_(original_conv.bias)
|
| 362 |
+
|
| 363 |
+
# setattr(parent, child_name, new_layer)
|
| 364 |
+
# # print(f"✅ {full_name} => r={r}, alpha={alpha}")
|
| 365 |
+
|
| 366 |
+
# def _get_parent_and_child(self, module_name):
|
| 367 |
+
# """
|
| 368 |
+
# 获取模块的父级模块和子模块名称
|
| 369 |
+
# 例如:
|
| 370 |
+
# module_name = "block_5.c1_r.eval_conv"
|
| 371 |
+
# 则返回 (model.block_5.c1_r, "eval_conv")
|
| 372 |
+
# """
|
| 373 |
+
# parts = module_name.split(".")
|
| 374 |
+
# parent = self
|
| 375 |
+
# for part in parts[:-1]: # 遍历到倒数第二个
|
| 376 |
+
# if hasattr(parent, part):
|
| 377 |
+
# parent = getattr(parent, part)
|
| 378 |
+
# else:
|
| 379 |
+
# return None, None # 没找到路径
|
| 380 |
+
# return parent, parts[-1] # 返回父模块和子模块名称
|
| 381 |
+
|
| 382 |
+
# def replace_layers(self, desired_submodules):
|
| 383 |
+
# """
|
| 384 |
+
# 遍历模型的子模块,将符合条件的层替换为 Lora_Conv2d
|
| 385 |
+
# """
|
| 386 |
+
# # 替换conv_layer
|
| 387 |
+
# for name, module in self._modules.items():
|
| 388 |
+
# if name in desired_submodules:
|
| 389 |
+
# print('--------------------self._modules.items--------------------------')
|
| 390 |
+
# print(name)
|
| 391 |
+
# if isinstance(module, nn.Conv2d):
|
| 392 |
+
# print(f"Replacing {name} with Lora_Conv2d")
|
| 393 |
+
# setattr(self, name, Lora_Conv2d(
|
| 394 |
+
# module.in_channels,
|
| 395 |
+
# module.out_channels,
|
| 396 |
+
# kernel_size=module.kernel_size[0],
|
| 397 |
+
# stride=module.stride,
|
| 398 |
+
# padding=module.padding,
|
| 399 |
+
# bias=True,
|
| 400 |
+
# r=2,
|
| 401 |
+
# lora_alpha=2
|
| 402 |
+
# ))
|
| 403 |
+
|
| 404 |
+
# def mark_only_lora_as_trainable(self, bias: str = 'none'):
|
| 405 |
+
# """
|
| 406 |
+
# 只训练 LoRA 相关参数,而冻结所有其他参数。
|
| 407 |
+
|
| 408 |
+
# 参数:
|
| 409 |
+
# - bias: 'none' (不训练 bias), 'all' (训练所有 bias), 'lora_only' (只训练 LoRA 层的 bias)
|
| 410 |
+
# """
|
| 411 |
+
# # 冻结所有非 LoRA 参数
|
| 412 |
+
# # for n, p in self.named_parameters():
|
| 413 |
+
# # if 'lora_' not in n:
|
| 414 |
+
# # p.requires_grad = False
|
| 415 |
+
# for n, p in self.named_parameters():
|
| 416 |
+
# if 'lora_' not in n:
|
| 417 |
+
# p.requires_grad = False # 冻结非 LoRA 参数
|
| 418 |
+
# else:
|
| 419 |
+
# p.requires_grad = True # 解冻 LoRA 参数
|
| 420 |
+
|
| 421 |
+
# if bias == 'none':
|
| 422 |
+
# return
|
| 423 |
+
# elif bias == 'all':
|
| 424 |
+
# for n, p in self.named_parameters():
|
| 425 |
+
# if 'bias' in n:
|
| 426 |
+
# p.requires_grad = True
|
| 427 |
+
# elif bias == 'lora_only':
|
| 428 |
+
# for m in self.modules():
|
| 429 |
+
# if isinstance(m, LoRALayer) and hasattr(m, 'bias') and m.bias is not None:
|
| 430 |
+
# m.bias.requires_grad = True
|
| 431 |
+
# else:
|
| 432 |
+
# raise NotImplementedError(f"未知 bias 选项: {bias}")
|
| 433 |
+
def forward(self, x, return_features=False):
|
| 434 |
+
# features = []
|
| 435 |
+
self.mean = self.mean.type_as(x)
|
| 436 |
+
x = (x - self.mean) * self.img_range
|
| 437 |
+
|
| 438 |
+
out_feature = self.conv_1(x)
|
| 439 |
+
|
| 440 |
+
out_b1, out_b1_1, out_b1_2, out_b1_3 = self.block_1(out_feature)
|
| 441 |
+
out_b2, out_b2_1, out_b2_2, out_b2_3 = self.block_2(out_b1)
|
| 442 |
+
out_b3, out_b3_1, out_b3_2, out_b3_3 = self.block_3(out_b2)
|
| 443 |
+
|
| 444 |
+
out_b4, _, _, _ = self.block_4(out_b3)
|
| 445 |
+
out_b5, _, _, _ = self.block_5(out_b4)
|
| 446 |
+
out_b6, out_b5_2, _, _ = self.block_6(out_b5)
|
| 447 |
+
|
| 448 |
+
out_b6 = self.conv_2(out_b6)
|
| 449 |
+
out = self.conv_cat(torch.cat([out_feature, out_b6, out_b1, out_b5_2], 1))
|
| 450 |
+
output = self.upsampler(out)
|
| 451 |
+
|
| 452 |
+
# features.append(out_b1_1)
|
| 453 |
+
# features.append(out_b1_2)
|
| 454 |
+
# features.append(out_b1_3)
|
| 455 |
+
# features.append(out_b2_1)
|
| 456 |
+
# features.append(out_b2_2)
|
| 457 |
+
# features.append(out_b2_3)
|
| 458 |
+
# features.append(out_b3_1)
|
| 459 |
+
# features.append(out_b3_2)
|
| 460 |
+
# features.append(out_b3_3)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
if return_features:
|
| 464 |
+
return output, features # Return output and intermediate features
|
| 465 |
+
return output
|
| 466 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.0.0
|
| 2 |
+
anyio==4.0.0
|
| 3 |
+
appdirs==1.4.4
|
| 4 |
+
beartype==0.16.4
|
| 5 |
+
blessed==1.20.0
|
| 6 |
+
brotlipy==0.7.0
|
| 7 |
+
cachetools==5.3.1
|
| 8 |
+
certifi==2023.7.22
|
| 9 |
+
cffi==1.15.1
|
| 10 |
+
charset-normalizer==2.0.4
|
| 11 |
+
click==8.1.7
|
| 12 |
+
clip==0.2.0
|
| 13 |
+
cmake==3.27.7
|
| 14 |
+
contourpy==1.1.1
|
| 15 |
+
cryptography==41.0.3
|
| 16 |
+
cycler==0.12.1
|
| 17 |
+
docker-pycreds==0.4.0
|
| 18 |
+
dpcpp-cpp-rt==2024.0.2
|
| 19 |
+
einops==0.7.0
|
| 20 |
+
ema-pytorch==0.3.1
|
| 21 |
+
exceptiongroup==1.1.3
|
| 22 |
+
filelock==3.12.4
|
| 23 |
+
fonttools==4.43.1
|
| 24 |
+
fsspec==2023.9.2
|
| 25 |
+
ftfy==6.1.1
|
| 26 |
+
fvcore==0.1.5.post20221221
|
| 27 |
+
gitdb==4.0.10
|
| 28 |
+
GitPython==3.1.40
|
| 29 |
+
google-auth==2.23.3
|
| 30 |
+
google-auth-oauthlib==1.0.0
|
| 31 |
+
gpustat==1.1.1
|
| 32 |
+
grpcio==1.59.0
|
| 33 |
+
h11==0.14.0
|
| 34 |
+
h5py==3.10.0
|
| 35 |
+
huggingface-hub==0.18.0
|
| 36 |
+
idna==3.4
|
| 37 |
+
imageio==2.31.5
|
| 38 |
+
importlib-metadata==6.8.0
|
| 39 |
+
importlib-resources==6.1.0
|
| 40 |
+
intel-cmplr-lib-rt==2024.0.2
|
| 41 |
+
intel-cmplr-lic-rt==2024.0.2
|
| 42 |
+
intel-opencl-rt==2024.0.2
|
| 43 |
+
intel-openmp==2024.0.2
|
| 44 |
+
iopath==0.1.10
|
| 45 |
+
itsdangerous==2.1.2
|
| 46 |
+
kiwisolver==1.4.5
|
| 47 |
+
lit==17.0.3
|
| 48 |
+
lpips==0.1.4
|
| 49 |
+
Markdown==3.5
|
| 50 |
+
markdown-it-py==2.2.0
|
| 51 |
+
MarkupSafe==2.1.3
|
| 52 |
+
matplotlib==3.7.3
|
| 53 |
+
mdurl==0.1.2
|
| 54 |
+
mkl==2024.0.0
|
| 55 |
+
mkl-fft==1.3.6
|
| 56 |
+
mkl-random==1.2.2
|
| 57 |
+
mkl-service==2.4.0
|
| 58 |
+
mpmath==1.3.0
|
| 59 |
+
multidict==6.0.4
|
| 60 |
+
networkx==3.1
|
| 61 |
+
numpy==1.24.3
|
| 62 |
+
nvidia-ml-py==12.535.108
|
| 63 |
+
oauthlib==3.2.2
|
| 64 |
+
opencv-python==4.8.1.78
|
| 65 |
+
ordered-set==4.1.0
|
| 66 |
+
orjson==3.8.9
|
| 67 |
+
packaging==23.1
|
| 68 |
+
pandas==1.5.3
|
| 69 |
+
pathtools==0.1.2
|
| 70 |
+
Pillow==9.4.0
|
| 71 |
+
portalocker==2.8.2
|
| 72 |
+
protobuf==4.24.4
|
| 73 |
+
psutil==5.9.6
|
| 74 |
+
py-cpuinfo==9.0.0
|
| 75 |
+
pyasn1==0.5.0
|
| 76 |
+
pyasn1-modules==0.3.0
|
| 77 |
+
pycparser==2.21
|
| 78 |
+
pydantic==1.10.7
|
| 79 |
+
Pygments==2.16.1
|
| 80 |
+
PyJWT==2.6.0
|
| 81 |
+
pyOpenSSL==23.2.0
|
| 82 |
+
pyparsing==3.0.9
|
| 83 |
+
PySocks==1.7.1
|
| 84 |
+
python-dateutil==2.8.2
|
| 85 |
+
pytorch-fid==0.3.0
|
| 86 |
+
pytz==2023.3.post1
|
| 87 |
+
PyWavelets==1.4.1
|
| 88 |
+
PyYAML==6.0
|
| 89 |
+
readchar==4.0.5
|
| 90 |
+
regex==2023.10.3
|
| 91 |
+
requests==2.28.2
|
| 92 |
+
requests-oauthlib==1.3.1
|
| 93 |
+
rfc3986==1.5.0
|
| 94 |
+
rich==13.3.3
|
| 95 |
+
rsa==4.9
|
| 96 |
+
scikit-image==0.19.3
|
| 97 |
+
scikit-video==1.1.11
|
| 98 |
+
scipy==1.10.1
|
| 99 |
+
seaborn==0.12.2
|
| 100 |
+
sentry-sdk==1.14.0
|
| 101 |
+
setproctitle==1.3.2
|
| 102 |
+
six==1.16.0
|
| 103 |
+
smmap==5.0.0
|
| 104 |
+
sniffio==1.3.0
|
| 105 |
+
soupsieve==2.4
|
| 106 |
+
starlette==0.22.0
|
| 107 |
+
starsessions==1.3.0
|
| 108 |
+
sympy==1.11.1
|
| 109 |
+
tabulate==0.9.0
|
| 110 |
+
tbb==2021.11.0
|
| 111 |
+
tensorboard==2.14.0
|
| 112 |
+
tensorboard-data-server==0.7.1
|
| 113 |
+
termcolor==2.4.0
|
| 114 |
+
tifffile==2023.1.23.1
|
| 115 |
+
timm==0.6.12
|
| 116 |
+
torchmetrics==0.11.4
|
| 117 |
+
torchsummary==1.5.1
|
| 118 |
+
tqdm==4.66.1
|
| 119 |
+
triton==2.0.0
|
| 120 |
+
tsnecuda==3.0.1
|
| 121 |
+
typing_extensions==4.5.0
|
| 122 |
+
ujson==5.7.0
|
| 123 |
+
urllib3==1.26.14
|
| 124 |
+
uvicorn==0.21.1
|
| 125 |
+
uvloop==0.17.0
|
| 126 |
+
wandb==0.13.9
|
| 127 |
+
warmup-scheduler==0.3
|
| 128 |
+
watchfiles==0.19.0
|
| 129 |
+
wcwidth==0.2.8
|
| 130 |
+
websocket-client==1.5.1
|
| 131 |
+
websockets==11.0.1
|
| 132 |
+
Werkzeug==3.0.0
|
| 133 |
+
yacs==0.1.8
|
| 134 |
+
yapf==0.32.0
|
| 135 |
+
yarl==1.8.2
|
| 136 |
+
zipp==3.17.0
|
results.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Model Val PSNR Val Time [ms] Params [M] FLOPs [G] Acts [M] Mem [M] Conv
|
| 2 |
+
00_EFDN_baseline 26.93 34.31 0.276 16.70 111.12 662.89 65
|
| 3 |
+
23_DSCF 26.92 8.37 0.131 8.54 38.93 728.41 22
|
run.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- Evaluation on LSDIR_DIV2K_valid datasets for One Method: ---
|
| 2 |
+
CUDA_VISIBLE_DEVICES=0 python test_demo.py \
|
| 3 |
+
--data_dir ../ \
|
| 4 |
+
--save_dir ../results \
|
| 5 |
+
--model_id 23
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# --- When only LSDIR_DIV2K_test datasets are included (For Organizer) ---
|
| 9 |
+
# CUDA_VISIBLE_DEVICES=0 python test_demo.py \
|
| 10 |
+
# --data_dir ../ \
|
| 11 |
+
# --save_dir ../results \
|
| 12 |
+
# --include_test \
|
| 13 |
+
# --model_id 0
|
| 14 |
+
|
| 15 |
+
# --- Test all the methods (For Organizer) ---
|
| 16 |
+
#!/bin/bash
|
| 17 |
+
# DATA_DIR="/Your/Validate/Datasets/Path"
|
| 18 |
+
# SAVE_DIR="./results"
|
| 19 |
+
# MODEL_IDS=(
|
| 20 |
+
# 0 1 3 4 5 7 10 11 13 15
|
| 21 |
+
# 16 17 18 19 21 23 25 26
|
| 22 |
+
# 28 29 30 31 33 34 38 39
|
| 23 |
+
# 41 42 43 44 45 46 48
|
| 24 |
+
# )
|
| 25 |
+
|
| 26 |
+
# for model_id in "${MODEL_IDS[@]}"
|
| 27 |
+
# do
|
| 28 |
+
# CUDA_VISIBLE_DEVICES=0 python test_demo.py --data_dir "$DATA_DIR" --save_dir "$SAVE_DIR" --include_test --model_id "$model_id"
|
| 29 |
+
# done
|
test_demo.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
import argparse
|
| 5 |
+
import json
|
| 6 |
+
import glob
|
| 7 |
+
|
| 8 |
+
from pprint import pprint
|
| 9 |
+
from fvcore.nn import FlopCountAnalysis
|
| 10 |
+
from utils.model_summary import get_model_activation, get_model_flops
|
| 11 |
+
from utils import utils_logger
|
| 12 |
+
from utils import utils_image as util
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def select_model(args, device):
|
| 16 |
+
# Model ID is assigned according to the order of the submissions.
|
| 17 |
+
# Different networks are trained with input range of either [0,1] or [0,255]. The range is determined manually.
|
| 18 |
+
model_id = args.model_id
|
| 19 |
+
if model_id == 0:
|
| 20 |
+
# Baseline: The 1st Place of the `Overall Performance`` of the NTIRE 2023 Efficient SR Challenge
|
| 21 |
+
# Edge-enhanced Feature Distillation Network for Efficient Super-Resolution
|
| 22 |
+
# arXiv: https://arxiv.org/pdf/2204.08759
|
| 23 |
+
# Original Code: https://github.com/icandle/EFDN
|
| 24 |
+
# Ckpts: EFDN_gv.pth
|
| 25 |
+
from models.team00_EFDN import EFDN
|
| 26 |
+
name, data_range = f"{model_id:02}_EFDN_baseline", 1.0
|
| 27 |
+
model_path = os.path.join('model_zoo', 'team00_EFDN.pth')
|
| 28 |
+
model = EFDN()
|
| 29 |
+
model.load_state_dict(torch.load(model_path), strict=True)
|
| 30 |
+
elif model_id == 23:
|
| 31 |
+
from models.team23_DSCF import DSCF
|
| 32 |
+
|
| 33 |
+
name, data_range = f"{model_id:02}_DSCF", 1.0
|
| 34 |
+
model_path = os.path.join('model_zoo', 'team23_DSCF.pth')
|
| 35 |
+
model = DSCF(3,3,feature_channels=26,upscale=4)
|
| 36 |
+
state_dict = torch.load(model_path)
|
| 37 |
+
|
| 38 |
+
model.load_state_dict(state_dict, strict=False)
|
| 39 |
+
else:
|
| 40 |
+
raise NotImplementedError(f"Model {model_id} is not implemented.")
|
| 41 |
+
|
| 42 |
+
# print(model)
|
| 43 |
+
model.eval()
|
| 44 |
+
tile = None
|
| 45 |
+
for k, v in model.named_parameters():
|
| 46 |
+
v.requires_grad = False
|
| 47 |
+
model = model.to(device)
|
| 48 |
+
return model, name, data_range, tile
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def select_dataset(data_dir, mode):
|
| 52 |
+
# inference on the DIV2K_LSDIR_test set
|
| 53 |
+
if mode == "test":
|
| 54 |
+
path = [
|
| 55 |
+
(
|
| 56 |
+
p.replace("_HR", "_LR").replace(".png", "x4.png"),
|
| 57 |
+
p
|
| 58 |
+
) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_test_HR/*.png")))
|
| 59 |
+
]
|
| 60 |
+
|
| 61 |
+
# inference on the DIV2K_LSDIR_valid set
|
| 62 |
+
elif mode == "valid":
|
| 63 |
+
path = [
|
| 64 |
+
(
|
| 65 |
+
p.replace("_HR", "_LR").replace(".png", "x4.png"),
|
| 66 |
+
p
|
| 67 |
+
) for p in sorted(glob.glob(os.path.join(data_dir, "DIV2K_LSDIR_valid_HR/*.png")))
|
| 68 |
+
]
|
| 69 |
+
else:
|
| 70 |
+
raise NotImplementedError(f"{mode} is not implemented in select_dataset")
|
| 71 |
+
|
| 72 |
+
return path
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def forward(img_lq, model, tile=None, tile_overlap=32, scale=4):
|
| 76 |
+
if tile is None:
|
| 77 |
+
# test the image as a whole
|
| 78 |
+
output = model(img_lq)
|
| 79 |
+
else:
|
| 80 |
+
# test the image tile by tile
|
| 81 |
+
b, c, h, w = img_lq.size()
|
| 82 |
+
tile = min(tile, h, w)
|
| 83 |
+
tile_overlap = tile_overlap
|
| 84 |
+
sf = scale
|
| 85 |
+
|
| 86 |
+
stride = tile - tile_overlap
|
| 87 |
+
h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
|
| 88 |
+
w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
|
| 89 |
+
E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
|
| 90 |
+
W = torch.zeros_like(E)
|
| 91 |
+
|
| 92 |
+
for h_idx in h_idx_list:
|
| 93 |
+
for w_idx in w_idx_list:
|
| 94 |
+
in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
|
| 95 |
+
out_patch = model(in_patch)
|
| 96 |
+
out_patch_mask = torch.ones_like(out_patch)
|
| 97 |
+
|
| 98 |
+
E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
|
| 99 |
+
W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
|
| 100 |
+
output = E.div_(W)
|
| 101 |
+
|
| 102 |
+
return output
|
| 103 |
+
|
| 104 |
+
def run(model, model_name, data_range, tile, logger, device, args, mode="test"):
|
| 105 |
+
|
| 106 |
+
sf = 4
|
| 107 |
+
border = sf
|
| 108 |
+
results = dict()
|
| 109 |
+
results[f"{mode}_runtime"] = []
|
| 110 |
+
results[f"{mode}_psnr"] = []
|
| 111 |
+
if args.ssim:
|
| 112 |
+
results[f"{mode}_ssim"] = []
|
| 113 |
+
# results[f"{mode}_psnr_y"] = []
|
| 114 |
+
# results[f"{mode}_ssim_y"] = []
|
| 115 |
+
|
| 116 |
+
# --------------------------------
|
| 117 |
+
# dataset path
|
| 118 |
+
# --------------------------------
|
| 119 |
+
data_path = select_dataset(args.data_dir, mode)
|
| 120 |
+
save_path = os.path.join(args.save_dir, model_name, mode)
|
| 121 |
+
util.mkdir(save_path)
|
| 122 |
+
|
| 123 |
+
start = torch.cuda.Event(enable_timing=True)
|
| 124 |
+
end = torch.cuda.Event(enable_timing=True)
|
| 125 |
+
|
| 126 |
+
for i, (img_lr, img_hr) in enumerate(data_path):
|
| 127 |
+
|
| 128 |
+
# --------------------------------
|
| 129 |
+
# (1) img_lr
|
| 130 |
+
# --------------------------------
|
| 131 |
+
img_name, ext = os.path.splitext(os.path.basename(img_hr))
|
| 132 |
+
img_lr = util.imread_uint(img_lr, n_channels=3)
|
| 133 |
+
img_lr = util.uint2tensor4(img_lr, data_range)
|
| 134 |
+
img_lr = img_lr.to(device)
|
| 135 |
+
|
| 136 |
+
# --------------------------------
|
| 137 |
+
# (2) img_sr
|
| 138 |
+
# --------------------------------
|
| 139 |
+
start.record()
|
| 140 |
+
img_sr = forward(img_lr, model, tile)
|
| 141 |
+
end.record()
|
| 142 |
+
torch.cuda.synchronize()
|
| 143 |
+
results[f"{mode}_runtime"].append(start.elapsed_time(end)) # milliseconds
|
| 144 |
+
img_sr = util.tensor2uint(img_sr, data_range)
|
| 145 |
+
|
| 146 |
+
# --------------------------------
|
| 147 |
+
# (3) img_hr
|
| 148 |
+
# --------------------------------
|
| 149 |
+
img_hr = util.imread_uint(img_hr, n_channels=3)
|
| 150 |
+
img_hr = img_hr.squeeze()
|
| 151 |
+
img_hr = util.modcrop(img_hr, sf)
|
| 152 |
+
|
| 153 |
+
# --------------------------------
|
| 154 |
+
# PSNR and SSIM
|
| 155 |
+
# --------------------------------
|
| 156 |
+
|
| 157 |
+
# print(img_sr.shape, img_hr.shape)
|
| 158 |
+
psnr = util.calculate_psnr(img_sr, img_hr, border=border)
|
| 159 |
+
results[f"{mode}_psnr"].append(psnr)
|
| 160 |
+
|
| 161 |
+
if args.ssim:
|
| 162 |
+
ssim = util.calculate_ssim(img_sr, img_hr, border=border)
|
| 163 |
+
results[f"{mode}_ssim"].append(ssim)
|
| 164 |
+
logger.info("{:s} - PSNR: {:.2f} dB; SSIM: {:.4f}.".format(img_name + ext, psnr, ssim))
|
| 165 |
+
else:
|
| 166 |
+
logger.info("{:s} - PSNR: {:.2f} dB".format(img_name + ext, psnr))
|
| 167 |
+
|
| 168 |
+
# if np.ndim(img_hr) == 3: # RGB image
|
| 169 |
+
# img_sr_y = util.rgb2ycbcr(img_sr, only_y=True)
|
| 170 |
+
# img_hr_y = util.rgb2ycbcr(img_hr, only_y=True)
|
| 171 |
+
# psnr_y = util.calculate_psnr(img_sr_y, img_hr_y, border=border)
|
| 172 |
+
# ssim_y = util.calculate_ssim(img_sr_y, img_hr_y, border=border)
|
| 173 |
+
# results[f"{mode}_psnr_y"].append(psnr_y)
|
| 174 |
+
# results[f"{mode}_ssim_y"].append(ssim_y)
|
| 175 |
+
# print(os.path.join(save_path, img_name+ext))
|
| 176 |
+
|
| 177 |
+
# --- Save Restored Images ---
|
| 178 |
+
# util.imsave(img_sr, os.path.join(save_path, img_name+ext))
|
| 179 |
+
|
| 180 |
+
results[f"{mode}_memory"] = torch.cuda.max_memory_allocated(torch.cuda.current_device()) / 1024 ** 2
|
| 181 |
+
results[f"{mode}_ave_runtime"] = sum(results[f"{mode}_runtime"]) / len(results[f"{mode}_runtime"]) #/ 1000.0
|
| 182 |
+
results[f"{mode}_ave_psnr"] = sum(results[f"{mode}_psnr"]) / len(results[f"{mode}_psnr"])
|
| 183 |
+
if args.ssim:
|
| 184 |
+
results[f"{mode}_ave_ssim"] = sum(results[f"{mode}_ssim"]) / len(results[f"{mode}_ssim"])
|
| 185 |
+
# results[f"{mode}_ave_psnr_y"] = sum(results[f"{mode}_psnr_y"]) / len(results[f"{mode}_psnr_y"])
|
| 186 |
+
# results[f"{mode}_ave_ssim_y"] = sum(results[f"{mode}_ssim_y"]) / len(results[f"{mode}_ssim_y"])
|
| 187 |
+
logger.info("{:>16s} : {:<.3f} [M]".format("Max Memory", results[f"{mode}_memory"])) # Memery
|
| 188 |
+
logger.info("------> Average runtime of ({}) is : {:.6f} milliseconds".format("test" if mode == "test" else "valid", results[f"{mode}_ave_runtime"]))
|
| 189 |
+
logger.info("------> Average PSNR of ({}) is : {:.6f} dB".format("test" if mode == "test" else "valid", results[f"{mode}_ave_psnr"]))
|
| 190 |
+
|
| 191 |
+
return results
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def main(args):
|
| 195 |
+
|
| 196 |
+
utils_logger.logger_info("NTIRE2025-EfficientSR", log_path="NTIRE2025-EfficientSR.log")
|
| 197 |
+
logger = logging.getLogger("NTIRE2025-EfficientSR")
|
| 198 |
+
|
| 199 |
+
# --------------------------------
|
| 200 |
+
# basic settings
|
| 201 |
+
# --------------------------------
|
| 202 |
+
torch.cuda.current_device()
|
| 203 |
+
torch.cuda.empty_cache()
|
| 204 |
+
torch.backends.cudnn.benchmark = False
|
| 205 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 206 |
+
|
| 207 |
+
json_dir = os.path.join(os.getcwd(), "results.json")
|
| 208 |
+
if not os.path.exists(json_dir):
|
| 209 |
+
results = dict()
|
| 210 |
+
else:
|
| 211 |
+
with open(json_dir, "r") as f:
|
| 212 |
+
results = json.load(f)
|
| 213 |
+
|
| 214 |
+
# --------------------------------
|
| 215 |
+
# load model
|
| 216 |
+
# --------------------------------
|
| 217 |
+
model, model_name, data_range, tile = select_model(args, device)
|
| 218 |
+
logger.info(model_name)
|
| 219 |
+
|
| 220 |
+
# if model not in results:
|
| 221 |
+
if True:
|
| 222 |
+
# --------------------------------
|
| 223 |
+
# restore image
|
| 224 |
+
# --------------------------------
|
| 225 |
+
|
| 226 |
+
# inference on the DIV2K_LSDIR_valid set
|
| 227 |
+
valid_results = run(model, model_name, data_range, tile, logger, device, args, mode="valid")
|
| 228 |
+
# record PSNR, runtime
|
| 229 |
+
results[model_name] = valid_results
|
| 230 |
+
|
| 231 |
+
# inference conducted by the Organizer on DIV2K_LSDIR_test set
|
| 232 |
+
if args.include_test:
|
| 233 |
+
test_results = run(model, model_name, data_range, tile, logger, device, args, mode="test")
|
| 234 |
+
results[model_name].update(test_results)
|
| 235 |
+
|
| 236 |
+
input_dim = (3, 256, 256) # set the input dimension
|
| 237 |
+
activations, num_conv = get_model_activation(model, input_dim)
|
| 238 |
+
activations = activations/10**6
|
| 239 |
+
logger.info("{:>16s} : {:<.4f} [M]".format("#Activations", activations))
|
| 240 |
+
logger.info("{:>16s} : {:<d}".format("#Conv2d", num_conv))
|
| 241 |
+
|
| 242 |
+
# The FLOPs calculation in previous NTIRE_ESR Challenge
|
| 243 |
+
# flops = get_model_flops(model, input_dim, False)
|
| 244 |
+
# flops = flops/10**9
|
| 245 |
+
# logger.info("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
|
| 246 |
+
|
| 247 |
+
# fvcore is used in NTIRE2025_ESR for FLOPs calculation
|
| 248 |
+
input_fake = torch.rand(1, 3, 256, 256).to(device)
|
| 249 |
+
flops = FlopCountAnalysis(model, input_fake).total()
|
| 250 |
+
flops = flops/10**9
|
| 251 |
+
logger.info("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
|
| 252 |
+
|
| 253 |
+
num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
|
| 254 |
+
num_parameters = num_parameters/10**6
|
| 255 |
+
logger.info("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters))
|
| 256 |
+
results[model_name].update({"activations": activations, "num_conv": num_conv, "flops": flops, "num_parameters": num_parameters})
|
| 257 |
+
|
| 258 |
+
with open(json_dir, "w") as f:
|
| 259 |
+
json.dump(results, f)
|
| 260 |
+
if args.include_test:
|
| 261 |
+
fmt = "{:20s}\t{:10s}\t{:10s}\t{:14s}\t{:14s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n"
|
| 262 |
+
s = fmt.format("Model", "Val PSNR", "Test PSNR", "Val Time [ms]", "Test Time [ms]", "Ave Time [ms]",
|
| 263 |
+
"Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv")
|
| 264 |
+
else:
|
| 265 |
+
fmt = "{:20s}\t{:10s}\t{:14s}\t{:10s}\t{:10s}\t{:8s}\t{:8s}\t{:8s}\n"
|
| 266 |
+
s = fmt.format("Model", "Val PSNR", "Val Time [ms]", "Params [M]", "FLOPs [G]", "Acts [M]", "Mem [M]", "Conv")
|
| 267 |
+
for k, v in results.items():
|
| 268 |
+
val_psnr = f"{v['valid_ave_psnr']:2.2f}"
|
| 269 |
+
val_time = f"{v['valid_ave_runtime']:3.2f}"
|
| 270 |
+
mem = f"{v['valid_memory']:2.2f}"
|
| 271 |
+
|
| 272 |
+
num_param = f"{v['num_parameters']:2.3f}"
|
| 273 |
+
flops = f"{v['flops']:2.2f}"
|
| 274 |
+
acts = f"{v['activations']:2.2f}"
|
| 275 |
+
conv = f"{v['num_conv']:4d}"
|
| 276 |
+
if args.include_test:
|
| 277 |
+
# from IPython import embed; embed()
|
| 278 |
+
test_psnr = f"{v['test_ave_psnr']:2.2f}"
|
| 279 |
+
test_time = f"{v['test_ave_runtime']:3.2f}"
|
| 280 |
+
ave_time = f"{(v['valid_ave_runtime'] + v['test_ave_runtime']) / 2:3.2f}"
|
| 281 |
+
s += fmt.format(k, val_psnr, test_psnr, val_time, test_time, ave_time, num_param, flops, acts, mem, conv)
|
| 282 |
+
else:
|
| 283 |
+
s += fmt.format(k, val_psnr, val_time, num_param, flops, acts, mem, conv)
|
| 284 |
+
with open(os.path.join(os.getcwd(), 'results.txt'), "w") as f:
|
| 285 |
+
f.write(s)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
if __name__ == "__main__":
|
| 289 |
+
parser = argparse.ArgumentParser("NTIRE2025-EfficientSR")
|
| 290 |
+
parser.add_argument("--data_dir", default="../", type=str)
|
| 291 |
+
parser.add_argument("--save_dir", default="../results", type=str)
|
| 292 |
+
parser.add_argument("--model_id", default=0, type=int)
|
| 293 |
+
parser.add_argument("--include_test", action="store_true", help="Inference on the `DIV2K_LSDIR_test` set")
|
| 294 |
+
parser.add_argument("--ssim", action="store_true", help="Calculate SSIM")
|
| 295 |
+
|
| 296 |
+
args = parser.parse_args()
|
| 297 |
+
pprint(args)
|
| 298 |
+
|
| 299 |
+
main(args)
|
utils/__pycache__/model_summary.cpython-311.pyc
ADDED
|
Binary file (22.3 kB). View file
|
|
|
utils/__pycache__/utils_image.cpython-311.pyc
ADDED
|
Binary file (38.5 kB). View file
|
|
|
utils/__pycache__/utils_logger.cpython-311.pyc
ADDED
|
Binary file (2.87 kB). View file
|
|
|
utils/model_summary.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
'''
|
| 6 |
+
---- 1) FLOPs: floating point operations
|
| 7 |
+
---- 2) #Activations: the number of elements of all ‘Conv2d’ outputs
|
| 8 |
+
---- 3) #Conv2d: the number of ‘Conv2d’ layers
|
| 9 |
+
'''
|
| 10 |
+
|
| 11 |
+
def get_model_flops(model, input_res, print_per_layer_stat=True,
|
| 12 |
+
input_constructor=None):
|
| 13 |
+
assert type(input_res) is tuple, 'Please provide the size of the input image.'
|
| 14 |
+
assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
|
| 15 |
+
flops_model = add_flops_counting_methods(model)
|
| 16 |
+
flops_model.eval().start_flops_count()
|
| 17 |
+
if input_constructor:
|
| 18 |
+
input = input_constructor(input_res)
|
| 19 |
+
_ = flops_model(**input)
|
| 20 |
+
else:
|
| 21 |
+
device = list(flops_model.parameters())[-1].device
|
| 22 |
+
batch = torch.FloatTensor(1, *input_res).to(device)
|
| 23 |
+
_ = flops_model(batch)
|
| 24 |
+
|
| 25 |
+
if print_per_layer_stat:
|
| 26 |
+
print_model_with_flops(flops_model)
|
| 27 |
+
flops_count = flops_model.compute_average_flops_cost()
|
| 28 |
+
flops_model.stop_flops_count()
|
| 29 |
+
|
| 30 |
+
return flops_count
|
| 31 |
+
|
| 32 |
+
def get_model_activation(model, input_res, input_constructor=None):
|
| 33 |
+
assert type(input_res) is tuple, 'Please provide the size of the input image.'
|
| 34 |
+
assert len(input_res) >= 3, 'Input image should have 3 dimensions.'
|
| 35 |
+
activation_model = add_activation_counting_methods(model)
|
| 36 |
+
activation_model.eval().start_activation_count()
|
| 37 |
+
if input_constructor:
|
| 38 |
+
input = input_constructor(input_res)
|
| 39 |
+
_ = activation_model(**input)
|
| 40 |
+
else:
|
| 41 |
+
device = list(activation_model.parameters())[-1].device
|
| 42 |
+
batch = torch.FloatTensor(1, *input_res).to(device)
|
| 43 |
+
_ = activation_model(batch)
|
| 44 |
+
|
| 45 |
+
activation_count, num_conv = activation_model.compute_average_activation_cost()
|
| 46 |
+
activation_model.stop_activation_count()
|
| 47 |
+
|
| 48 |
+
return activation_count, num_conv
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_model_complexity_info(model, input_res, print_per_layer_stat=True, as_strings=True,
|
| 52 |
+
input_constructor=None):
|
| 53 |
+
assert type(input_res) is tuple
|
| 54 |
+
assert len(input_res) >= 3
|
| 55 |
+
flops_model = add_flops_counting_methods(model)
|
| 56 |
+
flops_model.eval().start_flops_count()
|
| 57 |
+
if input_constructor:
|
| 58 |
+
input = input_constructor(input_res)
|
| 59 |
+
_ = flops_model(**input)
|
| 60 |
+
else:
|
| 61 |
+
batch = torch.FloatTensor(1, *input_res)
|
| 62 |
+
_ = flops_model(batch)
|
| 63 |
+
|
| 64 |
+
if print_per_layer_stat:
|
| 65 |
+
print_model_with_flops(flops_model)
|
| 66 |
+
flops_count = flops_model.compute_average_flops_cost()
|
| 67 |
+
params_count = get_model_parameters_number(flops_model)
|
| 68 |
+
flops_model.stop_flops_count()
|
| 69 |
+
|
| 70 |
+
if as_strings:
|
| 71 |
+
return flops_to_string(flops_count), params_to_string(params_count)
|
| 72 |
+
|
| 73 |
+
return flops_count, params_count
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def flops_to_string(flops, units='GMac', precision=2):
|
| 77 |
+
if units is None:
|
| 78 |
+
if flops // 10**9 > 0:
|
| 79 |
+
return str(round(flops / 10.**9, precision)) + ' GMac'
|
| 80 |
+
elif flops // 10**6 > 0:
|
| 81 |
+
return str(round(flops / 10.**6, precision)) + ' MMac'
|
| 82 |
+
elif flops // 10**3 > 0:
|
| 83 |
+
return str(round(flops / 10.**3, precision)) + ' KMac'
|
| 84 |
+
else:
|
| 85 |
+
return str(flops) + ' Mac'
|
| 86 |
+
else:
|
| 87 |
+
if units == 'GMac':
|
| 88 |
+
return str(round(flops / 10.**9, precision)) + ' ' + units
|
| 89 |
+
elif units == 'MMac':
|
| 90 |
+
return str(round(flops / 10.**6, precision)) + ' ' + units
|
| 91 |
+
elif units == 'KMac':
|
| 92 |
+
return str(round(flops / 10.**3, precision)) + ' ' + units
|
| 93 |
+
else:
|
| 94 |
+
return str(flops) + ' Mac'
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def params_to_string(params_num):
|
| 98 |
+
if params_num // 10 ** 6 > 0:
|
| 99 |
+
return str(round(params_num / 10 ** 6, 2)) + ' M'
|
| 100 |
+
elif params_num // 10 ** 3:
|
| 101 |
+
return str(round(params_num / 10 ** 3, 2)) + ' k'
|
| 102 |
+
else:
|
| 103 |
+
return str(params_num)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def print_model_with_flops(model, units='GMac', precision=3):
|
| 107 |
+
total_flops = model.compute_average_flops_cost()
|
| 108 |
+
|
| 109 |
+
def accumulate_flops(self):
|
| 110 |
+
if is_supported_instance(self):
|
| 111 |
+
return self.__flops__ / model.__batch_counter__
|
| 112 |
+
else:
|
| 113 |
+
sum = 0
|
| 114 |
+
for m in self.children():
|
| 115 |
+
sum += m.accumulate_flops()
|
| 116 |
+
return sum
|
| 117 |
+
|
| 118 |
+
def flops_repr(self):
|
| 119 |
+
accumulated_flops_cost = self.accumulate_flops()
|
| 120 |
+
return ', '.join([flops_to_string(accumulated_flops_cost, units=units, precision=precision),
|
| 121 |
+
'{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
|
| 122 |
+
self.original_extra_repr()])
|
| 123 |
+
|
| 124 |
+
def add_extra_repr(m):
|
| 125 |
+
m.accumulate_flops = accumulate_flops.__get__(m)
|
| 126 |
+
flops_extra_repr = flops_repr.__get__(m)
|
| 127 |
+
if m.extra_repr != flops_extra_repr:
|
| 128 |
+
m.original_extra_repr = m.extra_repr
|
| 129 |
+
m.extra_repr = flops_extra_repr
|
| 130 |
+
assert m.extra_repr != m.original_extra_repr
|
| 131 |
+
|
| 132 |
+
def del_extra_repr(m):
|
| 133 |
+
if hasattr(m, 'original_extra_repr'):
|
| 134 |
+
m.extra_repr = m.original_extra_repr
|
| 135 |
+
del m.original_extra_repr
|
| 136 |
+
if hasattr(m, 'accumulate_flops'):
|
| 137 |
+
del m.accumulate_flops
|
| 138 |
+
|
| 139 |
+
model.apply(add_extra_repr)
|
| 140 |
+
print(model)
|
| 141 |
+
model.apply(del_extra_repr)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_model_parameters_number(model):
|
| 145 |
+
params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 146 |
+
return params_num
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def add_flops_counting_methods(net_main_module):
|
| 150 |
+
# adding additional methods to the existing module object,
|
| 151 |
+
# this is done this way so that each function has access to self object
|
| 152 |
+
# embed()
|
| 153 |
+
net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
|
| 154 |
+
net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
|
| 155 |
+
net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
|
| 156 |
+
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module)
|
| 157 |
+
|
| 158 |
+
net_main_module.reset_flops_count()
|
| 159 |
+
return net_main_module
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def compute_average_flops_cost(self):
|
| 163 |
+
"""
|
| 164 |
+
A method that will be available after add_flops_counting_methods() is called
|
| 165 |
+
on a desired net object.
|
| 166 |
+
|
| 167 |
+
Returns current mean flops consumption per image.
|
| 168 |
+
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
flops_sum = 0
|
| 172 |
+
for module in self.modules():
|
| 173 |
+
if is_supported_instance(module):
|
| 174 |
+
flops_sum += module.__flops__
|
| 175 |
+
|
| 176 |
+
return flops_sum
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def start_flops_count(self):
|
| 180 |
+
"""
|
| 181 |
+
A method that will be available after add_flops_counting_methods() is called
|
| 182 |
+
on a desired net object.
|
| 183 |
+
|
| 184 |
+
Activates the computation of mean flops consumption per image.
|
| 185 |
+
Call it before you run the network.
|
| 186 |
+
|
| 187 |
+
"""
|
| 188 |
+
self.apply(add_flops_counter_hook_function)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def stop_flops_count(self):
|
| 192 |
+
"""
|
| 193 |
+
A method that will be available after add_flops_counting_methods() is called
|
| 194 |
+
on a desired net object.
|
| 195 |
+
|
| 196 |
+
Stops computing the mean flops consumption per image.
|
| 197 |
+
Call whenever you want to pause the computation.
|
| 198 |
+
|
| 199 |
+
"""
|
| 200 |
+
self.apply(remove_flops_counter_hook_function)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def reset_flops_count(self):
|
| 204 |
+
"""
|
| 205 |
+
A method that will be available after add_flops_counting_methods() is called
|
| 206 |
+
on a desired net object.
|
| 207 |
+
|
| 208 |
+
Resets statistics computed so far.
|
| 209 |
+
|
| 210 |
+
"""
|
| 211 |
+
self.apply(add_flops_counter_variable_or_reset)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def add_flops_counter_hook_function(module):
|
| 215 |
+
if is_supported_instance(module):
|
| 216 |
+
if hasattr(module, '__flops_handle__'):
|
| 217 |
+
return
|
| 218 |
+
|
| 219 |
+
if isinstance(module, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
|
| 220 |
+
handle = module.register_forward_hook(conv_flops_counter_hook)
|
| 221 |
+
elif isinstance(module, (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)):
|
| 222 |
+
handle = module.register_forward_hook(relu_flops_counter_hook)
|
| 223 |
+
elif isinstance(module, nn.Linear):
|
| 224 |
+
handle = module.register_forward_hook(linear_flops_counter_hook)
|
| 225 |
+
elif isinstance(module, (nn.BatchNorm2d)):
|
| 226 |
+
handle = module.register_forward_hook(bn_flops_counter_hook)
|
| 227 |
+
else:
|
| 228 |
+
handle = module.register_forward_hook(empty_flops_counter_hook)
|
| 229 |
+
module.__flops_handle__ = handle
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def remove_flops_counter_hook_function(module):
|
| 233 |
+
if is_supported_instance(module):
|
| 234 |
+
if hasattr(module, '__flops_handle__'):
|
| 235 |
+
module.__flops_handle__.remove()
|
| 236 |
+
del module.__flops_handle__
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def add_flops_counter_variable_or_reset(module):
|
| 240 |
+
if is_supported_instance(module):
|
| 241 |
+
module.__flops__ = 0
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# ---- Internal functions
|
| 245 |
+
def is_supported_instance(module):
|
| 246 |
+
if isinstance(module,
|
| 247 |
+
(
|
| 248 |
+
nn.Conv2d, nn.ConvTranspose2d,
|
| 249 |
+
nn.BatchNorm2d,
|
| 250 |
+
nn.Linear,
|
| 251 |
+
nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6,
|
| 252 |
+
)):
|
| 253 |
+
return True
|
| 254 |
+
|
| 255 |
+
return False
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def conv_flops_counter_hook(conv_module, input, output):
|
| 259 |
+
# Can have multiple inputs, getting the first one
|
| 260 |
+
# input = input[0]
|
| 261 |
+
|
| 262 |
+
batch_size = output.shape[0]
|
| 263 |
+
output_dims = list(output.shape[2:])
|
| 264 |
+
|
| 265 |
+
kernel_dims = list(conv_module.kernel_size)
|
| 266 |
+
in_channels = conv_module.in_channels
|
| 267 |
+
out_channels = conv_module.out_channels
|
| 268 |
+
groups = conv_module.groups
|
| 269 |
+
|
| 270 |
+
filters_per_channel = out_channels // groups
|
| 271 |
+
conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
|
| 272 |
+
|
| 273 |
+
active_elements_count = batch_size * np.prod(output_dims)
|
| 274 |
+
overall_conv_flops = int(conv_per_position_flops) * int(active_elements_count)
|
| 275 |
+
|
| 276 |
+
# overall_flops = overall_conv_flops
|
| 277 |
+
|
| 278 |
+
conv_module.__flops__ += int(overall_conv_flops)
|
| 279 |
+
# conv_module.__output_dims__ = output_dims
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def relu_flops_counter_hook(module, input, output):
|
| 283 |
+
active_elements_count = output.numel()
|
| 284 |
+
module.__flops__ += int(active_elements_count)
|
| 285 |
+
# print(module.__flops__, id(module))
|
| 286 |
+
# print(module)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def linear_flops_counter_hook(module, input, output):
|
| 290 |
+
input = input[0]
|
| 291 |
+
if len(input.shape) == 1:
|
| 292 |
+
batch_size = 1
|
| 293 |
+
module.__flops__ += int(batch_size * input.shape[0] * output.shape[0])
|
| 294 |
+
else:
|
| 295 |
+
batch_size = input.shape[0]
|
| 296 |
+
module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def bn_flops_counter_hook(module, input, output):
|
| 300 |
+
# input = input[0]
|
| 301 |
+
# TODO: need to check here
|
| 302 |
+
# batch_flops = np.prod(input.shape)
|
| 303 |
+
# if module.affine:
|
| 304 |
+
# batch_flops *= 2
|
| 305 |
+
# module.__flops__ += int(batch_flops)
|
| 306 |
+
batch = output.shape[0]
|
| 307 |
+
output_dims = output.shape[2:]
|
| 308 |
+
channels = module.num_features
|
| 309 |
+
batch_flops = batch * channels * np.prod(output_dims)
|
| 310 |
+
if module.affine:
|
| 311 |
+
batch_flops *= 2
|
| 312 |
+
module.__flops__ += int(batch_flops)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# ---- Count the number of convolutional layers and the activation
|
| 316 |
+
def add_activation_counting_methods(net_main_module):
|
| 317 |
+
# adding additional methods to the existing module object,
|
| 318 |
+
# this is done this way so that each function has access to self object
|
| 319 |
+
# embed()
|
| 320 |
+
net_main_module.start_activation_count = start_activation_count.__get__(net_main_module)
|
| 321 |
+
net_main_module.stop_activation_count = stop_activation_count.__get__(net_main_module)
|
| 322 |
+
net_main_module.reset_activation_count = reset_activation_count.__get__(net_main_module)
|
| 323 |
+
net_main_module.compute_average_activation_cost = compute_average_activation_cost.__get__(net_main_module)
|
| 324 |
+
|
| 325 |
+
net_main_module.reset_activation_count()
|
| 326 |
+
return net_main_module
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def compute_average_activation_cost(self):
|
| 330 |
+
"""
|
| 331 |
+
A method that will be available after add_activation_counting_methods() is called
|
| 332 |
+
on a desired net object.
|
| 333 |
+
|
| 334 |
+
Returns current mean activation consumption per image.
|
| 335 |
+
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
activation_sum = 0
|
| 339 |
+
num_conv = 0
|
| 340 |
+
for module in self.modules():
|
| 341 |
+
if is_supported_instance_for_activation(module):
|
| 342 |
+
activation_sum += module.__activation__
|
| 343 |
+
num_conv += module.__num_conv__
|
| 344 |
+
return activation_sum, num_conv
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def start_activation_count(self):
|
| 348 |
+
"""
|
| 349 |
+
A method that will be available after add_activation_counting_methods() is called
|
| 350 |
+
on a desired net object.
|
| 351 |
+
|
| 352 |
+
Activates the computation of mean activation consumption per image.
|
| 353 |
+
Call it before you run the network.
|
| 354 |
+
|
| 355 |
+
"""
|
| 356 |
+
self.apply(add_activation_counter_hook_function)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def stop_activation_count(self):
|
| 360 |
+
"""
|
| 361 |
+
A method that will be available after add_activation_counting_methods() is called
|
| 362 |
+
on a desired net object.
|
| 363 |
+
|
| 364 |
+
Stops computing the mean activation consumption per image.
|
| 365 |
+
Call whenever you want to pause the computation.
|
| 366 |
+
|
| 367 |
+
"""
|
| 368 |
+
self.apply(remove_activation_counter_hook_function)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def reset_activation_count(self):
|
| 372 |
+
"""
|
| 373 |
+
A method that will be available after add_activation_counting_methods() is called
|
| 374 |
+
on a desired net object.
|
| 375 |
+
|
| 376 |
+
Resets statistics computed so far.
|
| 377 |
+
|
| 378 |
+
"""
|
| 379 |
+
self.apply(add_activation_counter_variable_or_reset)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def add_activation_counter_hook_function(module):
|
| 383 |
+
if is_supported_instance_for_activation(module):
|
| 384 |
+
if hasattr(module, '__activation_handle__'):
|
| 385 |
+
return
|
| 386 |
+
|
| 387 |
+
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
| 388 |
+
handle = module.register_forward_hook(conv_activation_counter_hook)
|
| 389 |
+
module.__activation_handle__ = handle
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def remove_activation_counter_hook_function(module):
|
| 393 |
+
if is_supported_instance_for_activation(module):
|
| 394 |
+
if hasattr(module, '__activation_handle__'):
|
| 395 |
+
module.__activation_handle__.remove()
|
| 396 |
+
del module.__activation_handle__
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def add_activation_counter_variable_or_reset(module):
|
| 400 |
+
if is_supported_instance_for_activation(module):
|
| 401 |
+
module.__activation__ = 0
|
| 402 |
+
module.__num_conv__ = 0
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def is_supported_instance_for_activation(module):
|
| 406 |
+
if isinstance(module,
|
| 407 |
+
(
|
| 408 |
+
nn.Conv2d, nn.ConvTranspose2d, nn.Conv1d, nn.Linear, nn.ConvTranspose1d
|
| 409 |
+
)):
|
| 410 |
+
return True
|
| 411 |
+
|
| 412 |
+
return False
|
| 413 |
+
|
| 414 |
+
def conv_activation_counter_hook(module, input, output):
|
| 415 |
+
"""
|
| 416 |
+
Calculate the activations in the convolutional operation.
|
| 417 |
+
Reference: Ilija Radosavovic, Raj Prateek Kosaraju, Ross Girshick, Kaiming He, Piotr Dollár, Designing Network Design Spaces.
|
| 418 |
+
:param module:
|
| 419 |
+
:param input:
|
| 420 |
+
:param output:
|
| 421 |
+
:return:
|
| 422 |
+
"""
|
| 423 |
+
module.__activation__ += output.numel()
|
| 424 |
+
module.__num_conv__ += 1
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def empty_flops_counter_hook(module, input, output):
|
| 428 |
+
module.__flops__ += 0
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def upsample_flops_counter_hook(module, input, output):
|
| 432 |
+
output_size = output[0]
|
| 433 |
+
batch_size = output_size.shape[0]
|
| 434 |
+
output_elements_count = batch_size
|
| 435 |
+
for val in output_size.shape[1:]:
|
| 436 |
+
output_elements_count *= val
|
| 437 |
+
module.__flops__ += int(output_elements_count)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def pool_flops_counter_hook(module, input, output):
|
| 441 |
+
input = input[0]
|
| 442 |
+
module.__flops__ += int(np.prod(input.shape))
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def dconv_flops_counter_hook(dconv_module, input, output):
|
| 446 |
+
input = input[0]
|
| 447 |
+
|
| 448 |
+
batch_size = input.shape[0]
|
| 449 |
+
output_dims = list(output.shape[2:])
|
| 450 |
+
|
| 451 |
+
m_channels, in_channels, kernel_dim1, _, = dconv_module.weight.shape
|
| 452 |
+
out_channels, _, kernel_dim2, _, = dconv_module.projection.shape
|
| 453 |
+
# groups = dconv_module.groups
|
| 454 |
+
|
| 455 |
+
# filters_per_channel = out_channels // groups
|
| 456 |
+
conv_per_position_flops1 = kernel_dim1 ** 2 * in_channels * m_channels
|
| 457 |
+
conv_per_position_flops2 = kernel_dim2 ** 2 * out_channels * m_channels
|
| 458 |
+
active_elements_count = batch_size * np.prod(output_dims)
|
| 459 |
+
|
| 460 |
+
overall_conv_flops = (conv_per_position_flops1 + conv_per_position_flops2) * active_elements_count
|
| 461 |
+
overall_flops = overall_conv_flops
|
| 462 |
+
|
| 463 |
+
dconv_module.__flops__ += int(overall_flops)
|
| 464 |
+
# dconv_module.__output_dims__ = output_dims
|
| 465 |
+
|
utils/test.bmp
ADDED
|
Git LFS Details
|
utils/utils_image.py
ADDED
|
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import cv2
|
| 7 |
+
from torchvision.utils import make_grid
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
# import torchvision.transforms as transforms
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def is_image_file(filename):
|
| 17 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_timestamp():
|
| 21 |
+
return datetime.now().strftime('%y%m%d-%H%M%S')
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def imshow(x, title=None, cbar=False, figsize=None):
|
| 25 |
+
plt.figure(figsize=figsize)
|
| 26 |
+
plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
|
| 27 |
+
if title:
|
| 28 |
+
plt.title(title)
|
| 29 |
+
if cbar:
|
| 30 |
+
plt.colorbar()
|
| 31 |
+
plt.show()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
'''
|
| 35 |
+
# =======================================
|
| 36 |
+
# get image pathes of files
|
| 37 |
+
# =======================================
|
| 38 |
+
'''
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_image_paths(dataroot):
|
| 42 |
+
paths = None # return None if dataroot is None
|
| 43 |
+
if dataroot is not None:
|
| 44 |
+
paths = sorted(_get_paths_from_images(dataroot))
|
| 45 |
+
return paths
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _get_paths_from_images(path):
|
| 49 |
+
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
| 50 |
+
images = []
|
| 51 |
+
for dirpath, _, fnames in sorted(os.walk(path)):
|
| 52 |
+
for fname in sorted(fnames):
|
| 53 |
+
if is_image_file(fname):
|
| 54 |
+
img_path = os.path.join(dirpath, fname)
|
| 55 |
+
images.append(img_path)
|
| 56 |
+
assert images, '{:s} has no valid image file'.format(path)
|
| 57 |
+
return images
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
'''
|
| 61 |
+
# =======================================
|
| 62 |
+
# makedir
|
| 63 |
+
# =======================================
|
| 64 |
+
'''
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def mkdir(path):
|
| 68 |
+
if not os.path.exists(path):
|
| 69 |
+
os.makedirs(path)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def mkdirs(paths):
|
| 73 |
+
if isinstance(paths, str):
|
| 74 |
+
mkdir(paths)
|
| 75 |
+
else:
|
| 76 |
+
for path in paths:
|
| 77 |
+
mkdir(path)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def mkdir_and_rename(path):
|
| 81 |
+
if os.path.exists(path):
|
| 82 |
+
new_name = path + '_archived_' + get_timestamp()
|
| 83 |
+
print('Path already exists. Rename it to [{:s}]'.format(new_name))
|
| 84 |
+
os.rename(path, new_name)
|
| 85 |
+
os.makedirs(path)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
'''
|
| 89 |
+
# =======================================
|
| 90 |
+
# read image from path
|
| 91 |
+
# Note: opencv is fast
|
| 92 |
+
# but read BGR numpy image
|
| 93 |
+
# =======================================
|
| 94 |
+
'''
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ----------------------------------------
|
| 98 |
+
# get single image of size HxWxn_channles (BGR)
|
| 99 |
+
# ----------------------------------------
|
| 100 |
+
def read_img(path):
|
| 101 |
+
# read image by cv2
|
| 102 |
+
# return: Numpy float32, HWC, BGR, [0,1]
|
| 103 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
|
| 104 |
+
img = img.astype(np.float32) / 255.
|
| 105 |
+
if img.ndim == 2:
|
| 106 |
+
img = np.expand_dims(img, axis=2)
|
| 107 |
+
# some images have 4 channels
|
| 108 |
+
if img.shape[2] > 3:
|
| 109 |
+
img = img[:, :, :3]
|
| 110 |
+
return img
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ----------------------------------------
|
| 114 |
+
# get uint8 image of size HxWxn_channles (RGB)
|
| 115 |
+
# ----------------------------------------
|
| 116 |
+
def imread_uint(path, n_channels=3):
|
| 117 |
+
# input: path
|
| 118 |
+
# output: HxWx3(RGB or GGG), or HxWx1 (G)
|
| 119 |
+
if n_channels == 1:
|
| 120 |
+
img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
|
| 121 |
+
img = np.expand_dims(img, axis=2) # HxWx1
|
| 122 |
+
elif n_channels == 3:
|
| 123 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
|
| 124 |
+
if img.ndim == 2:
|
| 125 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
|
| 126 |
+
else:
|
| 127 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
|
| 128 |
+
return img
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def imsave(img, img_path):
|
| 132 |
+
img = np.squeeze(img)
|
| 133 |
+
if img.ndim == 3:
|
| 134 |
+
img = img[:, :, [2, 1, 0]]
|
| 135 |
+
cv2.imwrite(img_path, img)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
'''
|
| 139 |
+
# =======================================
|
| 140 |
+
# numpy(single) <---> numpy(uint)
|
| 141 |
+
# numpy(single) <---> tensor
|
| 142 |
+
# numpy(uint) <---> tensor
|
| 143 |
+
# =======================================
|
| 144 |
+
'''
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# --------------------------------
|
| 148 |
+
# numpy(single) <---> numpy(uint)
|
| 149 |
+
# --------------------------------
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def uint2single(img):
|
| 153 |
+
|
| 154 |
+
return np.float32(img/255.)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def uint2single1(img):
|
| 158 |
+
|
| 159 |
+
return np.float32(np.squeeze(img)/255.)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def single2uint(img):
|
| 163 |
+
|
| 164 |
+
return np.uint8((img.clip(0, 1)*255.).round())
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def uint162single(img):
|
| 168 |
+
|
| 169 |
+
return np.float32(img/65535.)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def single2uint16(img):
|
| 173 |
+
|
| 174 |
+
return np.uint8((img.clip(0, 1)*65535.).round())
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
# --------------------------------
|
| 178 |
+
# numpy(uint) <---> tensor
|
| 179 |
+
# uint (HxWxn_channels (RGB) or G)
|
| 180 |
+
# --------------------------------
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# convert uint (HxWxn_channels) to 4-dimensional torch tensor
|
| 184 |
+
def uint2tensor4(img, data_range):
|
| 185 |
+
if img.ndim == 2:
|
| 186 |
+
img = np.expand_dims(img, axis=2)
|
| 187 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255./data_range).unsqueeze(0)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# convert uint (HxWxn_channels) to 3-dimensional torch tensor
|
| 191 |
+
def uint2tensor3(img):
|
| 192 |
+
if img.ndim == 2:
|
| 193 |
+
img = np.expand_dims(img, axis=2)
|
| 194 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# convert torch tensor to uint
|
| 198 |
+
def tensor2uint(img, data_range):
|
| 199 |
+
img = img.data.squeeze().float().clamp_(0, 1*data_range).cpu().numpy()
|
| 200 |
+
if img.ndim == 3:
|
| 201 |
+
img = np.transpose(img, (1, 2, 0))
|
| 202 |
+
return np.uint8((img*255.0/data_range).round())
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# --------------------------------
|
| 206 |
+
# numpy(single) <---> tensor
|
| 207 |
+
# single (HxWxn_channels (RGB) or G)
|
| 208 |
+
# --------------------------------
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# convert single (HxWxn_channels) to 4-dimensional torch tensor
|
| 212 |
+
def single2tensor4(img):
|
| 213 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# convert single (HxWxn_channels) to 3-dimensional torch tensor
|
| 217 |
+
def single2tensor3(img):
|
| 218 |
+
return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# convert torch tensor to single
|
| 222 |
+
def tensor2single(img):
|
| 223 |
+
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
| 224 |
+
if img.ndim == 3:
|
| 225 |
+
img = np.transpose(img, (1, 2, 0))
|
| 226 |
+
|
| 227 |
+
return img
|
| 228 |
+
|
| 229 |
+
def tensor2single3(img):
|
| 230 |
+
img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
|
| 231 |
+
if img.ndim == 3:
|
| 232 |
+
img = np.transpose(img, (1, 2, 0))
|
| 233 |
+
elif img.ndim == 2:
|
| 234 |
+
img = np.expand_dims(img, axis=2)
|
| 235 |
+
return img
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# from skimage.io import imread, imsave
|
| 239 |
+
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
| 240 |
+
'''
|
| 241 |
+
Converts a torch Tensor into an image Numpy array of BGR channel order
|
| 242 |
+
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
|
| 243 |
+
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
|
| 244 |
+
'''
|
| 245 |
+
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
|
| 246 |
+
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
|
| 247 |
+
n_dim = tensor.dim()
|
| 248 |
+
if n_dim == 4:
|
| 249 |
+
n_img = len(tensor)
|
| 250 |
+
img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
|
| 251 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
| 252 |
+
elif n_dim == 3:
|
| 253 |
+
img_np = tensor.numpy()
|
| 254 |
+
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
|
| 255 |
+
elif n_dim == 2:
|
| 256 |
+
img_np = tensor.numpy()
|
| 257 |
+
else:
|
| 258 |
+
raise TypeError(
|
| 259 |
+
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
|
| 260 |
+
if out_type == np.uint8:
|
| 261 |
+
img_np = (img_np * 255.0).round()
|
| 262 |
+
# Important. Unlike matlab, numpy.uint8() WILL NOT round by default.
|
| 263 |
+
return img_np.astype(out_type)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
'''
|
| 267 |
+
# =======================================
|
| 268 |
+
# image processing process on numpy image
|
| 269 |
+
# augment(img_list, hflip=True, rot=True):
|
| 270 |
+
# =======================================
|
| 271 |
+
'''
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def augment_img(img, mode=0):
|
| 275 |
+
if mode == 0:
|
| 276 |
+
return img
|
| 277 |
+
elif mode == 1:
|
| 278 |
+
return np.flipud(np.rot90(img))
|
| 279 |
+
elif mode == 2:
|
| 280 |
+
return np.flipud(img)
|
| 281 |
+
elif mode == 3:
|
| 282 |
+
return np.rot90(img, k=3)
|
| 283 |
+
elif mode == 4:
|
| 284 |
+
return np.flipud(np.rot90(img, k=2))
|
| 285 |
+
elif mode == 5:
|
| 286 |
+
return np.rot90(img)
|
| 287 |
+
elif mode == 6:
|
| 288 |
+
return np.rot90(img, k=2)
|
| 289 |
+
elif mode == 7:
|
| 290 |
+
return np.flipud(np.rot90(img, k=3))
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def augment_img_np3(img, mode=0):
|
| 294 |
+
if mode == 0:
|
| 295 |
+
return img
|
| 296 |
+
elif mode == 1:
|
| 297 |
+
return img.transpose(1, 0, 2)
|
| 298 |
+
elif mode == 2:
|
| 299 |
+
return img[::-1, :, :]
|
| 300 |
+
elif mode == 3:
|
| 301 |
+
img = img[::-1, :, :]
|
| 302 |
+
img = img.transpose(1, 0, 2)
|
| 303 |
+
return img
|
| 304 |
+
elif mode == 4:
|
| 305 |
+
return img[:, ::-1, :]
|
| 306 |
+
elif mode == 5:
|
| 307 |
+
img = img[:, ::-1, :]
|
| 308 |
+
img = img.transpose(1, 0, 2)
|
| 309 |
+
return img
|
| 310 |
+
elif mode == 6:
|
| 311 |
+
img = img[:, ::-1, :]
|
| 312 |
+
img = img[::-1, :, :]
|
| 313 |
+
return img
|
| 314 |
+
elif mode == 7:
|
| 315 |
+
img = img[:, ::-1, :]
|
| 316 |
+
img = img[::-1, :, :]
|
| 317 |
+
img = img.transpose(1, 0, 2)
|
| 318 |
+
return img
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def augment_img_tensor(img, mode=0):
|
| 322 |
+
img_size = img.size()
|
| 323 |
+
img_np = img.data.cpu().numpy()
|
| 324 |
+
if len(img_size) == 3:
|
| 325 |
+
img_np = np.transpose(img_np, (1, 2, 0))
|
| 326 |
+
elif len(img_size) == 4:
|
| 327 |
+
img_np = np.transpose(img_np, (2, 3, 1, 0))
|
| 328 |
+
img_np = augment_img(img_np, mode=mode)
|
| 329 |
+
img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
|
| 330 |
+
if len(img_size) == 3:
|
| 331 |
+
img_tensor = img_tensor.permute(2, 0, 1)
|
| 332 |
+
elif len(img_size) == 4:
|
| 333 |
+
img_tensor = img_tensor.permute(3, 2, 0, 1)
|
| 334 |
+
|
| 335 |
+
return img_tensor.type_as(img)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def augment_imgs(img_list, hflip=True, rot=True):
|
| 339 |
+
# horizontal flip OR rotate
|
| 340 |
+
hflip = hflip and random.random() < 0.5
|
| 341 |
+
vflip = rot and random.random() < 0.5
|
| 342 |
+
rot90 = rot and random.random() < 0.5
|
| 343 |
+
|
| 344 |
+
def _augment(img):
|
| 345 |
+
if hflip:
|
| 346 |
+
img = img[:, ::-1, :]
|
| 347 |
+
if vflip:
|
| 348 |
+
img = img[::-1, :, :]
|
| 349 |
+
if rot90:
|
| 350 |
+
img = img.transpose(1, 0, 2)
|
| 351 |
+
return img
|
| 352 |
+
|
| 353 |
+
return [_augment(img) for img in img_list]
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
'''
|
| 357 |
+
# =======================================
|
| 358 |
+
# image processing process on numpy image
|
| 359 |
+
# channel_convert(in_c, tar_type, img_list):
|
| 360 |
+
# rgb2ycbcr(img, only_y=True):
|
| 361 |
+
# bgr2ycbcr(img, only_y=True):
|
| 362 |
+
# ycbcr2rgb(img):
|
| 363 |
+
# modcrop(img_in, scale):
|
| 364 |
+
# =======================================
|
| 365 |
+
'''
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def rgb2ycbcr(img, only_y=True):
|
| 369 |
+
'''same as matlab rgb2ycbcr
|
| 370 |
+
only_y: only return Y channel
|
| 371 |
+
Input:
|
| 372 |
+
uint8, [0, 255]
|
| 373 |
+
float, [0, 1]
|
| 374 |
+
'''
|
| 375 |
+
in_img_type = img.dtype
|
| 376 |
+
img.astype(np.float32)
|
| 377 |
+
if in_img_type != np.uint8:
|
| 378 |
+
img *= 255.
|
| 379 |
+
# convert
|
| 380 |
+
if only_y:
|
| 381 |
+
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
| 382 |
+
else:
|
| 383 |
+
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
| 384 |
+
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
| 385 |
+
if in_img_type == np.uint8:
|
| 386 |
+
rlt = rlt.round()
|
| 387 |
+
else:
|
| 388 |
+
rlt /= 255.
|
| 389 |
+
return rlt.astype(in_img_type)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def ycbcr2rgb(img):
|
| 393 |
+
'''same as matlab ycbcr2rgb
|
| 394 |
+
Input:
|
| 395 |
+
uint8, [0, 255]
|
| 396 |
+
float, [0, 1]
|
| 397 |
+
'''
|
| 398 |
+
in_img_type = img.dtype
|
| 399 |
+
img.astype(np.float32)
|
| 400 |
+
if in_img_type != np.uint8:
|
| 401 |
+
img *= 255.
|
| 402 |
+
# convert
|
| 403 |
+
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
| 404 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
| 405 |
+
if in_img_type == np.uint8:
|
| 406 |
+
rlt = rlt.round()
|
| 407 |
+
else:
|
| 408 |
+
rlt /= 255.
|
| 409 |
+
return rlt.astype(in_img_type)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def bgr2ycbcr(img, only_y=True):
|
| 413 |
+
'''bgr version of rgb2ycbcr
|
| 414 |
+
only_y: only return Y channel
|
| 415 |
+
Input:
|
| 416 |
+
uint8, [0, 255]
|
| 417 |
+
float, [0, 1]
|
| 418 |
+
'''
|
| 419 |
+
in_img_type = img.dtype
|
| 420 |
+
img.astype(np.float32)
|
| 421 |
+
if in_img_type != np.uint8:
|
| 422 |
+
img *= 255.
|
| 423 |
+
# convert
|
| 424 |
+
if only_y:
|
| 425 |
+
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
| 426 |
+
else:
|
| 427 |
+
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
| 428 |
+
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
| 429 |
+
if in_img_type == np.uint8:
|
| 430 |
+
rlt = rlt.round()
|
| 431 |
+
else:
|
| 432 |
+
rlt /= 255.
|
| 433 |
+
return rlt.astype(in_img_type)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def modcrop(img_in, scale):
|
| 437 |
+
# img_in: Numpy, HWC or HW
|
| 438 |
+
img = np.copy(img_in)
|
| 439 |
+
if img.ndim == 2:
|
| 440 |
+
H, W = img.shape
|
| 441 |
+
H_r, W_r = H % scale, W % scale
|
| 442 |
+
img = img[:H - H_r, :W - W_r]
|
| 443 |
+
elif img.ndim == 3:
|
| 444 |
+
H, W, C = img.shape
|
| 445 |
+
H_r, W_r = H % scale, W % scale
|
| 446 |
+
img = img[:H - H_r, :W - W_r, :]
|
| 447 |
+
else:
|
| 448 |
+
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
|
| 449 |
+
return img
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def shave(img_in, border=0):
|
| 453 |
+
# img_in: Numpy, HWC or HW
|
| 454 |
+
img = np.copy(img_in)
|
| 455 |
+
h, w = img.shape[:2]
|
| 456 |
+
img = img[border:h-border, border:w-border]
|
| 457 |
+
return img
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def channel_convert(in_c, tar_type, img_list):
|
| 461 |
+
# conversion among BGR, gray and y
|
| 462 |
+
if in_c == 3 and tar_type == 'gray': # BGR to gray
|
| 463 |
+
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
| 464 |
+
return [np.expand_dims(img, axis=2) for img in gray_list]
|
| 465 |
+
elif in_c == 3 and tar_type == 'y': # BGR to y
|
| 466 |
+
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
| 467 |
+
return [np.expand_dims(img, axis=2) for img in y_list]
|
| 468 |
+
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
|
| 469 |
+
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
| 470 |
+
else:
|
| 471 |
+
return img_list
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
'''
|
| 475 |
+
# =======================================
|
| 476 |
+
# metric, PSNR and SSIM
|
| 477 |
+
# =======================================
|
| 478 |
+
'''
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# ----------
|
| 482 |
+
# PSNR
|
| 483 |
+
# ----------
|
| 484 |
+
def calculate_psnr(img1, img2, border=0):
|
| 485 |
+
# img1 and img2 have range [0, 255]
|
| 486 |
+
if not img1.shape == img2.shape:
|
| 487 |
+
raise ValueError('Input images must have the same dimensions.')
|
| 488 |
+
h, w = img1.shape[:2]
|
| 489 |
+
img1 = img1[border:h-border, border:w-border]
|
| 490 |
+
img2 = img2[border:h-border, border:w-border]
|
| 491 |
+
|
| 492 |
+
img1 = img1.astype(np.float64)
|
| 493 |
+
img2 = img2.astype(np.float64)
|
| 494 |
+
mse = np.mean((img1 - img2)**2)
|
| 495 |
+
if mse == 0:
|
| 496 |
+
return float('inf')
|
| 497 |
+
return 20 * math.log10(255.0 / math.sqrt(mse))
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
# ----------
|
| 501 |
+
# SSIM
|
| 502 |
+
# ----------
|
| 503 |
+
def calculate_ssim(img1, img2, border=0):
|
| 504 |
+
'''calculate SSIM
|
| 505 |
+
the same outputs as MATLAB's
|
| 506 |
+
img1, img2: [0, 255]
|
| 507 |
+
'''
|
| 508 |
+
if not img1.shape == img2.shape:
|
| 509 |
+
raise ValueError('Input images must have the same dimensions.')
|
| 510 |
+
h, w = img1.shape[:2]
|
| 511 |
+
img1 = img1[border:h-border, border:w-border]
|
| 512 |
+
img2 = img2[border:h-border, border:w-border]
|
| 513 |
+
|
| 514 |
+
if img1.ndim == 2:
|
| 515 |
+
return ssim(img1, img2)
|
| 516 |
+
elif img1.ndim == 3:
|
| 517 |
+
if img1.shape[2] == 3:
|
| 518 |
+
ssims = []
|
| 519 |
+
for i in range(3):
|
| 520 |
+
ssims.append(ssim(img1, img2))
|
| 521 |
+
return np.array(ssims).mean()
|
| 522 |
+
elif img1.shape[2] == 1:
|
| 523 |
+
return ssim(np.squeeze(img1), np.squeeze(img2))
|
| 524 |
+
else:
|
| 525 |
+
raise ValueError('Wrong input image dimensions.')
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def ssim(img1, img2):
|
| 529 |
+
C1 = (0.01 * 255)**2
|
| 530 |
+
C2 = (0.03 * 255)**2
|
| 531 |
+
|
| 532 |
+
img1 = img1.astype(np.float64)
|
| 533 |
+
img2 = img2.astype(np.float64)
|
| 534 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
| 535 |
+
window = np.outer(kernel, kernel.transpose())
|
| 536 |
+
|
| 537 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
| 538 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
| 539 |
+
mu1_sq = mu1**2
|
| 540 |
+
mu2_sq = mu2**2
|
| 541 |
+
mu1_mu2 = mu1 * mu2
|
| 542 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
| 543 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
| 544 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
| 545 |
+
|
| 546 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
| 547 |
+
(sigma1_sq + sigma2_sq + C2))
|
| 548 |
+
return ssim_map.mean()
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
'''
|
| 552 |
+
# =======================================
|
| 553 |
+
# pytorch version of matlab imresize
|
| 554 |
+
# =======================================
|
| 555 |
+
'''
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
# matlab 'imresize' function, now only support 'bicubic'
|
| 559 |
+
def cubic(x):
|
| 560 |
+
absx = torch.abs(x)
|
| 561 |
+
absx2 = absx**2
|
| 562 |
+
absx3 = absx**3
|
| 563 |
+
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
|
| 564 |
+
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
| 568 |
+
if (scale < 1) and (antialiasing):
|
| 569 |
+
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
| 570 |
+
kernel_width = kernel_width / scale
|
| 571 |
+
|
| 572 |
+
# Output-space coordinates
|
| 573 |
+
x = torch.linspace(1, out_length, out_length)
|
| 574 |
+
|
| 575 |
+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
| 576 |
+
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
| 577 |
+
# space maps to 1.5 in input space.
|
| 578 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
| 579 |
+
|
| 580 |
+
# What is the left-most pixel that can be involved in the computation?
|
| 581 |
+
left = torch.floor(u - kernel_width / 2)
|
| 582 |
+
|
| 583 |
+
# What is the maximum number of pixels that can be involved in the
|
| 584 |
+
# computation? Note: it's OK to use an extra pixel here; if the
|
| 585 |
+
# corresponding weights are all zero, it will be eliminated at the end
|
| 586 |
+
# of this function.
|
| 587 |
+
P = math.ceil(kernel_width) + 2
|
| 588 |
+
|
| 589 |
+
# The indices of the input pixels involved in computing the k-th output
|
| 590 |
+
# pixel are in row k of the indices matrix.
|
| 591 |
+
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
|
| 592 |
+
1, P).expand(out_length, P)
|
| 593 |
+
|
| 594 |
+
# The weights used to compute the k-th output pixel are in row k of the
|
| 595 |
+
# weights matrix.
|
| 596 |
+
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
| 597 |
+
# apply cubic kernel
|
| 598 |
+
if (scale < 1) and (antialiasing):
|
| 599 |
+
weights = scale * cubic(distance_to_center * scale)
|
| 600 |
+
else:
|
| 601 |
+
weights = cubic(distance_to_center)
|
| 602 |
+
# Normalize the weights matrix so that each row sums to 1.
|
| 603 |
+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
| 604 |
+
weights = weights / weights_sum.expand(out_length, P)
|
| 605 |
+
|
| 606 |
+
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
| 607 |
+
weights_zero_tmp = torch.sum((weights == 0), 0)
|
| 608 |
+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
| 609 |
+
indices = indices.narrow(1, 1, P - 2)
|
| 610 |
+
weights = weights.narrow(1, 1, P - 2)
|
| 611 |
+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
| 612 |
+
indices = indices.narrow(1, 0, P - 2)
|
| 613 |
+
weights = weights.narrow(1, 0, P - 2)
|
| 614 |
+
weights = weights.contiguous()
|
| 615 |
+
indices = indices.contiguous()
|
| 616 |
+
sym_len_s = -indices.min() + 1
|
| 617 |
+
sym_len_e = indices.max() - in_length
|
| 618 |
+
indices = indices + sym_len_s - 1
|
| 619 |
+
return weights, indices, int(sym_len_s), int(sym_len_e)
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
# --------------------------------
|
| 623 |
+
# imresize for tensor image
|
| 624 |
+
# --------------------------------
|
| 625 |
+
def imresize(img, scale, antialiasing=True):
|
| 626 |
+
# Now the scale should be the same for H and W
|
| 627 |
+
# input: img: pytorch tensor, CHW or HW [0,1]
|
| 628 |
+
# output: CHW or HW [0,1] w/o round
|
| 629 |
+
need_squeeze = True if img.dim() == 2 else False
|
| 630 |
+
if need_squeeze:
|
| 631 |
+
img.unsqueeze_(0)
|
| 632 |
+
in_C, in_H, in_W = img.size()
|
| 633 |
+
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
| 634 |
+
kernel_width = 4
|
| 635 |
+
kernel = 'cubic'
|
| 636 |
+
|
| 637 |
+
# Return the desired dimension order for performing the resize. The
|
| 638 |
+
# strategy is to perform the resize first along the dimension with the
|
| 639 |
+
# smallest scale factor.
|
| 640 |
+
# Now we do not support this.
|
| 641 |
+
|
| 642 |
+
# get weights and indices
|
| 643 |
+
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
| 644 |
+
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
| 645 |
+
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
| 646 |
+
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
| 647 |
+
# process H dimension
|
| 648 |
+
# symmetric copying
|
| 649 |
+
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
| 650 |
+
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
| 651 |
+
|
| 652 |
+
sym_patch = img[:, :sym_len_Hs, :]
|
| 653 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 654 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 655 |
+
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
| 656 |
+
|
| 657 |
+
sym_patch = img[:, -sym_len_He:, :]
|
| 658 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 659 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 660 |
+
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
| 661 |
+
|
| 662 |
+
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
| 663 |
+
kernel_width = weights_H.size(1)
|
| 664 |
+
for i in range(out_H):
|
| 665 |
+
idx = int(indices_H[i][0])
|
| 666 |
+
for j in range(out_C):
|
| 667 |
+
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
| 668 |
+
|
| 669 |
+
# process W dimension
|
| 670 |
+
# symmetric copying
|
| 671 |
+
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
| 672 |
+
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
| 673 |
+
|
| 674 |
+
sym_patch = out_1[:, :, :sym_len_Ws]
|
| 675 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 676 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 677 |
+
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
| 678 |
+
|
| 679 |
+
sym_patch = out_1[:, :, -sym_len_We:]
|
| 680 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 681 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 682 |
+
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
| 683 |
+
|
| 684 |
+
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
| 685 |
+
kernel_width = weights_W.size(1)
|
| 686 |
+
for i in range(out_W):
|
| 687 |
+
idx = int(indices_W[i][0])
|
| 688 |
+
for j in range(out_C):
|
| 689 |
+
out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
|
| 690 |
+
if need_squeeze:
|
| 691 |
+
out_2.squeeze_()
|
| 692 |
+
return out_2
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
# --------------------------------
|
| 696 |
+
# imresize for numpy image
|
| 697 |
+
# --------------------------------
|
| 698 |
+
def imresize_np(img, scale, antialiasing=True):
|
| 699 |
+
# Now the scale should be the same for H and W
|
| 700 |
+
# input: img: Numpy, HWC or HW [0,1]
|
| 701 |
+
# output: HWC or HW [0,1] w/o round
|
| 702 |
+
img = torch.from_numpy(img)
|
| 703 |
+
need_squeeze = True if img.dim() == 2 else False
|
| 704 |
+
if need_squeeze:
|
| 705 |
+
img.unsqueeze_(2)
|
| 706 |
+
|
| 707 |
+
in_H, in_W, in_C = img.size()
|
| 708 |
+
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
| 709 |
+
kernel_width = 4
|
| 710 |
+
kernel = 'cubic'
|
| 711 |
+
|
| 712 |
+
# Return the desired dimension order for performing the resize. The
|
| 713 |
+
# strategy is to perform the resize first along the dimension with the
|
| 714 |
+
# smallest scale factor.
|
| 715 |
+
# Now we do not support this.
|
| 716 |
+
|
| 717 |
+
# get weights and indices
|
| 718 |
+
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
| 719 |
+
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
| 720 |
+
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
| 721 |
+
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
| 722 |
+
# process H dimension
|
| 723 |
+
# symmetric copying
|
| 724 |
+
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
| 725 |
+
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
| 726 |
+
|
| 727 |
+
sym_patch = img[:sym_len_Hs, :, :]
|
| 728 |
+
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
| 729 |
+
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
| 730 |
+
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
| 731 |
+
|
| 732 |
+
sym_patch = img[-sym_len_He:, :, :]
|
| 733 |
+
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
| 734 |
+
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
| 735 |
+
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
| 736 |
+
|
| 737 |
+
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
| 738 |
+
kernel_width = weights_H.size(1)
|
| 739 |
+
for i in range(out_H):
|
| 740 |
+
idx = int(indices_H[i][0])
|
| 741 |
+
for j in range(out_C):
|
| 742 |
+
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
|
| 743 |
+
|
| 744 |
+
# process W dimension
|
| 745 |
+
# symmetric copying
|
| 746 |
+
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
| 747 |
+
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
| 748 |
+
|
| 749 |
+
sym_patch = out_1[:, :sym_len_Ws, :]
|
| 750 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 751 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 752 |
+
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
| 753 |
+
|
| 754 |
+
sym_patch = out_1[:, -sym_len_We:, :]
|
| 755 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 756 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 757 |
+
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
| 758 |
+
|
| 759 |
+
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
| 760 |
+
kernel_width = weights_W.size(1)
|
| 761 |
+
for i in range(out_W):
|
| 762 |
+
idx = int(indices_W[i][0])
|
| 763 |
+
for j in range(out_C):
|
| 764 |
+
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
|
| 765 |
+
if need_squeeze:
|
| 766 |
+
out_2.squeeze_()
|
| 767 |
+
|
| 768 |
+
return out_2.numpy()
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
if __name__ == '__main__':
|
| 772 |
+
img = imread_uint('test.bmp',3)
|
utils/utils_logger.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import datetime
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def log(*args, **kwargs):
|
| 8 |
+
print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
'''
|
| 12 |
+
# ===============================
|
| 13 |
+
# logger
|
| 14 |
+
# logger_name = None = 'base' ???
|
| 15 |
+
# ===============================
|
| 16 |
+
'''
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def logger_info(logger_name, log_path='default_logger.log'):
|
| 20 |
+
''' set up logger
|
| 21 |
+
modified by Kai Zhang (github: https://github.com/cszn)
|
| 22 |
+
'''
|
| 23 |
+
log = logging.getLogger(logger_name)
|
| 24 |
+
if log.hasHandlers():
|
| 25 |
+
print('LogHandlers exist!')
|
| 26 |
+
else:
|
| 27 |
+
print('LogHandlers setup!')
|
| 28 |
+
level = logging.INFO
|
| 29 |
+
formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S')
|
| 30 |
+
fh = logging.FileHandler(log_path, mode='a')
|
| 31 |
+
fh.setFormatter(formatter)
|
| 32 |
+
log.setLevel(level)
|
| 33 |
+
log.addHandler(fh)
|
| 34 |
+
# print(len(log.handlers))
|
| 35 |
+
|
| 36 |
+
sh = logging.StreamHandler()
|
| 37 |
+
sh.setFormatter(formatter)
|
| 38 |
+
log.addHandler(sh)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
'''
|
| 42 |
+
# ===============================
|
| 43 |
+
# print to file and std_out simultaneously
|
| 44 |
+
# ===============================
|
| 45 |
+
'''
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class logger_print(object):
|
| 49 |
+
def __init__(self, log_path="default.log"):
|
| 50 |
+
self.terminal = sys.stdout
|
| 51 |
+
self.log = open(log_path, 'a')
|
| 52 |
+
|
| 53 |
+
def write(self, message):
|
| 54 |
+
self.terminal.write(message)
|
| 55 |
+
self.log.write(message) # write the message
|
| 56 |
+
|
| 57 |
+
def flush(self):
|
| 58 |
+
pass
|