| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
|
|
| import numpy as np |
|
|
| import os |
| import sys |
|
|
| __dir__ = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.append(__dir__) |
| sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) |
|
|
| os.environ["FLAGS_allocator_strategy"] = "auto_growth" |
| import cv2 |
| import json |
| import paddle |
|
|
| from ppocr.data import create_operators, transform |
| from ppocr.modeling.architectures import build_model |
| from ppocr.postprocess import build_post_process |
| from ppocr.utils.save_load import load_model |
| from ppocr.utils.visual import draw_ser_results |
| from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps |
| import tools.program as program |
|
|
|
|
| def to_tensor(data): |
| import numbers |
| from collections import defaultdict |
|
|
| data_dict = defaultdict(list) |
| to_tensor_idxs = [] |
|
|
| for idx, v in enumerate(data): |
| if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): |
| if idx not in to_tensor_idxs: |
| to_tensor_idxs.append(idx) |
| data_dict[idx].append(v) |
| for idx in to_tensor_idxs: |
| data_dict[idx] = paddle.to_tensor(data_dict[idx]) |
| return list(data_dict.values()) |
|
|
|
|
| class SerPredictor(object): |
| def __init__(self, config): |
| global_config = config["Global"] |
| self.algorithm = config["Architecture"]["algorithm"] |
|
|
| |
| self.post_process_class = build_post_process( |
| config["PostProcess"], global_config |
| ) |
|
|
| |
| self.model = build_model(config["Architecture"]) |
|
|
| load_model(config, self.model, model_type=config["Architecture"]["model_type"]) |
|
|
| from paddleocr import PaddleOCR |
|
|
| self.ocr_engine = PaddleOCR( |
| use_angle_cls=False, |
| show_log=False, |
| rec_model_dir=global_config.get("kie_rec_model_dir", None), |
| det_model_dir=global_config.get("kie_det_model_dir", None), |
| use_gpu=global_config["use_gpu"], |
| ) |
|
|
| |
| transforms = [] |
| for op in config["Eval"]["dataset"]["transforms"]: |
| op_name = list(op)[0] |
| if "Label" in op_name: |
| op[op_name]["ocr_engine"] = self.ocr_engine |
| elif op_name == "KeepKeys": |
| op[op_name]["keep_keys"] = [ |
| "input_ids", |
| "bbox", |
| "attention_mask", |
| "token_type_ids", |
| "image", |
| "labels", |
| "segment_offset_id", |
| "ocr_info", |
| "entities", |
| ] |
|
|
| transforms.append(op) |
| if config["Global"].get("infer_mode", None) is None: |
| global_config["infer_mode"] = True |
| self.ops = create_operators( |
| config["Eval"]["dataset"]["transforms"], global_config |
| ) |
| self.model.eval() |
|
|
| def __call__(self, data): |
| with open(data["img_path"], "rb") as f: |
| img = f.read() |
| data["image"] = img |
| batch = transform(data, self.ops) |
| batch = to_tensor(batch) |
| preds = self.model(batch) |
|
|
| post_result = self.post_process_class( |
| preds, segment_offset_ids=batch[6], ocr_infos=batch[7] |
| ) |
| return post_result, batch |
|
|
|
|
| if __name__ == "__main__": |
| config, device, logger, vdl_writer = program.preprocess() |
| os.makedirs(config["Global"]["save_res_path"], exist_ok=True) |
|
|
| ser_engine = SerPredictor(config) |
|
|
| if config["Global"].get("infer_mode", None) is False: |
| data_dir = config["Eval"]["dataset"]["data_dir"] |
| with open(config["Global"]["infer_img"], "rb") as f: |
| infer_imgs = f.readlines() |
| else: |
| infer_imgs = get_image_file_list(config["Global"]["infer_img"]) |
|
|
| with open( |
| os.path.join(config["Global"]["save_res_path"], "infer_results.txt"), |
| "w", |
| encoding="utf-8", |
| ) as fout: |
| for idx, info in enumerate(infer_imgs): |
| if config["Global"].get("infer_mode", None) is False: |
| data_line = info.decode("utf-8") |
| substr = data_line.strip("\n").split("\t") |
| img_path = os.path.join(data_dir, substr[0]) |
| data = {"img_path": img_path, "label": substr[1]} |
| else: |
| img_path = info |
| data = {"img_path": img_path} |
|
|
| save_img_path = os.path.join( |
| config["Global"]["save_res_path"], |
| os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg", |
| ) |
|
|
| result, _ = ser_engine(data) |
| result = result[0] |
| fout.write( |
| img_path |
| + "\t" |
| + json.dumps( |
| { |
| "ocr_info": result, |
| }, |
| ensure_ascii=False, |
| ) |
| + "\n" |
| ) |
| img_res = draw_ser_results(img_path, result) |
| cv2.imwrite(save_img_path, img_res) |
|
|
| logger.info( |
| "process: [{}/{}], save result to {}".format( |
| idx, len(infer_imgs), save_img_path |
| ) |
| ) |
|
|