Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import nn | |
| from .build_contextpath import build_contextpath | |
| import warnings | |
| warnings.filterwarnings(action='ignore') | |
| class ConvBlock(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, | |
| stride=stride, padding=padding, bias=False) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.ReLU() | |
| def forward(self, input): | |
| x = self.conv1(input) | |
| return self.relu(self.bn(x)) | |
| class Spatial_path(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.convblock1 = ConvBlock(in_channels=3, out_channels=64) | |
| self.convblock2 = ConvBlock(in_channels=64, out_channels=128) | |
| self.convblock3 = ConvBlock(in_channels=128, out_channels=256) | |
| def forward(self, input): | |
| x = self.convblock1(input) | |
| x = self.convblock2(x) | |
| x = self.convblock3(x) | |
| return x | |
| class AttentionRefinementModule(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.sigmoid = nn.Sigmoid() | |
| self.in_channels = in_channels | |
| self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) | |
| def forward(self, input): | |
| # global average pooling | |
| x = self.avgpool(input) | |
| assert self.in_channels == x.size(1), 'in_channels and out_channels should all be {}'.format(x.size(1)) | |
| x = self.conv(x) | |
| x = self.sigmoid(self.bn(x)) | |
| # x = self.sigmoid(x) | |
| # channels of input and x should be same | |
| x = torch.mul(input, x) | |
| return x | |
| class FeatureFusionModule(torch.nn.Module): | |
| def __init__(self, num_classes, in_channels): | |
| super().__init__() | |
| # self.in_channels = input_1.channels + input_2.channels | |
| # resnet101 3328 = 256(from spatial path) + 1024(from context path) + 2048(from context path) | |
| # resnet18 1024 = 256(from spatial path) + 256(from context path) + 512(from context path) | |
| self.in_channels = in_channels | |
| self.convblock = ConvBlock(in_channels=self.in_channels, out_channels=num_classes, stride=1) | |
| self.conv1 = nn.Conv2d(num_classes, num_classes, kernel_size=1) | |
| self.relu = nn.ReLU() | |
| self.conv2 = nn.Conv2d(num_classes, num_classes, kernel_size=1) | |
| self.sigmoid = nn.Sigmoid() | |
| self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) | |
| def forward(self, input_1, input_2): | |
| x = torch.cat((input_1, input_2), dim=1) | |
| assert self.in_channels == x.size(1), 'in_channels of ConvBlock should be {}'.format(x.size(1)) | |
| feature = self.convblock(x) | |
| x = self.avgpool(feature) | |
| x = self.relu(self.conv1(x)) | |
| x = self.sigmoid(self.conv2(x)) | |
| x = torch.mul(feature, x) | |
| x = torch.add(x, feature) | |
| return x | |
| class BiSeNet(torch.nn.Module): | |
| def __init__(self, num_classes, context_path): | |
| super().__init__() | |
| # build spatial path | |
| self.saptial_path = Spatial_path() | |
| # build context path | |
| self.context_path = build_contextpath(name=context_path) | |
| # build attention refinement module for resnet 101 | |
| if context_path == 'resnet101': | |
| self.attention_refinement_module1 = AttentionRefinementModule(1024, 1024) | |
| self.attention_refinement_module2 = AttentionRefinementModule(2048, 2048) | |
| # supervision block | |
| self.supervision1 = nn.Conv2d(in_channels=1024, out_channels=num_classes, kernel_size=1) | |
| self.supervision2 = nn.Conv2d(in_channels=2048, out_channels=num_classes, kernel_size=1) | |
| # build feature fusion module | |
| self.feature_fusion_module = FeatureFusionModule(num_classes, 3328) | |
| elif context_path == 'resnet18': | |
| # build attention refinement module for resnet 18 | |
| self.attention_refinement_module1 = AttentionRefinementModule(256, 256) | |
| self.attention_refinement_module2 = AttentionRefinementModule(512, 512) | |
| # supervision block | |
| self.supervision1 = nn.Conv2d(in_channels=256, out_channels=num_classes, kernel_size=1) | |
| self.supervision2 = nn.Conv2d(in_channels=512, out_channels=num_classes, kernel_size=1) | |
| # build feature fusion module | |
| self.feature_fusion_module = FeatureFusionModule(num_classes, 1024) | |
| else: | |
| print('Error: unspport context_path network \n') | |
| # build final convolution | |
| self.conv = nn.Conv2d(in_channels=num_classes, out_channels=num_classes, kernel_size=1) | |
| self.init_weight() | |
| self.mul_lr = [] | |
| self.mul_lr.append(self.saptial_path) | |
| self.mul_lr.append(self.attention_refinement_module1) | |
| self.mul_lr.append(self.attention_refinement_module2) | |
| self.mul_lr.append(self.supervision1) | |
| self.mul_lr.append(self.supervision2) | |
| self.mul_lr.append(self.feature_fusion_module) | |
| self.mul_lr.append(self.conv) | |
| def init_weight(self): | |
| for name, m in self.named_modules(): | |
| if 'context_path' not in name: | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.eps = 1e-5 | |
| m.momentum = 0.1 | |
| nn.init.constant_(m.weight, 1) | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, input): | |
| # output of spatial path | |
| sx = self.saptial_path(input) | |
| # output of context path | |
| cx1, cx2, tail = self.context_path(input) | |
| cx1 = self.attention_refinement_module1(cx1) | |
| cx2 = self.attention_refinement_module2(cx2) | |
| cx2 = torch.mul(cx2, tail) | |
| # upsampling | |
| cx1 = torch.nn.functional.interpolate(cx1, size=sx.size()[-2:], mode='bilinear') | |
| cx2 = torch.nn.functional.interpolate(cx2, size=sx.size()[-2:], mode='bilinear') | |
| cx = torch.cat((cx1, cx2), dim=1) | |
| if self.training == True: | |
| cx1_sup = self.supervision1(cx1) | |
| cx2_sup = self.supervision2(cx2) | |
| cx1_sup = torch.nn.functional.interpolate(cx1_sup, size=input.size()[-2:], mode='bilinear') | |
| cx2_sup = torch.nn.functional.interpolate(cx2_sup, size=input.size()[-2:], mode='bilinear') | |
| # output of feature fusion module | |
| result = self.feature_fusion_module(sx, cx) | |
| # upsampling | |
| result = torch.nn.functional.interpolate(result, scale_factor=8, mode='bilinear') | |
| result = self.conv(result) | |
| if self.training == True: | |
| return result, cx1_sup, cx2_sup | |
| return result | |