|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
|
|
| class ESA(nn.Module):
|
| def __init__(self, n_feats, conv):
|
| super(ESA, self).__init__()
|
| f = n_feats // 4
|
| self.conv1 = conv(n_feats, f, kernel_size=1)
|
| self.conv_f = conv(f, f, kernel_size=1)
|
| self.conv_max = conv(f, f, kernel_size=3, padding=1)
|
| self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
|
| self.conv3 = conv(f, f, kernel_size=3, padding=1)
|
| self.conv3_ = conv(f, f, kernel_size=3, padding=1)
|
| self.conv4 = conv(f, n_feats, kernel_size=1)
|
| self.sigmoid = nn.Sigmoid()
|
| self.relu = nn.ReLU(inplace=True)
|
|
|
| def forward(self, x):
|
| c1_ = (self.conv1(x))
|
| c1 = self.conv2(c1_)
|
| v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
|
| v_range = self.relu(self.conv_max(v_max))
|
| c3 = self.relu(self.conv3(v_range))
|
| c3 = self.conv3_(c3)
|
| c3 = F.interpolate(c3, (x.size(2), x.size(3)), mode='bilinear', align_corners=False)
|
| cf = self.conv_f(c1_)
|
| c4 = self.conv4(c3+cf)
|
| m = self.sigmoid(c4)
|
|
|
| return x * m
|
|
|
|
|
| class conv(nn.Module):
|
| def __init__(self, n_feats):
|
| super(conv, self).__init__()
|
| self.conv1x1 = nn.Conv2d(n_feats, n_feats, 1, 1, 0)
|
| self.act = nn.PReLU(num_parameters=n_feats)
|
| def forward(self, x):
|
| return self.act(self.conv1x1(x))
|
|
|
|
|
| class Cell(nn.Module):
|
| def __init__(self, n_feats=48, dynamic = True, deploy = False, L= None, with_13=False):
|
| super(Cell, self).__init__()
|
|
|
| self.conv1 = conv(n_feats)
|
| self.conv2 = EDBB_deploy(n_feats,n_feats)
|
| self.conv3 = EDBB_deploy(n_feats,n_feats)
|
|
|
| self.fuse = nn.Conv2d(n_feats*2, n_feats, 1, 1, 0)
|
|
|
| self.att = ESA(n_feats, nn.Conv2d)
|
|
|
| self.branch = nn.ModuleList([nn.Conv2d(n_feats, n_feats//2, 1, 1, 0) for _ in range(4)])
|
|
|
| def forward(self, x):
|
| out1 = self.conv1(x)
|
| out2 = self.conv2(out1)
|
| out3 = self.conv3(out2)
|
|
|
|
|
| out = self.fuse(torch.cat([self.branch[0](x), self.branch[1](out1), self.branch[2](out2), self.branch[3](out3)], dim=1))
|
| out = self.att(out)
|
| out += x
|
|
|
| return out
|
|
|
|
|
| class EFDN(nn.Module):
|
| def __init__(self, scale=4, in_channels=3, n_feats=48, out_channels=3):
|
| super(EFDN, self).__init__()
|
| self.head = nn.Conv2d(in_channels, n_feats, 3, 1, 1)
|
|
|
| self.cells = nn.ModuleList([Cell(n_feats) for _ in range(4)])
|
|
|
|
|
| self.local_fuse = nn.ModuleList([nn.Conv2d(n_feats*2, n_feats, 1, 1, 0) for _ in range(3)])
|
|
|
| self.tail = nn.Sequential(
|
| nn.Conv2d(n_feats, out_channels*(scale**2), 3, 1, 1),
|
| nn.PixelShuffle(scale)
|
| )
|
|
|
| def forward(self, x):
|
|
|
| out0 = self.head(x)
|
|
|
|
|
| out1 = self.cells[0](out0)
|
| out2 = self.cells[1](out1)
|
| out2_fuse = self.local_fuse[0](torch.cat([out1, out2], dim=1))
|
| out3 = self.cells[2](out2_fuse)
|
| out3_fuse = self.local_fuse[1](torch.cat([out2, out3], dim=1))
|
| out4 = self.cells[3](out3_fuse)
|
| out4_fuse = self.local_fuse[2](torch.cat([out2, out4], dim=1))
|
|
|
| out = out4_fuse + out0
|
|
|
|
|
| out = self.tail(out)
|
|
|
| return out.clamp(0,1)
|
|
|
|
|
|
|
|
|
| def multiscale(kernel, target_kernel_size):
|
| H_pixels_to_pad = (target_kernel_size - kernel.size(2)) // 2
|
| W_pixels_to_pad = (target_kernel_size - kernel.size(3)) // 2
|
| return F.pad(kernel, [H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad])
|
|
|
|
|
| class SeqConv3x3(nn.Module):
|
| def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier):
|
| super(SeqConv3x3, self).__init__()
|
|
|
| self.type = seq_type
|
| self.inp_planes = inp_planes
|
| self.out_planes = out_planes
|
|
|
| if self.type == 'conv1x1-conv3x3':
|
| self.mid_planes = int(out_planes * depth_multiplier)
|
| conv0 = torch.nn.Conv2d(self.inp_planes, self.mid_planes, kernel_size=1, padding=0)
|
| self.k0 = conv0.weight
|
| self.b0 = conv0.bias
|
|
|
| conv1 = torch.nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3)
|
| self.k1 = conv1.weight
|
| self.b1 = conv1.bias
|
|
|
| elif self.type == 'conv1x1-sobelx':
|
| conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
|
| self.k0 = conv0.weight
|
| self.b0 = conv0.bias
|
|
|
|
|
| scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
|
| self.scale = nn.Parameter(scale)
|
|
|
|
|
|
|
| bias = torch.randn(self.out_planes) * 1e-3
|
| bias = torch.reshape(bias, (self.out_planes,))
|
| self.bias = nn.Parameter(bias)
|
|
|
| self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
|
| for i in range(self.out_planes):
|
| self.mask[i, 0, 0, 0] = 1.0
|
| self.mask[i, 0, 1, 0] = 2.0
|
| self.mask[i, 0, 2, 0] = 1.0
|
| self.mask[i, 0, 0, 2] = -1.0
|
| self.mask[i, 0, 1, 2] = -2.0
|
| self.mask[i, 0, 2, 2] = -1.0
|
| self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
|
|
| elif self.type == 'conv1x1-sobely':
|
| conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
|
| self.k0 = conv0.weight
|
| self.b0 = conv0.bias
|
|
|
|
|
| scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
|
| self.scale = nn.Parameter(torch.FloatTensor(scale))
|
|
|
|
|
|
|
| bias = torch.randn(self.out_planes) * 1e-3
|
| bias = torch.reshape(bias, (self.out_planes,))
|
| self.bias = nn.Parameter(torch.FloatTensor(bias))
|
|
|
| self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
|
| for i in range(self.out_planes):
|
| self.mask[i, 0, 0, 0] = 1.0
|
| self.mask[i, 0, 0, 1] = 2.0
|
| self.mask[i, 0, 0, 2] = 1.0
|
| self.mask[i, 0, 2, 0] = -1.0
|
| self.mask[i, 0, 2, 1] = -2.0
|
| self.mask[i, 0, 2, 2] = -1.0
|
| self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
|
|
| elif self.type == 'conv1x1-laplacian':
|
| conv0 = torch.nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
|
| self.k0 = conv0.weight
|
| self.b0 = conv0.bias
|
|
|
|
|
| scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3
|
| self.scale = nn.Parameter(torch.FloatTensor(scale))
|
|
|
|
|
|
|
| bias = torch.randn(self.out_planes) * 1e-3
|
| bias = torch.reshape(bias, (self.out_planes,))
|
| self.bias = nn.Parameter(torch.FloatTensor(bias))
|
|
|
| self.mask = torch.zeros((self.out_planes, 1, 3, 3), dtype=torch.float32)
|
| for i in range(self.out_planes):
|
| self.mask[i, 0, 0, 1] = 1.0
|
| self.mask[i, 0, 1, 0] = 1.0
|
| self.mask[i, 0, 1, 2] = 1.0
|
| self.mask[i, 0, 2, 1] = 1.0
|
| self.mask[i, 0, 1, 1] = -4.0
|
| self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
| else:
|
| raise ValueError('the type of seqconv is not supported!')
|
|
|
| def forward(self, x):
|
| if self.type == 'conv1x1-conv3x3':
|
|
|
| y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
|
|
|
| y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
|
| b0_pad = self.b0.view(1, -1, 1, 1)
|
| y0[:, :, 0:1, :] = b0_pad
|
| y0[:, :, -1:, :] = b0_pad
|
| y0[:, :, :, 0:1] = b0_pad
|
| y0[:, :, :, -1:] = b0_pad
|
|
|
| y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
|
| else:
|
| y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
|
|
|
| y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
|
| b0_pad = self.b0.view(1, -1, 1, 1)
|
| y0[:, :, 0:1, :] = b0_pad
|
| y0[:, :, -1:, :] = b0_pad
|
| y0[:, :, :, 0:1] = b0_pad
|
| y0[:, :, :, -1:] = b0_pad
|
|
|
| y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_planes)
|
| return y1
|
|
|
| def rep_params(self):
|
| device = self.k0.get_device()
|
| if device < 0:
|
| device = None
|
|
|
| if self.type == 'conv1x1-conv3x3':
|
|
|
| RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
|
|
|
| RB = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
|
| RB = F.conv2d(input=RB, weight=self.k1).view(-1,) + self.b1
|
| else:
|
| tmp = self.scale * self.mask
|
| k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3), device=device)
|
| for i in range(self.out_planes):
|
| k1[i, i, :, :] = tmp[i, 0, :, :]
|
| b1 = self.bias
|
|
|
| RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
|
|
|
| RB = torch.ones(1, self.out_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
|
| RB = F.conv2d(input=RB, weight=k1).view(-1,) + b1
|
| return RK, RB
|
|
|
|
|
| class EDBB(nn.Module):
|
| def __init__(self, inp_planes, out_planes, depth_multiplier=None, act_type='prelu', with_idt = False, deploy=False, with_13=False, gv=False):
|
| super(EDBB, self).__init__()
|
|
|
| self.deploy = deploy
|
| self.act_type = act_type
|
|
|
| self.inp_planes = inp_planes
|
| self.out_planes = out_planes
|
|
|
| self.gv = gv
|
|
|
| if depth_multiplier is None:
|
| self.depth_multiplier = 1.0
|
| else:
|
| self.depth_multiplier = depth_multiplier
|
|
|
| if deploy:
|
| self.rep_conv = nn.Conv2d(in_channels=inp_planes, out_channels=out_planes, kernel_size=3, stride=1,
|
| padding=1, bias=True)
|
| else:
|
| self.with_13 = with_13
|
| if with_idt and (self.inp_planes == self.out_planes):
|
| self.with_idt = True
|
| else:
|
| self.with_idt = False
|
|
|
| self.rep_conv = nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=3, padding=1)
|
| self.conv1x1 = nn.Conv2d(self.inp_planes, self.out_planes, kernel_size=1, padding=0)
|
| self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.inp_planes, self.out_planes, self.depth_multiplier)
|
| self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes, self.out_planes, -1)
|
| self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes, self.out_planes, -1)
|
| self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes, self.out_planes, -1)
|
|
|
| if self.act_type == 'prelu':
|
| self.act = nn.PReLU(num_parameters=self.out_planes)
|
| elif self.act_type == 'relu':
|
| self.act = nn.ReLU(inplace=True)
|
| elif self.act_type == 'rrelu':
|
| self.act = nn.RReLU(lower=-0.05, upper=0.05)
|
| elif self.act_type == 'softplus':
|
| self.act = nn.Softplus()
|
| elif self.act_type == 'linear':
|
| pass
|
| else:
|
| raise ValueError('The type of activation if not support!')
|
|
|
| def forward(self, x):
|
| if self.deploy:
|
| y = self.rep_conv(x)
|
| elif self.gv:
|
| y = self.rep_conv(x) + \
|
| self.conv1x1_sbx(x) + \
|
| self.conv1x1_sby(x) + \
|
| self.conv1x1_lpl(x) + x
|
| else:
|
| y = self.rep_conv(x) + \
|
| self.conv1x1(x) + \
|
| self.conv1x1_sbx(x) + \
|
| self.conv1x1_sby(x) + \
|
| self.conv1x1_lpl(x)
|
|
|
| if self.with_idt:
|
| y += x
|
| if self.with_13:
|
| y += self.conv1x1_3x3(x)
|
|
|
| if self.act_type != 'linear':
|
| y = self.act(y)
|
| return y
|
|
|
| def switch_to_gv(self):
|
| if self.gv:
|
| return
|
| self.gv = True
|
|
|
| K0, B0 = self.rep_conv.weight, self.rep_conv.bias
|
| K1, B1 = self.conv1x1_3x3.rep_params()
|
| K5, B5 = multiscale(self.conv1x1.weight,3), self.conv1x1.bias
|
| RK, RB = (K0+K5), (B0+B5)
|
| if self.with_13:
|
| RK, RB = RK + K1, RB + B1
|
|
|
| self.rep_conv.weight.data = RK
|
| self.rep_conv.bias.data = RB
|
|
|
| for para in self.parameters():
|
| para.detach_()
|
|
|
|
|
| def switch_to_deploy(self):
|
|
|
| if self.deploy:
|
| return
|
| self.deploy = True
|
|
|
| K0, B0 = self.rep_conv.weight, self.rep_conv.bias
|
| K1, B1 = self.conv1x1_3x3.rep_params()
|
| K2, B2 = self.conv1x1_sbx.rep_params()
|
| K3, B3 = self.conv1x1_sby.rep_params()
|
| K4, B4 = self.conv1x1_lpl.rep_params()
|
| K5, B5 = multiscale(self.conv1x1.weight,3), self.conv1x1.bias
|
| if self.gv:
|
| RK, RB = (K0+K2+K3+K4), (B0+B2+B3+B4)
|
| else:
|
| RK, RB = (K0+K2+K3+K4+K5), (B0+B2+B3+B4+B5)
|
| if self.with_13:
|
| RK, RB = RK + K1, RB + B1
|
| if self.with_idt:
|
| device = RK.get_device()
|
| if device < 0:
|
| device = None
|
| K_idt = torch.zeros(self.out_planes, self.out_planes, 3, 3, device=device)
|
| for i in range(self.out_planes):
|
| K_idt[i, i, 1, 1] = 1.0
|
| B_idt = 0.0
|
| RK, RB = RK + K_idt, RB + B_idt
|
|
|
|
|
| self.rep_conv = nn.Conv2d(in_channels=self.inp_planes, out_channels=self.out_planes, kernel_size=3, stride=1,
|
| padding=1, bias=True)
|
| self.rep_conv.weight.data = RK
|
| self.rep_conv.bias.data = RB
|
|
|
| for para in self.parameters():
|
| para.detach_()
|
|
|
|
|
| self.__delattr__('conv1x1_3x3')
|
| self.__delattr__('conv1x1')
|
| self.__delattr__('conv1x1_sbx')
|
| self.__delattr__('conv1x1_sby')
|
| self.__delattr__('conv1x1_lpl')
|
|
|
|
|
| class EDBB_deploy(nn.Module):
|
| def __init__(self, inp_planes, out_planes):
|
| super(EDBB_deploy, self).__init__()
|
|
|
| self.rep_conv = nn.Conv2d(in_channels=inp_planes, out_channels=out_planes, kernel_size=3, stride=1,
|
| padding=1, bias=True)
|
|
|
| self.act = nn.PReLU(num_parameters=out_planes)
|
|
|
| def forward(self, x):
|
| y = self.rep_conv(x)
|
| y = self.act(y)
|
|
|
| return y
|
| |