| from collections import OrderedDict |
| import torch |
| from torch import nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Optional, List |
| |
|
|
| class LoRALayer(): |
| def __init__( |
| self, |
| r: int, |
| lora_alpha: int, |
| lora_dropout: float, |
| merge_weights: bool, |
| ): |
| self.r = r |
| self.lora_alpha = lora_alpha |
| |
| if lora_dropout > 0.: |
| self.lora_dropout = nn.Dropout(p=lora_dropout) |
| else: |
| self.lora_dropout = lambda x: x |
| |
| self.merged = False |
| self.merge_weights = merge_weights |
|
|
| class Lora_Conv2d(nn.Conv2d, LoRALayer): |
| |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| r: int = 0, |
| lora_alpha: int = 1, |
| lora_dropout: float = 0., |
| merge_weights: bool = True, |
| **kwargs |
| ): |
| nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) |
| LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
| merge_weights=merge_weights) |
| assert type(kernel_size) is int |
| |
| |
| |
| if r > 0: |
| self.lora_A = nn.Parameter( |
| self.weight.new_zeros((r*kernel_size, in_channels*kernel_size)) |
| ) |
| self.lora_B = nn.Parameter( |
| self.weight.new_zeros((out_channels*kernel_size, r*kernel_size)) |
| ) |
| self.scaling = self.lora_alpha / self.r |
| |
| self.weight.requires_grad = False |
| |
| |
| |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| nn.Conv2d.reset_parameters(self) |
| if hasattr(self, 'lora_A'): |
| |
| nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B) |
|
|
| def train(self, mode: bool = True): |
| |
| nn.Conv2d.train(self, mode) |
| if mode: |
| if self.merge_weights and self.merged: |
| |
| self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling |
| self.merged = False |
| else: |
| |
| |
| if self.merge_weights and not self.merged: |
| |
| |
| |
| self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling |
| self.merged = True |
|
|
| def forward(self, x: torch.Tensor): |
| |
| if self.r > 0 and not self.merged: |
| |
| |
| |
| return F.conv2d( |
| x, |
| self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling, |
| self.bias, self.stride, self.padding, self.dilation, self.groups |
| ) |
| |
| return nn.Conv2d.forward(self, x) |
| def _make_pair(value): |
| if isinstance(value, int): |
| value = (value,) * 2 |
| return value |
|
|
|
|
| def conv_layer(in_channels, |
| out_channels, |
| kernel_size, |
| bias=True): |
| """ |
| Re-write convolution layer for adaptive `padding`. |
| """ |
| kernel_size = _make_pair(kernel_size) |
| padding = (int((kernel_size[0] - 1) / 2), |
| int((kernel_size[1] - 1) / 2)) |
| return nn.Conv2d(in_channels, |
| out_channels, |
| kernel_size, |
| padding=padding, |
| bias=bias) |
|
|
|
|
| def activation(act_type, inplace=True, neg_slope=0.05, n_prelu=1): |
| """ |
| Activation functions for ['relu', 'lrelu', 'prelu']. |
| Parameters |
| ---------- |
| act_type: str |
| one of ['relu', 'lrelu', 'prelu']. |
| inplace: bool |
| whether to use inplace operator. |
| neg_slope: float |
| slope of negative region for `lrelu` or `prelu`. |
| n_prelu: int |
| `num_parameters` for `prelu`. |
| ---------- |
| """ |
| act_type = act_type.lower() |
| if act_type == 'relu': |
| layer = nn.ReLU(inplace) |
| elif act_type == 'lrelu': |
| layer = nn.LeakyReLU(neg_slope, inplace) |
| elif act_type == 'prelu': |
| layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) |
| else: |
| raise NotImplementedError( |
| 'activation layer [{:s}] is not found'.format(act_type)) |
| return layer |
|
|
|
|
| def sequential(*args): |
| """ |
| Modules will be added to the a Sequential Container in the order they |
| are passed. |
| |
| Parameters |
| ---------- |
| args: Definition of Modules in order. |
| ------- |
| """ |
| if len(args) == 1: |
| if isinstance(args[0], OrderedDict): |
| raise NotImplementedError( |
| 'sequential does not support OrderedDict input.') |
| return args[0] |
| modules = [] |
| for module in args: |
| if isinstance(module, nn.Sequential): |
| for submodule in module.children(): |
| modules.append(submodule) |
| elif isinstance(module, nn.Module): |
| modules.append(module) |
| return nn.Sequential(*modules) |
|
|
|
|
| def pixelshuffle_block(in_channels, |
| out_channels, |
| upscale_factor=2, |
| kernel_size=3): |
| """ |
| Upsample features according to `upscale_factor`. |
| """ |
| conv = conv_layer(in_channels, |
| out_channels * (upscale_factor ** 2), |
| kernel_size) |
| pixel_shuffle = nn.PixelShuffle(upscale_factor) |
| return sequential(conv, pixel_shuffle) |
|
|
| class Conv3XC(nn.Module): |
| def __init__(self, c_in, c_out, gain1=1, gain2=0, s=1, bias=True, relu=False): |
| super(Conv3XC, self).__init__() |
| self.weight_concat = None |
| self.bias_concat = None |
| self.update_params_flag = False |
| self.stride = s |
| self.has_relu = relu |
| |
|
|
| self.eval_conv = nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=3, padding=1, stride=s, bias=bias) |
|
|
| def forward(self, x): |
| out = self.eval_conv(x) |
| if self.has_relu: |
| out = F.leaky_relu(out, negative_slope=0.05) |
| return out |
|
|
|
|
| class SPAB(nn.Module): |
| def __init__(self, |
| in_channels, |
| mid_channels=None, |
| out_channels=None, |
| bias=False): |
| super(SPAB, self).__init__() |
| if mid_channels is None: |
| mid_channels = in_channels |
| if out_channels is None: |
| out_channels = in_channels |
|
|
| self.in_channels = in_channels |
| self.c1_r = Conv3XC(in_channels, mid_channels, gain1=2, s=1) |
| self.c2_r = Conv3XC(mid_channels, mid_channels, gain1=2, s=1) |
| self.c3_r = Conv3XC(mid_channels, out_channels, gain1=2, s=1) |
| self.act1 = torch.nn.SiLU(inplace=True) |
| |
|
|
| def forward(self, x): |
| out1 = (self.c1_r(x)) |
| out1_act = self.act1(out1) |
|
|
| out2 = (self.c2_r(out1_act)) |
| out2_act = self.act1(out2) |
|
|
| out3 = (self.c3_r(out2_act)) |
|
|
| sim_att = torch.sigmoid(out3) - 0.5 |
| out = (out3 + x) * sim_att |
| |
| |
| return out, out1, out2,out3 |
|
|
|
|
| class DSCF(nn.Module): |
| """ |
| Swift Parameter-free Attention Network for Efficient Super-Resolution |
| """ |
|
|
| def __init__(self, |
| num_in_ch, |
| num_out_ch, |
| feature_channels=26, |
| upscale=4, |
| bias=True, |
| img_range=255., |
| rgb_mean=(0.4488, 0.4371, 0.4040) |
| ): |
| super(DSCF, self).__init__() |
|
|
| in_channels = num_in_ch |
| out_channels = num_out_ch |
| self.img_range = img_range |
| self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) |
|
|
| self.conv_1 = Conv3XC(in_channels, feature_channels, gain1=2, s=1) |
| self.block_1 = SPAB(feature_channels, bias=bias) |
| self.block_2 = SPAB(feature_channels, bias=bias) |
| self.block_3 = SPAB(feature_channels, bias=bias) |
| self.block_4 = SPAB(feature_channels, bias=bias) |
| self.block_5 = SPAB(feature_channels, bias=bias) |
| self.block_6 = SPAB(feature_channels, bias=bias) |
|
|
| self.conv_cat = conv_layer(feature_channels * 4, feature_channels, kernel_size=1, bias=True) |
| self.conv_2 = Conv3XC(feature_channels, feature_channels, gain1=2, s=1) |
| |
| self.upsampler = pixelshuffle_block(feature_channels, out_channels, upscale_factor=upscale) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.eval().cuda() |
| input_tensor = torch.randn(1, 3, 256, 256).cuda() |
| output = self(input_tensor) |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def forward(self, x, return_features=False): |
| |
| self.mean = self.mean.type_as(x) |
| x = (x - self.mean) * self.img_range |
|
|
| out_feature = self.conv_1(x) |
|
|
| out_b1, out_b1_1, out_b1_2, out_b1_3 = self.block_1(out_feature) |
| out_b2, out_b2_1, out_b2_2, out_b2_3 = self.block_2(out_b1) |
| out_b3, out_b3_1, out_b3_2, out_b3_3 = self.block_3(out_b2) |
|
|
| out_b4, _, _, _ = self.block_4(out_b3) |
| out_b5, _, _, _ = self.block_5(out_b4) |
| out_b6, out_b5_2, _, _ = self.block_6(out_b5) |
|
|
| out_b6 = self.conv_2(out_b6) |
| out = self.conv_cat(torch.cat([out_feature, out_b6, out_b1, out_b5_2], 1)) |
| output = self.upsampler(out) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if return_features: |
| return output, features |
| return output |
| |
|
|