| from torch import nn |
| import torch |
| from method.MambaCSSM import MambaCSSM |
|
|
| class MambaCSSMUnet(nn.Module): |
|
|
| def __init__(self, output_classes = 2): |
| super(MambaCSSMUnet, self).__init__() |
|
|
| |
| self.conv_block_1 = nn.Sequential( |
| nn.Conv2d(6, 16, 3, 1, padding=1), |
| nn.BatchNorm2d(16), |
| nn.ReLU(), |
| nn.Conv2d(16, 16, 3, 1, padding=1), |
| nn.BatchNorm2d(16), |
| nn.ReLU() |
| ) |
|
|
| self.mp_block_1 = nn.MaxPool2d(2, 2, return_indices=True) |
|
|
| self.conv_block_2 = nn.Sequential( |
| nn.Conv2d(16, 32, 3, 1, padding=1), |
| nn.BatchNorm2d(32), |
| nn.ReLU(), |
| nn.Conv2d(32, 32, 3, 1, padding=1), |
| nn.BatchNorm2d(32), |
| nn.ReLU() |
| ) |
|
|
| self.mp_block_2 = nn.MaxPool2d(2, 2, return_indices=True) |
|
|
| self.conv_block_3 = nn.Sequential( |
| nn.Conv2d(32, 64, 3, 1, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.Conv2d(64, 64, 3, 1, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU() |
| ) |
|
|
| self.mp_block_3 = nn.MaxPool2d(2, 2, return_indices=True) |
|
|
| self.conv_block_4 = nn.Sequential( |
| nn.Conv2d(64, 128, 3, 1, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU(), |
| nn.Conv2d(128, 128, 3, 1, padding=1), |
| nn.BatchNorm2d(128), |
| nn.ReLU() |
| ) |
|
|
| self.mp_block_4 = nn.MaxPool2d(2, 2, return_indices=True) |
|
|
| |
|
|
|
|
| self.mamba = MambaCSSM(num_layers=4, d_model=256,d_conv=4, d_state=16) |
|
|
| |
| |
| self.mpu_block_4 = nn.MaxUnpool2d(2, 2) |
| self.conv_4 = nn.Sequential( |
| nn.Conv2d(256, 128, 3, 1, padding=1), |
| nn.ReLU() |
| ) |
| self.deconv_4_block = nn.Sequential( |
| nn.ConvTranspose2d(128, 64, 3, 1, padding=1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(64, 64, 3, 1, padding=1), |
| nn.ReLU() |
| ) |
|
|
| self.mpu_block_3 = nn.MaxUnpool2d(2, 2) |
|
|
| self.conv_3 = nn.Sequential( |
| nn.Conv2d(128, 64, 3, 1, padding=1), |
| nn.ReLU() |
| ) |
|
|
| self.deconv_3_block = nn.Sequential( |
| nn.ConvTranspose2d(64, 32, 3, 1, padding=1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(32, 32, 3, 1, padding=1), |
| nn.ReLU() |
| ) |
|
|
| self.mpu_block_2 = nn.MaxUnpool2d(2, 2) |
|
|
| self.conv_2 = nn.Sequential( |
| nn.Conv2d(64, 32, 3, 1, padding=1), |
| nn.ReLU() |
| ) |
|
|
| self.deconv_2_block = nn.Sequential( |
| nn.ConvTranspose2d(32, 16, 3, 1, padding=1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(16, 16, 3, 1, padding=1), |
| nn.ReLU() |
| ) |
|
|
| self.mpu_block_1 = nn.MaxUnpool2d(2, 2) |
|
|
| self.conv_1 = nn.Sequential( |
| nn.Conv2d(32, 16, 3, 1, padding=1), |
| nn.ReLU() |
| ) |
|
|
| self.deconv_1_block = nn.Sequential( |
| nn.ConvTranspose2d(16, 8, 3, 1, padding=1), |
| nn.ReLU(), |
| nn.ConvTranspose2d(8, 6, 3, 1, padding=1), |
| nn.ReLU() |
| ) |
|
|
| self.conv_final = nn.Conv2d(6, output_classes, 1, 1) |
|
|
|
|
| def forward(self, t1,t2): |
|
|
| t = torch.cat([t1,t2], dim = 1) |
|
|
| x1 = self.conv_block_1(t) |
| f1, i1 = self.mp_block_1(x1) |
| x2 = self.conv_block_2(f1) |
| f2, i2 = self.mp_block_2(x2) |
| x3 = self.conv_block_3(f2) |
| f3, i3 = self.mp_block_3(x3) |
| x4 = self.conv_block_4(f3) |
| f4, i4 = self.mp_block_4(x4) |
|
|
|
|
|
|
| b,c,h,w = f4.shape |
| f4_t1 = f4[:,:c//2, :,:] |
| f4_t2 = f4[:,c//2:, :,:] |
|
|
|
|
|
|
| |
| f4_t1 = f4_t1.view((-1, 64, 16*16)) |
| f4_t2 = f4_t2.view((-1, 64, 16*16)) |
| f5_t1,f5_t2 = self.mamba(f4_t1, f4_t2) |
| f5_t1 = f5_t1.view((-1, 64, 16, 16)) |
| f5_t2 = f5_t2.view((-1, 64, 16, 16)) |
|
|
| f5 = torch.cat([f5_t1, f5_t2], dim = 1) |
|
|
|
|
| f6 = self.mpu_block_4(f5, i4) |
| f7 = self.conv_4(torch.cat((x4, f6), dim=1)) |
| f8 = self.deconv_4_block(f7) |
|
|
| f9 = self.mpu_block_3(f8, i3, output_size=x3.size()) |
| f10 = self.conv_3(torch.cat((f9, x3), dim=1)) |
| f11 = self.deconv_3_block(f10) |
|
|
| f12 = self.mpu_block_2(f11, i2) |
| f13 = self.conv_2(torch.cat((f12, x2), dim=1)) |
|
|
| f14 = self.deconv_2_block(f13) |
|
|
| f15 = self.mpu_block_1(f14, i1) |
| f16 = self.conv_1(torch.cat((f15, x1), dim=1)) |
| f17 = self.deconv_1_block(f16) |
| f18 = self.conv_final(f17) |
|
|
|
|
|
|
|
|
|
|
|
|
| return f18 |