Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| import shutil | |
| import time | |
| from argparse import ArgumentParser | |
| from itertools import compress | |
| import mmcv | |
| from mmcv.utils import ProgressBar | |
| from mmocr.apis import init_detector, model_inference | |
| from mmocr.core.evaluation.ocr_metric import eval_ocr_metric | |
| from mmocr.datasets import build_dataset # noqa: F401 | |
| from mmocr.models import build_detector # noqa: F401 | |
| from mmocr.utils import get_root_logger, list_from_file, list_to_file | |
| def save_results(img_paths, pred_labels, gt_labels, res_dir): | |
| """Save predicted results to txt file. | |
| Args: | |
| img_paths (list[str]) | |
| pred_labels (list[str]) | |
| gt_labels (list[str]) | |
| res_dir (str) | |
| """ | |
| assert len(img_paths) == len(pred_labels) == len(gt_labels) | |
| corrects = [pred == gt for pred, gt in zip(pred_labels, gt_labels)] | |
| wrongs = [not c for c in corrects] | |
| lines = [ | |
| f'{img} {pred} {gt}' | |
| for img, pred, gt in zip(img_paths, pred_labels, gt_labels) | |
| ] | |
| list_to_file(osp.join(res_dir, 'results.txt'), lines) | |
| list_to_file(osp.join(res_dir, 'correct.txt'), compress(lines, corrects)) | |
| list_to_file(osp.join(res_dir, 'wrong.txt'), compress(lines, wrongs)) | |
| def main(): | |
| parser = ArgumentParser() | |
| parser.add_argument('img_root_path', type=str, help='Image root path') | |
| parser.add_argument('img_list', type=str, help='Image path list file') | |
| parser.add_argument('config', type=str, help='Config file') | |
| parser.add_argument('checkpoint', type=str, help='Checkpoint file') | |
| parser.add_argument( | |
| '--out_dir', type=str, default='./results', help='Dir to save results') | |
| parser.add_argument( | |
| '--show', action='store_true', help='show image or save') | |
| parser.add_argument( | |
| '--device', default='cuda:0', help='Device used for inference.') | |
| args = parser.parse_args() | |
| # init the logger before other steps | |
| timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) | |
| log_file = osp.join(args.out_dir, f'{timestamp}.log') | |
| logger = get_root_logger(log_file=log_file, log_level='INFO') | |
| # build the model from a config file and a checkpoint file | |
| model = init_detector(args.config, args.checkpoint, device=args.device) | |
| if hasattr(model, 'module'): | |
| model = model.module | |
| # Start Inference | |
| out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') | |
| mmcv.mkdir_or_exist(out_vis_dir) | |
| correct_vis_dir = osp.join(args.out_dir, 'correct') | |
| mmcv.mkdir_or_exist(correct_vis_dir) | |
| wrong_vis_dir = osp.join(args.out_dir, 'wrong') | |
| mmcv.mkdir_or_exist(wrong_vis_dir) | |
| img_paths, pred_labels, gt_labels = [], [], [] | |
| lines = list_from_file(args.img_list) | |
| progressbar = ProgressBar(task_num=len(lines)) | |
| num_gt_label = 0 | |
| for line in lines: | |
| progressbar.update() | |
| item_list = line.strip().split() | |
| img_file = item_list[0] | |
| gt_label = '' | |
| if len(item_list) >= 2: | |
| gt_label = item_list[1] | |
| num_gt_label += 1 | |
| img_path = osp.join(args.img_root_path, img_file) | |
| if not osp.exists(img_path): | |
| raise FileNotFoundError(img_path) | |
| # Test a single image | |
| result = model_inference(model, img_path) | |
| pred_label = result['text'] | |
| out_img_name = '_'.join(img_file.split('/')) | |
| out_file = osp.join(out_vis_dir, out_img_name) | |
| kwargs_dict = { | |
| 'gt_label': gt_label, | |
| 'show': args.show, | |
| 'out_file': '' if args.show else out_file | |
| } | |
| model.show_result(img_path, result, **kwargs_dict) | |
| if gt_label != '': | |
| if gt_label == pred_label: | |
| dst_file = osp.join(correct_vis_dir, out_img_name) | |
| else: | |
| dst_file = osp.join(wrong_vis_dir, out_img_name) | |
| shutil.copy(out_file, dst_file) | |
| img_paths.append(img_path) | |
| gt_labels.append(gt_label) | |
| pred_labels.append(pred_label) | |
| # Save results | |
| save_results(img_paths, pred_labels, gt_labels, args.out_dir) | |
| if num_gt_label == len(pred_labels): | |
| # eval | |
| eval_results = eval_ocr_metric(pred_labels, gt_labels) | |
| logger.info('\n' + '-' * 100) | |
| info = ('eval on testset with img_root_path ' | |
| f'{args.img_root_path} and img_list {args.img_list}\n') | |
| logger.info(info) | |
| logger.info(eval_results) | |
| print(f'\nInference done, and results saved in {args.out_dir}\n') | |
| if __name__ == '__main__': | |
| main() | |