|
|
| import torch |
|
|
| try: |
| import mmcv as mmcv |
| from mmcv.parallel import collate, scatter |
| from mmcv.runner import load_checkpoint |
| from mmseg.datasets.pipelines import Compose |
| from mmseg.models import build_segmentor |
| except ImportError: |
| import annotator.mmpkg.mmcv as mmcv |
| from annotator.mmpkg.mmcv.parallel import collate, scatter |
| from annotator.mmpkg.mmcv.runner import load_checkpoint |
| from annotator.mmpkg.mmseg.datasets.pipelines import Compose |
| from annotator.mmpkg.mmseg.models import build_segmentor |
| |
| def init_segmentor(config, checkpoint=None, device='cuda:0'): |
| """Initialize a segmentor from config file. |
| |
| Args: |
| config (str or :obj:`mmcv.Config`): Config file path or the config |
| object. |
| checkpoint (str, optional): Checkpoint path. If left as None, the model |
| will not load any weights. |
| device (str, optional) CPU/CUDA device option. Default 'cuda:0'. |
| Use 'cpu' for loading model on CPU. |
| Returns: |
| nn.Module: The constructed segmentor. |
| """ |
| if isinstance(config, str): |
| config = mmcv.Config.fromfile(config) |
| elif not isinstance(config, mmcv.Config): |
| raise TypeError('config must be a filename or Config object, ' |
| 'but got {}'.format(type(config))) |
| config.model.pretrained = None |
| config.model.train_cfg = None |
| model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) |
| if checkpoint is not None: |
| checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') |
| model.CLASSES = checkpoint['meta']['CLASSES'] |
| model.PALETTE = checkpoint['meta']['PALETTE'] |
| model.cfg = config |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| class LoadImage: |
| """A simple pipeline to load image.""" |
|
|
| def __call__(self, results): |
| """Call function to load images into results. |
| |
| Args: |
| results (dict): A result dict contains the file name |
| of the image to be read. |
| |
| Returns: |
| dict: ``results`` will be returned containing loaded image. |
| """ |
|
|
| if isinstance(results['img'], str): |
| results['filename'] = results['img'] |
| results['ori_filename'] = results['img'] |
| else: |
| results['filename'] = None |
| results['ori_filename'] = None |
| img = mmcv.imread(results['img']) |
| results['img'] = img |
| results['img_shape'] = img.shape |
| results['ori_shape'] = img.shape |
| return results |
|
|
|
|
| def inference_segmentor(model, img): |
| """Inference image(s) with the segmentor. |
| |
| Args: |
| model (nn.Module): The loaded segmentor. |
| imgs (str/ndarray or list[str/ndarray]): Either image files or loaded |
| images. |
| |
| Returns: |
| (list[Tensor]): The segmentation result. |
| """ |
| cfg = model.cfg |
| device = next(model.parameters()).device |
| |
| test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] |
| test_pipeline = Compose(test_pipeline) |
| |
| data = dict(img=img) |
| data = test_pipeline(data) |
| data = collate([data], samples_per_gpu=1) |
| if next(model.parameters()).is_cuda: |
| |
| data = scatter(data, [device])[0] |
| else: |
| data['img_metas'] = [i.data[0] for i in data['img_metas']] |
|
|
| data['img'] = [x.to(device) for x in data['img']] |
|
|
| |
| with torch.no_grad(): |
| result = model(return_loss=False, rescale=True, **data) |
| return result |
|
|
|
|
| def show_result_pyplot(model, |
| img, |
| result, |
| palette=None, |
| fig_size=(15, 10), |
| opacity=0.5, |
| title='', |
| block=True): |
| """Visualize the segmentation results on the image. |
| |
| Args: |
| model (nn.Module): The loaded segmentor. |
| img (str or np.ndarray): Image filename or loaded image. |
| result (list): The segmentation result. |
| palette (list[list[int]]] | None): The palette of segmentation |
| map. If None is given, random palette will be generated. |
| Default: None |
| fig_size (tuple): Figure size of the pyplot figure. |
| opacity(float): Opacity of painted segmentation map. |
| Default 0.5. |
| Must be in (0, 1] range. |
| title (str): The title of pyplot figure. |
| Default is ''. |
| block (bool): Whether to block the pyplot figure. |
| Default is True. |
| """ |
| if hasattr(model, 'module'): |
| model = model.module |
| img = model.show_result( |
| img, result, palette=palette, show=False, opacity=opacity) |
| |
| |
| |
| |
| |
| return mmcv.bgr2rgb(img) |
|
|