| import pytest |
| import torch |
| from mmcv.cnn import ConvModule |
| from mmcv.utils.parrots_wrapper import _BatchNorm |
| from torch import nn |
|
|
| from mmseg.models.backbones.unet import (BasicConvBlock, DeconvModule, |
| InterpConv, UNet, UpConvBlock) |
|
|
|
|
| def check_norm_state(modules, train_state): |
| """Check if norm layer is in correct train state.""" |
| for mod in modules: |
| if isinstance(mod, _BatchNorm): |
| if mod.training != train_state: |
| return False |
| return True |
|
|
|
|
| def test_unet_basic_conv_block(): |
| with pytest.raises(AssertionError): |
| |
| dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) |
| BasicConvBlock(64, 64, dcn=dcn) |
|
|
| with pytest.raises(AssertionError): |
| |
| plugins = [ |
| dict( |
| cfg=dict(type='ContextBlock', ratio=1. / 16), |
| position='after_conv3') |
| ] |
| BasicConvBlock(64, 64, plugins=plugins) |
|
|
| with pytest.raises(AssertionError): |
| |
| plugins = [ |
| dict( |
| cfg=dict( |
| type='GeneralizedAttention', |
| spatial_range=-1, |
| num_heads=8, |
| attention_type='0010', |
| kv_stride=2), |
| position='after_conv2') |
| ] |
| BasicConvBlock(64, 64, plugins=plugins) |
|
|
| |
| block = BasicConvBlock(16, 16, with_cp=True) |
| assert block.with_cp |
| x = torch.randn(1, 16, 64, 64, requires_grad=True) |
| x_out = block(x) |
| assert x_out.shape == torch.Size([1, 16, 64, 64]) |
|
|
| block = BasicConvBlock(16, 16, with_cp=False) |
| assert not block.with_cp |
| x = torch.randn(1, 16, 64, 64) |
| x_out = block(x) |
| assert x_out.shape == torch.Size([1, 16, 64, 64]) |
|
|
| |
| block = BasicConvBlock(16, 16, stride=2) |
| x = torch.randn(1, 16, 64, 64) |
| x_out = block(x) |
| assert x_out.shape == torch.Size([1, 16, 32, 32]) |
|
|
| |
| block = BasicConvBlock(16, 64, num_convs=3, dilation=3) |
| assert block.convs[0].conv.in_channels == 16 |
| assert block.convs[0].conv.out_channels == 64 |
| assert block.convs[0].conv.kernel_size == (3, 3) |
| assert block.convs[0].conv.dilation == (1, 1) |
| assert block.convs[0].conv.padding == (1, 1) |
|
|
| assert block.convs[1].conv.in_channels == 64 |
| assert block.convs[1].conv.out_channels == 64 |
| assert block.convs[1].conv.kernel_size == (3, 3) |
| assert block.convs[1].conv.dilation == (3, 3) |
| assert block.convs[1].conv.padding == (3, 3) |
|
|
| assert block.convs[2].conv.in_channels == 64 |
| assert block.convs[2].conv.out_channels == 64 |
| assert block.convs[2].conv.kernel_size == (3, 3) |
| assert block.convs[2].conv.dilation == (3, 3) |
| assert block.convs[2].conv.padding == (3, 3) |
|
|
|
|
| def test_deconv_module(): |
| with pytest.raises(AssertionError): |
| |
| |
| DeconvModule(64, 32, kernel_size=1, scale_factor=2) |
|
|
| with pytest.raises(AssertionError): |
| |
| |
| DeconvModule(64, 32, kernel_size=3, scale_factor=2) |
|
|
| with pytest.raises(AssertionError): |
| |
| |
| DeconvModule(64, 32, kernel_size=5, scale_factor=4) |
|
|
| |
| block = DeconvModule(64, 32, with_cp=True) |
| assert block.with_cp |
| x = torch.randn(1, 64, 128, 128, requires_grad=True) |
| x_out = block(x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| block = DeconvModule(64, 32, with_cp=False) |
| assert not block.with_cp |
| x = torch.randn(1, 64, 128, 128) |
| x_out = block(x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| |
| x = torch.randn(1, 64, 64, 64) |
| block = DeconvModule(64, 32, kernel_size=2, scale_factor=2) |
| x_out = block(x) |
| assert x_out.shape == torch.Size([1, 32, 128, 128]) |
|
|
| block = DeconvModule(64, 32, kernel_size=6, scale_factor=2) |
| x_out = block(x) |
| assert x_out.shape == torch.Size([1, 32, 128, 128]) |
|
|
| |
| x = torch.randn(1, 64, 64, 64) |
| block = DeconvModule(64, 32, kernel_size=4, scale_factor=4) |
| x_out = block(x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| block = DeconvModule(64, 32, kernel_size=6, scale_factor=4) |
| x_out = block(x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
|
|
| def test_interp_conv(): |
| |
| block = InterpConv(64, 32, with_cp=True) |
| assert block.with_cp |
| x = torch.randn(1, 64, 128, 128, requires_grad=True) |
| x_out = block(x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| block = InterpConv(64, 32, with_cp=False) |
| assert not block.with_cp |
| x = torch.randn(1, 64, 128, 128) |
| x_out = block(x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| |
| block = InterpConv(64, 32, conv_first=False) |
| x = torch.randn(1, 64, 128, 128) |
| x_out = block(x) |
| assert isinstance(block.interp_upsample[0], nn.Upsample) |
| assert isinstance(block.interp_upsample[1], ConvModule) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| |
| block = InterpConv(64, 32, conv_first=True) |
| x = torch.randn(1, 64, 128, 128) |
| x_out = block(x) |
| assert isinstance(block.interp_upsample[0], ConvModule) |
| assert isinstance(block.interp_upsample[1], nn.Upsample) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| |
| block = InterpConv( |
| 64, |
| 32, |
| conv_first=False, |
| upsampe_cfg=dict(scale_factor=2, mode='bilinear', align_corners=False)) |
| x = torch.randn(1, 64, 128, 128) |
| x_out = block(x) |
| assert isinstance(block.interp_upsample[0], nn.Upsample) |
| assert isinstance(block.interp_upsample[1], ConvModule) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| assert block.interp_upsample[0].mode == 'bilinear' |
|
|
| |
| block = InterpConv( |
| 64, |
| 32, |
| conv_first=False, |
| upsampe_cfg=dict(scale_factor=2, mode='nearest')) |
| x = torch.randn(1, 64, 128, 128) |
| x_out = block(x) |
| assert isinstance(block.interp_upsample[0], nn.Upsample) |
| assert isinstance(block.interp_upsample[1], ConvModule) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
| assert block.interp_upsample[0].mode == 'nearest' |
|
|
|
|
| def test_up_conv_block(): |
| with pytest.raises(AssertionError): |
| |
| dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) |
| UpConvBlock(BasicConvBlock, 64, 32, 32, dcn=dcn) |
|
|
| with pytest.raises(AssertionError): |
| |
| plugins = [ |
| dict( |
| cfg=dict(type='ContextBlock', ratio=1. / 16), |
| position='after_conv3') |
| ] |
| UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins) |
|
|
| with pytest.raises(AssertionError): |
| |
| plugins = [ |
| dict( |
| cfg=dict( |
| type='GeneralizedAttention', |
| spatial_range=-1, |
| num_heads=8, |
| attention_type='0010', |
| kv_stride=2), |
| position='after_conv2') |
| ] |
| UpConvBlock(BasicConvBlock, 64, 32, 32, plugins=plugins) |
|
|
| |
| block = UpConvBlock(BasicConvBlock, 64, 32, 32, with_cp=True) |
| skip_x = torch.randn(1, 32, 256, 256, requires_grad=True) |
| x = torch.randn(1, 64, 128, 128, requires_grad=True) |
| x_out = block(skip_x, x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| |
| |
| block = UpConvBlock( |
| BasicConvBlock, 64, 32, 32, upsample_cfg=dict(type='InterpConv')) |
| skip_x = torch.randn(1, 32, 256, 256) |
| x = torch.randn(1, 64, 128, 128) |
| x_out = block(skip_x, x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| |
| |
| block = UpConvBlock(BasicConvBlock, 64, 32, 32, upsample_cfg=None) |
| skip_x = torch.randn(1, 32, 256, 256) |
| x = torch.randn(1, 64, 256, 256) |
| x_out = block(skip_x, x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| |
| |
| block = UpConvBlock( |
| BasicConvBlock, |
| 64, |
| 32, |
| 32, |
| upsample_cfg=dict( |
| type='InterpConv', |
| upsampe_cfg=dict( |
| scale_factor=2, mode='bilinear', align_corners=False))) |
| skip_x = torch.randn(1, 32, 256, 256) |
| x = torch.randn(1, 64, 128, 128) |
| x_out = block(skip_x, x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| |
| |
| block = UpConvBlock( |
| BasicConvBlock, |
| 64, |
| 32, |
| 32, |
| upsample_cfg=dict(type='DeconvModule', kernel_size=4, scale_factor=2)) |
| skip_x = torch.randn(1, 32, 256, 256) |
| x = torch.randn(1, 64, 128, 128) |
| x_out = block(skip_x, x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| |
| block = UpConvBlock( |
| conv_block=BasicConvBlock, |
| in_channels=64, |
| skip_channels=32, |
| out_channels=32, |
| num_convs=3, |
| dilation=3, |
| upsample_cfg=dict( |
| type='InterpConv', |
| upsampe_cfg=dict( |
| scale_factor=2, mode='bilinear', align_corners=False))) |
| skip_x = torch.randn(1, 32, 256, 256) |
| x = torch.randn(1, 64, 128, 128) |
| x_out = block(skip_x, x) |
| assert x_out.shape == torch.Size([1, 32, 256, 256]) |
|
|
| assert block.conv_block.convs[0].conv.in_channels == 64 |
| assert block.conv_block.convs[0].conv.out_channels == 32 |
| assert block.conv_block.convs[0].conv.kernel_size == (3, 3) |
| assert block.conv_block.convs[0].conv.dilation == (1, 1) |
| assert block.conv_block.convs[0].conv.padding == (1, 1) |
|
|
| assert block.conv_block.convs[1].conv.in_channels == 32 |
| assert block.conv_block.convs[1].conv.out_channels == 32 |
| assert block.conv_block.convs[1].conv.kernel_size == (3, 3) |
| assert block.conv_block.convs[1].conv.dilation == (3, 3) |
| assert block.conv_block.convs[1].conv.padding == (3, 3) |
|
|
| assert block.conv_block.convs[2].conv.in_channels == 32 |
| assert block.conv_block.convs[2].conv.out_channels == 32 |
| assert block.conv_block.convs[2].conv.kernel_size == (3, 3) |
| assert block.conv_block.convs[2].conv.dilation == (3, 3) |
| assert block.conv_block.convs[2].conv.padding == (3, 3) |
|
|
| assert block.upsample.interp_upsample[1].conv.in_channels == 64 |
| assert block.upsample.interp_upsample[1].conv.out_channels == 32 |
| assert block.upsample.interp_upsample[1].conv.kernel_size == (1, 1) |
| assert block.upsample.interp_upsample[1].conv.dilation == (1, 1) |
| assert block.upsample.interp_upsample[1].conv.padding == (0, 0) |
|
|
|
|
| def test_unet(): |
| with pytest.raises(AssertionError): |
| |
| dcn = dict(type='DCN', deform_groups=1, fallback_on_stride=False) |
| UNet(3, 64, 5, dcn=dcn) |
|
|
| with pytest.raises(AssertionError): |
| |
| plugins = [ |
| dict( |
| cfg=dict(type='ContextBlock', ratio=1. / 16), |
| position='after_conv3') |
| ] |
| UNet(3, 64, 5, plugins=plugins) |
|
|
| with pytest.raises(AssertionError): |
| |
| plugins = [ |
| dict( |
| cfg=dict( |
| type='GeneralizedAttention', |
| spatial_range=-1, |
| num_heads=8, |
| attention_type='0010', |
| kv_stride=2), |
| position='after_conv2') |
| ] |
| UNet(3, 64, 5, plugins=plugins) |
|
|
| with pytest.raises(AssertionError): |
| |
| |
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=4, |
| strides=(1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2), |
| downsamples=(True, True, True), |
| enc_dilations=(1, 1, 1, 1), |
| dec_dilations=(1, 1, 1)) |
| x = torch.randn(2, 3, 65, 65) |
| unet(x) |
|
|
| with pytest.raises(AssertionError): |
| |
| |
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, True), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| x = torch.randn(2, 3, 65, 65) |
| unet(x) |
|
|
| with pytest.raises(AssertionError): |
| |
| |
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| x = torch.randn(2, 3, 65, 65) |
| unet(x) |
|
|
| with pytest.raises(AssertionError): |
| |
| |
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 2, 2, 2, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| x = torch.randn(2, 3, 65, 65) |
| unet(x) |
|
|
| with pytest.raises(AssertionError): |
| |
| |
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=6, |
| strides=(1, 1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2, 2), |
| downsamples=(True, True, True, True, True), |
| enc_dilations=(1, 1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1, 1)) |
| x = torch.randn(2, 3, 65, 65) |
| unet(x) |
|
|
| with pytest.raises(AssertionError): |
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, True), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| x = torch.randn(2, 3, 64, 64) |
| unet(x) |
|
|
| with pytest.raises(AssertionError): |
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, True), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| x = torch.randn(2, 3, 64, 64) |
| unet(x) |
|
|
| with pytest.raises(AssertionError): |
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2, 2), |
| downsamples=(True, True, True, True), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| x = torch.randn(2, 3, 64, 64) |
| unet(x) |
|
|
| with pytest.raises(AssertionError): |
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| x = torch.randn(2, 3, 64, 64) |
| unet(x) |
|
|
| with pytest.raises(AssertionError): |
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, True), |
| enc_dilations=(1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| x = torch.randn(2, 3, 64, 64) |
| unet(x) |
|
|
| with pytest.raises(AssertionError): |
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, True), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1, 1)) |
| x = torch.randn(2, 3, 64, 64) |
| unet(x) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, True), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1), |
| norm_eval=True) |
| unet.train() |
| assert check_norm_state(unet.modules(), False) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, True), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1), |
| norm_eval=False) |
| unet.train() |
| assert check_norm_state(unet.modules(), True) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, True), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
|
|
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 8, 8]) |
| assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
|
|
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
| assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 2, 2, 2, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
|
|
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
| assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, False, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
|
|
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
| assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 2, 2, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, False, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
|
|
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
| assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
|
|
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
| assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, False, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
|
|
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
| assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, False, False, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
|
|
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 64, 64]) |
| assert x_outs[1].shape == torch.Size([2, 512, 64, 64]) |
| assert x_outs[2].shape == torch.Size([2, 256, 64, 64]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 1, 1, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(False, False, False, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
|
|
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 128, 128]) |
| assert x_outs[1].shape == torch.Size([2, 512, 128, 128]) |
| assert x_outs[2].shape == torch.Size([2, 256, 128, 128]) |
| assert x_outs[3].shape == torch.Size([2, 128, 128, 128]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 2, 2, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, True), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| print(unet) |
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 8, 8]) |
| assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 2, 2, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| print(unet) |
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
| assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 2, 2, 2, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, True, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| print(unet) |
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 16, 16]) |
| assert x_outs[1].shape == torch.Size([2, 512, 16, 16]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 2, 2, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, False, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| print(unet) |
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
| assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|
| |
| unet = UNet( |
| in_channels=3, |
| base_channels=64, |
| num_stages=5, |
| strides=(1, 2, 2, 1, 1), |
| enc_num_convs=(2, 2, 2, 2, 2), |
| dec_num_convs=(2, 2, 2, 2), |
| downsamples=(True, True, False, False), |
| enc_dilations=(1, 1, 1, 1, 1), |
| dec_dilations=(1, 1, 1, 1)) |
| unet.init_weights(pretrained=None) |
| print(unet) |
| x = torch.randn(2, 3, 128, 128) |
| x_outs = unet(x) |
| assert x_outs[0].shape == torch.Size([2, 1024, 32, 32]) |
| assert x_outs[1].shape == torch.Size([2, 512, 32, 32]) |
| assert x_outs[2].shape == torch.Size([2, 256, 32, 32]) |
| assert x_outs[3].shape == torch.Size([2, 128, 64, 64]) |
| assert x_outs[4].shape == torch.Size([2, 64, 128, 128]) |
|
|