| import os |
| import numpy as np |
| import glob |
| import math |
| import torch |
| import torchvision |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn import CrossEntropyLoss, Linear, MSELoss |
| from torch.nn import ConvTranspose2d, Conv2d, MaxPool2d, BatchNorm2d |
| |
| import torchvision.models as models |
| from torchvision import datasets, transforms |
| from torchvision.io import read_image |
| from torch.utils.data import DataLoader, Dataset |
| import torch.optim as optim |
| from torch.autograd import Variable |
| from torchsummary import summary |
| class Nothing(nn.Module): |
| def __init__(self): |
| super(Nothing,self).__init__() |
| def forward(self, radar,satellite): |
| return radar, satellite |
|
|
| class ConvBlock(nn.Module): |
| def __init__(self, in_channels, out_channels): |
| super(ConvBlock, self).__init__() |
| |
| |
| |
| self.conv = nn.Sequential( |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True) |
| ) |
|
|
| def forward(self, x): |
| x = self.conv(x) |
| return x |
|
|
| class UpConv(nn.Module): |
| def __init__(self, in_channels, out_channels): |
| super(UpConv, self).__init__() |
| self.up = nn.Sequential( |
| nn.Upsample(scale_factor=2), |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding='same', bias=True), |
| nn.BatchNorm2d(out_channels), |
| nn.ReLU(inplace=True) |
| ) |
| def forward(self, x): |
| x = self.up(x) |
| return x |
| class AttentionBlock(nn.Module): |
| """Attention block with learnable parameters""" |
| def __init__(self, F_g, F_l, n_coefficients): |
| """ |
| :param F_g: number of feature maps (channels) in previous layer |
| :param F_l: number of feature maps in corresponding encoder layer, transferred via skip connection |
| :param n_coefficients: number of learnable multi-dimensional attention coefficients |
| """ |
|
|
| super(AttentionBlock, self).__init__() |
|
|
| self.W_gate = nn.Sequential( |
| nn.Conv2d(F_g, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True), |
| nn.BatchNorm2d(n_coefficients) |
| ) |
|
|
| self.W_x = nn.Sequential( |
| nn.Conv2d(F_l, n_coefficients, kernel_size=1, stride=1, padding=0, bias=True), |
| nn.BatchNorm2d(n_coefficients) |
| ) |
|
|
| self.psi = nn.Sequential( |
| nn.Conv2d(n_coefficients, 1, kernel_size=1, stride=1, padding=0, bias=True), |
| nn.BatchNorm2d(1), |
| nn.Sigmoid() |
| ) |
|
|
| self.relu = nn.ReLU(inplace=True) |
|
|
| def forward(self, gate, skip_connection): |
| """ |
| :param gate: gating signal from previous layer |
| :param skip_connection: activation from corresponding encoder layer |
| :return: output activations |
| """ |
| g1 = self.W_gate(gate) |
| x1 = self.W_x(skip_connection) |
| psi = self.relu(g1 + x1) |
| psi = self.psi(psi) |
| out = skip_connection * psi |
| return out |
|
|
| class Recurrent_block(nn.Module): |
| def __init__(self,ch_out,t=2): |
| super(Recurrent_block,self).__init__() |
| self.t = t |
| self.ch_out = ch_out |
| self.conv = nn.Sequential( |
| nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding='same',bias=True), |
| nn.BatchNorm2d(ch_out), |
| nn.ReLU(inplace=True) |
| ) |
|
|
| def forward(self,x): |
| for i in range(self.t): |
|
|
| if i==0: |
| x1 = self.conv(x) |
|
|
| x1 = self.conv(x+x1) |
| return x1 |
|
|
| class RRCNN_block(nn.Module): |
| def __init__(self,ch_in,ch_out,t=2): |
| super(RRCNN_block,self).__init__() |
| self.RCNN = nn.Sequential( |
| Recurrent_block(ch_out,t=t), |
| Recurrent_block(ch_out,t=t) |
| ) |
| self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding='same') |
|
|
| def forward(self,x): |
| x = self.Conv_1x1(x) |
| x1 = self.RCNN(x) |
| return x+x1 |
|
|
|
|
| class single_conv(nn.Module): |
| def __init__(self,ch_in,ch_out): |
| super(single_conv,self).__init__() |
| self.conv = nn.Sequential( |
| nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding='same',bias=True), |
| nn.BatchNorm2d(ch_out), |
| nn.ReLU(inplace=True) |
| ) |
|
|
| def forward(self,x): |
| x = self.conv(x) |
| return x |
| class Unet(nn.Module): |
| def __init__(self, rad_channel=1,sat_channel=1, rad_size=640, sat_size=20): |
| super(Unet, self).__init__() |
| assert rad_size % sat_size == 0, "rad_size must be divisible by sat_size" |
| ratio = rad_size // sat_size |
| assert (ratio & (ratio - 1)) == 0, "rad_size/sat_size must be a power of 2" |
| self.n_pool = int(math.log2(ratio)) |
| |
| self.encoder_blocks = nn.ModuleList() |
| self.pools = nn.ModuleList() |
| for i in range(self.n_pool): |
| in_c = rad_channel * (2**(i)) |
| out_c = rad_channel * (2**(i+1)) |
| self.encoder_blocks.append(ConvBlock(in_c, out_c)) |
| if i < self.n_pool: |
| self.pools.append(nn.MaxPool2d(kernel_size=2, stride=2)) |
| |
| self.mid_conv_1 = single_conv(out_c, out_c) |
| self.mid_conv_2 = single_conv(sat_channel, out_c) |
| self.mid_merge = ConvBlock(2*out_c, out_c) |
| |
| self.up_convs = nn.ModuleList() |
| self.decoder_blocks = nn.ModuleList() |
| for i in reversed(range(self.n_pool)): |
| up_in = rad_channel * (2**(i+2)) |
| up_out = rad_channel * (2**(i+1)) |
| self.up_convs.append(UpConv(up_in, up_out)) |
| self.decoder_blocks.append(ConvBlock(up_in, up_out)) |
| self.final_decoder = ConvBlock(4*rad_channel, 2*rad_channel) |
| self.out_conv_R = nn.Conv2d(2*rad_channel, rad_channel, kernel_size=1, padding='same') |
| self.out_conv_S = nn.Conv2d(out_c, sat_channel, kernel_size=1, padding='same') |
| def forward(self, radar, satellite): |
| |
| enc_feats = [] |
| x = radar |
| for i, block in enumerate(self.encoder_blocks): |
| x = block(x) |
| enc_feats.append(x) |
| if i < self.n_pool: |
| x = self.pools[i](x) |
| |
| x = F.relu(self.mid_conv_1(x)) |
| y = F.relu(self.mid_conv_2(satellite)) |
| x = torch.cat((x, y), dim=1) |
|
|
| mid_out = self.mid_merge(x) |
| pred_sat = self.out_conv_S(mid_out) |
| |
| x = x |
| for i in range(self.n_pool): |
| x = self.up_convs[i](x) |
| x = torch.cat((enc_feats[self.n_pool - 1 - i], x), dim=1) |
| x = self.decoder_blocks[i](x) |
| x = torch.cat((enc_feats[0], x), dim=1) |
| x = self.final_decoder(x) |
| pred_rad = self.out_conv_R(x) |
| return pred_rad, pred_sat |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| class R2Unet(nn.Module): |
| def __init__(self,num_channel=1,t=2): |
| super(R2Unet, self).__init__() |
| self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2) |
| self.RRCNN1 = RRCNN_block(5,2*num_channel,t=t) |
| self.RRCNN2 = RRCNN_block(2*num_channel,4*num_channel,t=t) |
| self.RRCNN3 = RRCNN_block(4*num_channel,8*num_channel,t=t) |
| self.RRCNN4 = RRCNN_block(8*num_channel,16*num_channel,t=t) |
| self.RRCNN5 = RRCNN_block(16*num_channel,32*num_channel,t=t) |
| self.mid_conv_1 = single_conv(32*num_channel,32*num_channel) |
| self.mid_conv_2 = single_conv(2, 32*num_channel) |
| self.MidConv = RRCNN_block(64*num_channel, 32*num_channel) |
| self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same') |
| self.Up5 = UpConv(64*num_channel, 32*num_channel) |
| self.UpRRCNN5 = RRCNN_block(64*num_channel, 32*num_channel) |
| self.Up4 = UpConv(32*num_channel, 16*num_channel) |
| self.UpRRCNN4 = RRCNN_block(32*num_channel, 16*num_channel) |
| self.Up3 = UpConv(16*num_channel, 8*num_channel) |
| self.UpRRCNN3 = RRCNN_block(16*num_channel, 8*num_channel) |
| self.Up2 = UpConv(8*num_channel, 4*num_channel) |
| self.UpRRCNN2 = RRCNN_block(8*num_channel, 4*num_channel) |
| self.Up1 = UpConv(4*num_channel, 2*num_channel) |
| self.UpRRCNN1 = RRCNN_block(4*num_channel, 2*num_channel) |
| self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same') |
| def forward(self, radar,satellite): |
| e1 = self.RRCNN1(radar) |
| e2 = self.MaxPool(e1) |
| e2 = self.RRCNN2(e2) |
| e3 = self.MaxPool(e2) |
| e3 = self.RRCNN3(e3) |
| e4 = self.MaxPool(e3) |
| e4 = self.RRCNN4(e4) |
| e5 = self.MaxPool(e4) |
| e5 = self.RRCNN5(e5) |
| e6 = self.MaxPool(e5) |
| X = F.relu(self.mid_conv_1(e6)) |
| Y = F.relu(self.mid_conv_2(satellite)) |
| X = torch.cat((X,Y),1) |
| Y = self.MidConv(X) |
| pred_satellite = self.out_conv_S(Y) |
| d5 = self.Up5(X) |
| d5 = torch.cat((e5, d5), dim=1) |
| d5 = self.UpRRCNN5(d5) |
| d4 = self.Up4(d5) |
| d4 = torch.cat((e4, d4), dim=1) |
| d4 = self.UpRRCNN4(d4) |
| d3 = self.Up3(d4) |
| d3 = torch.cat((e3, d3), dim=1) |
| d3 = self.UpRRCNN3(d3) |
| d2 = self.Up2(d3) |
| d2 = torch.cat((e2, d2), dim=1) |
| d2 = self.UpRRCNN2(d2) |
| d1 = self.Up1(d2) |
| d0 = torch.cat((e1, d1), dim=1) |
| d0 = self.UpRRCNN1(d0) |
| pred_radar = self.out_conv_R(d0) |
| return pred_radar, pred_satellite |
| class AttUnet(nn.Module): |
| def __init__(self,num_channel=1): |
| super(AttUnet, self).__init__() |
| self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2) |
| self.Conv1 = ConvBlock(5, 2*num_channel) |
| self.Conv2 = ConvBlock(2*num_channel, 4*num_channel) |
| self.Conv3 = ConvBlock(4*num_channel, 8*num_channel) |
| self.Conv4 = ConvBlock(8*num_channel, 16*num_channel) |
| self.Conv5 = ConvBlock(16*num_channel, 32*num_channel) |
| self.mid_conv_1 = single_conv(32*num_channel,32*num_channel) |
| self.mid_conv_2 = single_conv(2, 32*num_channel) |
| self.MidConv = ConvBlock(64*num_channel, 32*num_channel) |
| self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same') |
| self.Up5 = UpConv(64*num_channel, 32*num_channel) |
| self.Att5 = AttentionBlock(F_g=32*num_channel, F_l=32*num_channel, n_coefficients=16*num_channel) |
| self.UpConv5 = ConvBlock(64*num_channel, 32*num_channel) |
| self.Up4 = UpConv(32*num_channel, 16*num_channel) |
| self.Att4 = AttentionBlock(F_g=16*num_channel, F_l=16*num_channel, n_coefficients=8*num_channel) |
| self.UpConv4 = ConvBlock(32*num_channel, 16*num_channel) |
| self.Up3 = UpConv(16*num_channel, 8*num_channel) |
| self.Att3 = AttentionBlock(F_g=8*num_channel, F_l=8*num_channel, n_coefficients=4*num_channel) |
| self.UpConv3 = ConvBlock(16*num_channel, 8*num_channel) |
| self.Up2 = UpConv(8*num_channel, 4*num_channel) |
| self.Att2 = AttentionBlock(F_g=4*num_channel, F_l=4*num_channel, n_coefficients=2*num_channel) |
| self.UpConv2 = ConvBlock(8*num_channel, 4*num_channel) |
| self.Up1 = UpConv(4*num_channel, 2*num_channel) |
| self.Att1 = AttentionBlock(F_g=2*num_channel, F_l=2*num_channel, n_coefficients=1*num_channel) |
| self.UpConv1 = ConvBlock(4*num_channel, 2*num_channel) |
| self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same') |
| def forward(self, radar,satellite): |
| e1 = self.Conv1(radar) |
| e2 = self.MaxPool(e1) |
| e2 = self.Conv2(e2) |
| e3 = self.MaxPool(e2) |
| e3 = self.Conv3(e3) |
| e4 = self.MaxPool(e3) |
| e4 = self.Conv4(e4) |
| e5 = self.MaxPool(e4) |
| e5 = self.Conv5(e5) |
| e6 = self.MaxPool(e5) |
| X = F.relu(self.mid_conv_1(e6)) |
| Y = F.relu(self.mid_conv_2(satellite)) |
| X = torch.cat((X,Y),1) |
| Y = self.MidConv(X) |
| pred_satellite = self.out_conv_S(Y) |
| d5 = self.Up5(X) |
| s4 = self.Att5(gate=d5, skip_connection=e5) |
| d5 = torch.cat((s4, d5), dim=1) |
| d5 = self.UpConv5(d5) |
| d4 = self.Up4(d5) |
| s3 = self.Att4(gate=d4, skip_connection=e4) |
| d4 = torch.cat((s3, d4), dim=1) |
| d4 = self.UpConv4(d4) |
| d3 = self.Up3(d4) |
| s2 = self.Att3(gate=d3, skip_connection=e3) |
| d3 = torch.cat((s2, d3), dim=1) |
| d3 = self.UpConv3(d3) |
| d2 = self.Up2(d3) |
| s1 = self.Att2(gate=d2, skip_connection=e2) |
| d2 = torch.cat((s1, d2), dim=1) |
| d2 = self.UpConv2(d2) |
| d1 = self.Up1(d2) |
| s0 = self.Att1(gate=d1, skip_connection=e1) |
| d0 = torch.cat((s0, d1), dim=1) |
| d0 = self.UpConv1(d0) |
| pred_radar = self.out_conv_R(d0) |
| return pred_radar, pred_satellite |
| class AttR2Unet(nn.Module): |
| def __init__(self,num_channel=1,t=2): |
| super(AttR2Unet, self).__init__() |
| self.MaxPool = nn.MaxPool2d(kernel_size=2, stride=2) |
| self.RRCNN1 = RRCNN_block(5, 2*num_channel) |
| self.RRCNN2 = RRCNN_block(2*num_channel, 4*num_channel) |
| self.RRCNN3 = RRCNN_block(4*num_channel, 8*num_channel) |
| self.RRCNN4 = RRCNN_block(8*num_channel, 16*num_channel) |
| self.RRCNN5 = RRCNN_block(16*num_channel, 32*num_channel) |
| self.mid_conv_1 = single_conv(32*num_channel,32*num_channel) |
| self.mid_conv_2 = single_conv(2, 32*num_channel) |
| self.MidConv = RRCNN_block(64*num_channel, 32*num_channel) |
| self.out_conv_S = Conv2d(32*num_channel, 1, (1, 1), padding= 'same') |
| self.Up5 = UpConv(64*num_channel, 32*num_channel) |
| self.Att5 = AttentionBlock(F_g=32*num_channel, F_l=32*num_channel, n_coefficients=16*num_channel) |
| self.UpRRCNN5 = RRCNN_block(64*num_channel, 32*num_channel) |
| self.Up4 = UpConv(32*num_channel, 16*num_channel) |
| self.Att4 = AttentionBlock(F_g=16*num_channel, F_l=16*num_channel, n_coefficients=8*num_channel) |
| self.UpRRCNN4 = RRCNN_block(32*num_channel, 16*num_channel) |
| self.Up3 = UpConv(16*num_channel, 8*num_channel) |
| self.Att3 = AttentionBlock(F_g=8*num_channel, F_l=8*num_channel, n_coefficients=4*num_channel) |
| self.UpRRCNN3 = RRCNN_block(16*num_channel, 8*num_channel) |
| self.Up2 = UpConv(8*num_channel, 4*num_channel) |
| self.Att2 = AttentionBlock(F_g=4*num_channel, F_l=4*num_channel, n_coefficients=2*num_channel) |
| self.UpRRCNN2 = RRCNN_block(8*num_channel, 4*num_channel) |
| self.Up1 = UpConv(4*num_channel, 2*num_channel) |
| self.Att1 = AttentionBlock(F_g=2*num_channel, F_l=2*num_channel, n_coefficients=1*num_channel) |
| self.UpRRCNN1 = RRCNN_block(4*num_channel, 2*num_channel) |
| self.out_conv_R = Conv2d(2*num_channel, 1, (1, 1), padding= 'same') |
| def forward(self, radar,satellite): |
| e1 = self.RRCNN1(radar) |
| e2 = self.MaxPool(e1) |
| e2 = self.RRCNN2(e2) |
| e3 = self.MaxPool(e2) |
| e3 = self.RRCNN3(e3) |
| e4 = self.MaxPool(e3) |
| e4 = self.RRCNN4(e4) |
| e5 = self.MaxPool(e4) |
| e5 = self.RRCNN5(e5) |
| e6 = self.MaxPool(e5) |
| X = F.relu(self.mid_conv_1(e6)) |
| Y = F.relu(self.mid_conv_2(satellite)) |
| X = torch.cat((X,Y),1) |
| Y = self.MidConv(X) |
| pred_satellite = self.out_conv_S(Y) |
| d5 = self.Up5(X) |
| s4 = self.Att5(gate=d5, skip_connection=e5) |
| d5 = torch.cat((s4, d5), dim=1) |
| d5 = self.UpRRCNN5(d5) |
| d4 = self.Up4(d5) |
| s3 = self.Att4(gate=d4, skip_connection=e4) |
| d4 = torch.cat((s3, d4), dim=1) |
| d4 = self.UpRRCNN4(d4) |
| d3 = self.Up3(d4) |
| s2 = self.Att3(gate=d3, skip_connection=e3) |
| d3 = torch.cat((s2, d3), dim=1) |
| d3 = self.UpRRCNN3(d3) |
| d2 = self.Up2(d3) |
| s1 = self.Att2(gate=d2, skip_connection=e2) |
| d2 = torch.cat((s1, d2), dim=1) |
| d2 = self.UpRRCNN2(d2) |
| d1 = self.Up1(d2) |
| s0 = self.Att1(gate=d1, skip_connection=e1) |
| d0 = torch.cat((s0, d1), dim=1) |
| d0 = self.UpRRCNN1(d0) |
| pred_radar = self.out_conv_R(d0) |
| return pred_radar, pred_satellite |
| class Network(nn.Module): |
| def __init__(self,model_type:str,rad_channel:int, sat_channel:int,rad_size:int,sat_size:int): |
| super(Network,self).__init__() |
| print(model_type) |
| if(model_type == "Nothing"): |
| self.net = Nothing() |
| elif(model_type == "Unet"): |
| self.net = Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size) |
| elif(model_type == "Unet"): |
| self.net = Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size) |
| elif(model_type == "R2Unet"): |
| self.net = R2Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size) |
| elif(model_type == "AttUnet"): |
| self.net = AttUnet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size) |
| elif(model_type == "AttR2Unet"): |
| self.net = AttR2Unet(rad_channel=rad_channel,sat_channel=sat_channel,rad_size=rad_size,sat_size=sat_size) |
| else: |
| raise ValueError("model_type is wrong") |
| def forward(self, radar,satellite): |
| pred_radar, pred_satellite = self.net.forward(radar,satellite) |
| return pred_radar, pred_satellite |