| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import numpy as np |
| import os |
| import random |
| from paddle.io import Dataset |
| import json |
| from copy import deepcopy |
|
|
| from .imaug import transform, create_operators |
|
|
|
|
| class PubTabDataSet(Dataset): |
| def __init__(self, config, mode, logger, seed=None): |
| super(PubTabDataSet, self).__init__() |
| self.logger = logger |
|
|
| global_config = config["Global"] |
| dataset_config = config[mode]["dataset"] |
| loader_config = config[mode]["loader"] |
|
|
| label_file_list = dataset_config.pop("label_file_list") |
| data_source_num = len(label_file_list) |
| ratio_list = dataset_config.get("ratio_list", [1.0]) |
| if isinstance(ratio_list, (float, int)): |
| ratio_list = [float(ratio_list)] * int(data_source_num) |
|
|
| assert ( |
| len(ratio_list) == data_source_num |
| ), "The length of ratio_list should be the same as the file_list." |
|
|
| self.data_dir = dataset_config["data_dir"] |
| self.do_shuffle = loader_config["shuffle"] |
|
|
| self.seed = seed |
| self.mode = mode.lower() |
| logger.info("Initialize indexes of datasets:%s" % label_file_list) |
| self.data_lines = self.get_image_info_list(label_file_list, ratio_list) |
| |
|
|
| if mode.lower() == "train" and self.do_shuffle: |
| self.shuffle_data_random() |
| self.ops = create_operators(dataset_config["transforms"], global_config) |
| self.need_reset = True in [x < 1 for x in ratio_list] |
|
|
| def get_image_info_list(self, file_list, ratio_list): |
| if isinstance(file_list, str): |
| file_list = [file_list] |
| data_lines = [] |
| for idx, file in enumerate(file_list): |
| with open(file, "rb") as f: |
| lines = f.readlines() |
| if self.mode == "train" or ratio_list[idx] < 1.0: |
| random.seed(self.seed) |
| lines = random.sample(lines, round(len(lines) * ratio_list[idx])) |
| data_lines.extend(lines) |
| return data_lines |
|
|
| def check(self, max_text_length): |
| data_lines = [] |
| for line in self.data_lines: |
| data_line = line.decode("utf-8").strip("\n") |
| info = json.loads(data_line) |
| file_name = info["filename"] |
| cells = info["html"]["cells"].copy() |
| structure = info["html"]["structure"]["tokens"].copy() |
|
|
| img_path = os.path.join(self.data_dir, file_name) |
| if not os.path.exists(img_path): |
| self.logger.warning("{} does not exist!".format(img_path)) |
| continue |
| if len(structure) == 0 or len(structure) > max_text_length: |
| continue |
| |
| data_lines.append(line) |
| self.data_lines = data_lines |
|
|
| def shuffle_data_random(self): |
| if self.do_shuffle: |
| random.seed(self.seed) |
| random.shuffle(self.data_lines) |
| return |
|
|
| def __getitem__(self, idx): |
| try: |
| data_line = self.data_lines[idx] |
| data_line = data_line.decode("utf-8").strip("\n") |
| info = json.loads(data_line) |
| file_name = info["filename"] |
| cells = info["html"]["cells"].copy() |
| structure = info["html"]["structure"]["tokens"].copy() |
|
|
| img_path = os.path.join(self.data_dir, file_name) |
| if not os.path.exists(img_path): |
| raise Exception("{} does not exist!".format(img_path)) |
| data = { |
| "img_path": img_path, |
| "cells": cells, |
| "structure": structure, |
| "file_name": file_name, |
| } |
|
|
| with open(data["img_path"], "rb") as f: |
| img = f.read() |
| data["image"] = img |
| outs = transform(data, self.ops) |
| except: |
| import traceback |
|
|
| err = traceback.format_exc() |
| self.logger.error( |
| "When parsing line {}, error happened with msg: {}".format( |
| data_line, err |
| ) |
| ) |
| outs = None |
| if outs is None: |
| rnd_idx = ( |
| np.random.randint(self.__len__()) |
| if self.mode == "train" |
| else (idx + 1) % self.__len__() |
| ) |
| return self.__getitem__(rnd_idx) |
| return outs |
|
|
| def __len__(self): |
| return len(self.data_lines) |
|
|