| import glob |
| import os |
| from os.path import dirname, exists, isdir, join, relpath |
|
|
| from mmcv import Config |
| from torch import nn |
|
|
| from mmseg.models import build_segmentor |
|
|
|
|
| def _get_config_directory(): |
| """Find the predefined segmentor config directory.""" |
| try: |
| |
| repo_dpath = dirname(dirname(__file__)) |
| except NameError: |
| |
| import mmseg |
| repo_dpath = dirname(dirname(mmseg.__file__)) |
| config_dpath = join(repo_dpath, 'configs') |
| if not exists(config_dpath): |
| raise Exception('Cannot find config path') |
| return config_dpath |
|
|
|
|
| def test_config_build_segmentor(): |
| """Test that all segmentation models defined in the configs can be |
| initialized.""" |
| config_dpath = _get_config_directory() |
| print('Found config_dpath = {!r}'.format(config_dpath)) |
|
|
| config_fpaths = [] |
| |
| for sub_folder in os.listdir(config_dpath): |
| if isdir(sub_folder): |
| config_fpaths.append( |
| list(glob.glob(join(config_dpath, sub_folder, '*.py')))[0]) |
| config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1] |
| config_names = [relpath(p, config_dpath) for p in config_fpaths] |
|
|
| print('Using {} config files'.format(len(config_names))) |
|
|
| for config_fname in config_names: |
| config_fpath = join(config_dpath, config_fname) |
| config_mod = Config.fromfile(config_fpath) |
|
|
| config_mod.model |
| print('Building segmentor, config_fpath = {!r}'.format(config_fpath)) |
|
|
| |
| if 'pretrained' in config_mod.model: |
| config_mod.model['pretrained'] = None |
|
|
| print('building {}'.format(config_fname)) |
| segmentor = build_segmentor(config_mod.model) |
| assert segmentor is not None |
|
|
| head_config = config_mod.model['decode_head'] |
| _check_decode_head(head_config, segmentor.decode_head) |
|
|
|
|
| def test_config_data_pipeline(): |
| """Test whether the data pipeline is valid and can process corner cases. |
| |
| CommandLine: |
| xdoctest -m tests/test_config.py test_config_build_data_pipeline |
| """ |
| from mmcv import Config |
| from mmseg.datasets.pipelines import Compose |
| import numpy as np |
|
|
| config_dpath = _get_config_directory() |
| print('Found config_dpath = {!r}'.format(config_dpath)) |
|
|
| import glob |
| config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py'))) |
| config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1] |
| config_names = [relpath(p, config_dpath) for p in config_fpaths] |
|
|
| print('Using {} config files'.format(len(config_names))) |
|
|
| for config_fname in config_names: |
| config_fpath = join(config_dpath, config_fname) |
| print( |
| 'Building data pipeline, config_fpath = {!r}'.format(config_fpath)) |
| config_mod = Config.fromfile(config_fpath) |
|
|
| |
| load_img_pipeline = config_mod.train_pipeline.pop(0) |
| to_float32 = load_img_pipeline.get('to_float32', False) |
| config_mod.train_pipeline.pop(0) |
| config_mod.test_pipeline.pop(0) |
|
|
| train_pipeline = Compose(config_mod.train_pipeline) |
| test_pipeline = Compose(config_mod.test_pipeline) |
|
|
| img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8) |
| if to_float32: |
| img = img.astype(np.float32) |
| seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8) |
|
|
| results = dict( |
| filename='test_img.png', |
| ori_filename='test_img.png', |
| img=img, |
| img_shape=img.shape, |
| ori_shape=img.shape, |
| gt_semantic_seg=seg) |
| results['seg_fields'] = ['gt_semantic_seg'] |
|
|
| print('Test training data pipeline: \n{!r}'.format(train_pipeline)) |
| output_results = train_pipeline(results) |
| assert output_results is not None |
|
|
| results = dict( |
| filename='test_img.png', |
| ori_filename='test_img.png', |
| img=img, |
| img_shape=img.shape, |
| ori_shape=img.shape, |
| ) |
| print('Test testing data pipeline: \n{!r}'.format(test_pipeline)) |
| output_results = test_pipeline(results) |
| assert output_results is not None |
|
|
|
|
| def _check_decode_head(decode_head_cfg, decode_head): |
| if isinstance(decode_head_cfg, list): |
| assert isinstance(decode_head, nn.ModuleList) |
| assert len(decode_head_cfg) == len(decode_head) |
| num_heads = len(decode_head) |
| for i in range(num_heads): |
| _check_decode_head(decode_head_cfg[i], decode_head[i]) |
| return |
| |
| assert decode_head_cfg['type'] == decode_head.__class__.__name__ |
|
|
| assert decode_head_cfg['type'] == decode_head.__class__.__name__ |
|
|
| in_channels = decode_head_cfg.in_channels |
| input_transform = decode_head.input_transform |
| assert input_transform in ['resize_concat', 'multiple_select', None] |
| if input_transform is not None: |
| assert isinstance(in_channels, (list, tuple)) |
| assert isinstance(decode_head.in_index, (list, tuple)) |
| assert len(in_channels) == len(decode_head.in_index) |
| elif input_transform == 'resize_concat': |
| assert sum(in_channels) == decode_head.in_channels |
| else: |
| assert isinstance(in_channels, int) |
| assert in_channels == decode_head.in_channels |
| assert isinstance(decode_head.in_index, int) |
|
|
| if decode_head_cfg['type'] == 'PointHead': |
| assert decode_head_cfg.channels+decode_head_cfg.num_classes == \ |
| decode_head.fc_seg.in_channels |
| assert decode_head.fc_seg.out_channels == decode_head_cfg.num_classes |
| else: |
| assert decode_head_cfg.channels == decode_head.conv_seg.in_channels |
| assert decode_head.conv_seg.out_channels == decode_head_cfg.num_classes |
|
|