Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import argparse | |
| import ast | |
| import os | |
| import os.path as osp | |
| import mmcv | |
| import numpy as np | |
| import torch | |
| from mmcv import Config | |
| from mmcv.image import tensor2imgs | |
| from mmcv.parallel import MMDataParallel | |
| from mmcv.runner import load_checkpoint | |
| from mmocr.datasets import build_dataloader, build_dataset | |
| from mmocr.models import build_detector | |
| def save_results(model, img_meta, gt_bboxes, result, out_dir): | |
| assert 'filename' in img_meta, ('Please add "filename" ' | |
| 'to "meta_keys" in config.') | |
| assert 'ori_texts' in img_meta, ('Please add "ori_texts" ' | |
| 'to "meta_keys" in config.') | |
| out_json_file = osp.join(out_dir, | |
| osp.basename(img_meta['filename']) + '.json') | |
| idx_to_cls = {} | |
| if model.module.class_list is not None: | |
| for line in mmcv.list_from_file(model.module.class_list): | |
| class_idx, class_label = line.strip().split() | |
| idx_to_cls[int(class_idx)] = class_label | |
| json_result = [{ | |
| 'text': | |
| text, | |
| 'box': | |
| box, | |
| 'pred': | |
| idx_to_cls.get( | |
| pred.argmax(-1).cpu().item(), | |
| pred.argmax(-1).cpu().item()), | |
| 'conf': | |
| pred.max(-1)[0].cpu().item() | |
| } for text, box, pred in zip(img_meta['ori_texts'], gt_bboxes, | |
| result['nodes'])] | |
| mmcv.dump(json_result, out_json_file) | |
| def test(model, data_loader, show=False, out_dir=None): | |
| model.eval() | |
| results = [] | |
| dataset = data_loader.dataset | |
| prog_bar = mmcv.ProgressBar(len(dataset)) | |
| for i, data in enumerate(data_loader): | |
| with torch.no_grad(): | |
| result = model(return_loss=False, rescale=True, **data) | |
| batch_size = len(result) | |
| if show or out_dir: | |
| img_tensor = data['img'].data[0] | |
| img_metas = data['img_metas'].data[0] | |
| if np.prod(img_tensor.shape) == 0: | |
| imgs = [mmcv.imread(m['filename']) for m in img_metas] | |
| else: | |
| imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg']) | |
| assert len(imgs) == len(img_metas) | |
| gt_bboxes = [data['gt_bboxes'].data[0][0].numpy().tolist()] | |
| for i, (img, img_meta) in enumerate(zip(imgs, img_metas)): | |
| if 'img_shape' in img_meta: | |
| h, w, _ = img_meta['img_shape'] | |
| img_show = img[:h, :w, :] | |
| else: | |
| img_show = img | |
| if out_dir: | |
| out_file = osp.join(out_dir, | |
| osp.basename(img_meta['filename'])) | |
| else: | |
| out_file = None | |
| model.module.show_result( | |
| img_show, | |
| result[i], | |
| gt_bboxes[i], | |
| show=show, | |
| out_file=out_file) | |
| if out_dir: | |
| save_results(model, img_meta, gt_bboxes[i], result[i], | |
| out_dir) | |
| for _ in range(batch_size): | |
| prog_bar.update() | |
| return results | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description='MMOCR visualize for kie model.') | |
| parser.add_argument('config', help='Test config file path.') | |
| parser.add_argument('checkpoint', help='Checkpoint file.') | |
| parser.add_argument('--show', action='store_true', help='Show results.') | |
| parser.add_argument( | |
| '--out-dir', | |
| help='Directory where the output images and results will be saved.') | |
| parser.add_argument('--local_rank', type=int, default=0) | |
| parser.add_argument( | |
| '--device', | |
| help='Use int or int list for gpu. Default is cpu', | |
| default=None) | |
| args = parser.parse_args() | |
| if 'LOCAL_RANK' not in os.environ: | |
| os.environ['LOCAL_RANK'] = str(args.local_rank) | |
| return args | |
| def main(): | |
| args = parse_args() | |
| assert args.show or args.out_dir, ('Please specify at least one ' | |
| 'operation (show the results / save )' | |
| 'the results with the argument ' | |
| '"--show" or "--out-dir".') | |
| device = args.device | |
| if device is not None: | |
| device = ast.literal_eval(f'[{device}]') | |
| cfg = Config.fromfile(args.config) | |
| # import modules from string list. | |
| if cfg.get('custom_imports', None): | |
| from mmcv.utils import import_modules_from_strings | |
| import_modules_from_strings(**cfg['custom_imports']) | |
| # set cudnn_benchmark | |
| if cfg.get('cudnn_benchmark', False): | |
| torch.backends.cudnn.benchmark = True | |
| distributed = False | |
| # build the dataloader | |
| dataset = build_dataset(cfg.data.test) | |
| data_loader = build_dataloader( | |
| dataset, | |
| samples_per_gpu=1, | |
| workers_per_gpu=cfg.data.workers_per_gpu, | |
| dist=distributed, | |
| shuffle=False) | |
| # build the model and load checkpoint | |
| cfg.model.train_cfg = None | |
| model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) | |
| load_checkpoint(model, args.checkpoint, map_location='cpu') | |
| model = MMDataParallel(model, device_ids=device) | |
| test(model, data_loader, args.show, args.out_dir) | |
| if __name__ == '__main__': | |
| main() | |