| import torch.nn as nn
|
|
|
| from torch.nn import functional as F
|
| from timm import create_model
|
|
|
|
|
| __all__ = ['NoiseTransformer']
|
|
|
| class NoiseTransformer(nn.Module):
|
| def __init__(self, resolution=128):
|
| super().__init__()
|
| self.upsample = lambda x: F.interpolate(x, [224,224])
|
| self.downsample = lambda x: F.interpolate(x, [resolution,resolution])
|
| self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
|
| self.downconv = nn.Conv2d(4,3,(1,1),(1,1),(0,0))
|
|
|
| self.swin = create_model("swin_tiny_patch4_window7_224",pretrained=True)
|
|
|
|
|
| def forward(self, x, residual=False):
|
| if residual:
|
| x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x))))) + x
|
| else:
|
| x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x)))))
|
|
|
| return x |