| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
|
|
| import numpy as np |
|
|
| import os |
| import sys |
|
|
| __dir__ = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.append(__dir__) |
| sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) |
|
|
| os.environ["FLAGS_allocator_strategy"] = "auto_growth" |
| import cv2 |
| import json |
| import paddle |
| import paddle.distributed as dist |
|
|
| from ppocr.data import create_operators, transform |
| from ppocr.modeling.architectures import build_model |
| from ppocr.postprocess import build_post_process |
| from ppocr.utils.save_load import load_model |
| from ppocr.utils.visual import draw_re_results |
| from ppocr.utils.logging import get_logger |
| from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict |
| from tools.program import ArgsParser, load_config, merge_config |
| from tools.infer_kie_token_ser import SerPredictor |
|
|
|
|
| class ReArgsParser(ArgsParser): |
| def __init__(self): |
| super(ReArgsParser, self).__init__() |
| self.add_argument( |
| "-c_ser", "--config_ser", help="ser configuration file to use" |
| ) |
| self.add_argument( |
| "-o_ser", "--opt_ser", nargs="+", help="set ser configuration options " |
| ) |
|
|
| def parse_args(self, argv=None): |
| args = super(ReArgsParser, self).parse_args(argv) |
| assert ( |
| args.config_ser is not None |
| ), "Please specify --config_ser=ser_configure_file_path." |
| args.opt_ser = self._parse_opt(args.opt_ser) |
| return args |
|
|
|
|
| def make_input(ser_inputs, ser_results): |
| entities_labels = {"HEADER": 0, "QUESTION": 1, "ANSWER": 2} |
| batch_size, max_seq_len = ser_inputs[0].shape[:2] |
| entities = ser_inputs[8][0] |
| ser_results = ser_results[0] |
| assert len(entities) == len(ser_results) |
|
|
| |
| start = [] |
| end = [] |
| label = [] |
| entity_idx_dict = {} |
| for i, (res, entity) in enumerate(zip(ser_results, entities)): |
| if res["pred"] == "O": |
| continue |
| entity_idx_dict[len(start)] = i |
| start.append(entity["start"]) |
| end.append(entity["end"]) |
| label.append(entities_labels[res["pred"]]) |
|
|
| entities = np.full([max_seq_len + 1, 3], fill_value=-1, dtype=np.int64) |
| entities[0, 0] = len(start) |
| entities[1 : len(start) + 1, 0] = start |
| entities[0, 1] = len(end) |
| entities[1 : len(end) + 1, 1] = end |
| entities[0, 2] = len(label) |
| entities[1 : len(label) + 1, 2] = label |
|
|
| |
| head = [] |
| tail = [] |
| for i in range(len(label)): |
| for j in range(len(label)): |
| if label[i] == 1 and label[j] == 2: |
| head.append(i) |
| tail.append(j) |
|
|
| relations = np.full([len(head) + 1, 2], fill_value=-1, dtype=np.int64) |
| relations[0, 0] = len(head) |
| relations[1 : len(head) + 1, 0] = head |
| relations[0, 1] = len(tail) |
| relations[1 : len(tail) + 1, 1] = tail |
|
|
| entities = np.expand_dims(entities, axis=0) |
| entities = np.repeat(entities, batch_size, axis=0) |
| relations = np.expand_dims(relations, axis=0) |
| relations = np.repeat(relations, batch_size, axis=0) |
|
|
| |
| if isinstance(ser_inputs[0], paddle.Tensor): |
| entities = paddle.to_tensor(entities) |
| relations = paddle.to_tensor(relations) |
| ser_inputs = ser_inputs[:5] + [entities, relations] |
|
|
| entity_idx_dict_batch = [] |
| for b in range(batch_size): |
| entity_idx_dict_batch.append(entity_idx_dict) |
| return ser_inputs, entity_idx_dict_batch |
|
|
|
|
| class SerRePredictor(object): |
| def __init__(self, config, ser_config): |
| global_config = config["Global"] |
| if "infer_mode" in global_config: |
| ser_config["Global"]["infer_mode"] = global_config["infer_mode"] |
|
|
| self.ser_engine = SerPredictor(ser_config) |
|
|
| |
|
|
| |
| self.post_process_class = build_post_process( |
| config["PostProcess"], global_config |
| ) |
|
|
| |
| self.model = build_model(config["Architecture"]) |
|
|
| load_model(config, self.model, model_type=config["Architecture"]["model_type"]) |
|
|
| self.model.eval() |
|
|
| def __call__(self, data): |
| ser_results, ser_inputs = self.ser_engine(data) |
| re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) |
| if self.model.backbone.use_visual_backbone is False: |
| re_input.pop(4) |
| preds = self.model(re_input) |
| post_result = self.post_process_class( |
| preds, ser_results=ser_results, entity_idx_dict_batch=entity_idx_dict_batch |
| ) |
| return post_result |
|
|
|
|
| def preprocess(): |
| FLAGS = ReArgsParser().parse_args() |
| config = load_config(FLAGS.config) |
| config = merge_config(config, FLAGS.opt) |
|
|
| ser_config = load_config(FLAGS.config_ser) |
| ser_config = merge_config(ser_config, FLAGS.opt_ser) |
|
|
| logger = get_logger() |
|
|
| |
| use_gpu = config["Global"]["use_gpu"] |
|
|
| device = "gpu:{}".format(dist.ParallelEnv().dev_id) if use_gpu else "cpu" |
| device = paddle.set_device(device) |
|
|
| logger.info("{} re config {}".format("*" * 10, "*" * 10)) |
| print_dict(config, logger) |
| logger.info("\n") |
| logger.info("{} ser config {}".format("*" * 10, "*" * 10)) |
| print_dict(ser_config, logger) |
| logger.info("train with paddle {} and device {}".format(paddle.__version__, device)) |
| return config, ser_config, device, logger |
|
|
|
|
| if __name__ == "__main__": |
| config, ser_config, device, logger = preprocess() |
| os.makedirs(config["Global"]["save_res_path"], exist_ok=True) |
|
|
| ser_re_engine = SerRePredictor(config, ser_config) |
|
|
| if config["Global"].get("infer_mode", None) is False: |
| data_dir = config["Eval"]["dataset"]["data_dir"] |
| with open(config["Global"]["infer_img"], "rb") as f: |
| infer_imgs = f.readlines() |
| else: |
| infer_imgs = get_image_file_list(config["Global"]["infer_img"]) |
|
|
| with open( |
| os.path.join(config["Global"]["save_res_path"], "infer_results.txt"), |
| "w", |
| encoding="utf-8", |
| ) as fout: |
| for idx, info in enumerate(infer_imgs): |
| if config["Global"].get("infer_mode", None) is False: |
| data_line = info.decode("utf-8") |
| substr = data_line.strip("\n").split("\t") |
| img_path = os.path.join(data_dir, substr[0]) |
| data = {"img_path": img_path, "label": substr[1]} |
| else: |
| img_path = info |
| data = {"img_path": img_path} |
|
|
| save_img_path = os.path.join( |
| config["Global"]["save_res_path"], |
| os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg", |
| ) |
|
|
| result = ser_re_engine(data) |
| result = result[0] |
| fout.write(img_path + "\t" + json.dumps(result, ensure_ascii=False) + "\n") |
| img_res = draw_re_results(img_path, result) |
| cv2.imwrite(save_img_path, img_res) |
|
|
| logger.info( |
| "process: [{}/{}], save result to {}".format( |
| idx, len(infer_imgs), save_img_path |
| ) |
| ) |
|
|