| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import numpy as np |
| import io |
| import os |
| from paddle.io import Dataset |
| import lmdb |
| import cv2 |
| import string |
| import pickle |
| from PIL import Image |
|
|
| from .imaug import transform, create_operators |
|
|
|
|
| class LMDBDataSet(Dataset): |
| def __init__(self, config, mode, logger, seed=None): |
| super(LMDBDataSet, self).__init__() |
|
|
| global_config = config["Global"] |
| dataset_config = config[mode]["dataset"] |
| loader_config = config[mode]["loader"] |
| batch_size = loader_config["batch_size_per_card"] |
| data_dir = dataset_config["data_dir"] |
| self.do_shuffle = loader_config["shuffle"] |
|
|
| self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) |
| logger.info("Initialize indexes of datasets:%s" % data_dir) |
| self.data_idx_order_list = self.dataset_traversal() |
| if self.do_shuffle: |
| np.random.shuffle(self.data_idx_order_list) |
| self.ops = create_operators(dataset_config["transforms"], global_config) |
| self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 1) |
|
|
| ratio_list = dataset_config.get("ratio_list", [1.0]) |
| self.need_reset = True in [x < 1 for x in ratio_list] |
|
|
| def load_hierarchical_lmdb_dataset(self, data_dir): |
| lmdb_sets = {} |
| dataset_idx = 0 |
| for dirpath, dirnames, filenames in os.walk(data_dir + "/"): |
| if not dirnames: |
| env = lmdb.open( |
| dirpath, |
| max_readers=32, |
| readonly=True, |
| lock=False, |
| readahead=False, |
| meminit=False, |
| ) |
| txn = env.begin(write=False) |
| num_samples = int(txn.get("num-samples".encode())) |
| lmdb_sets[dataset_idx] = { |
| "dirpath": dirpath, |
| "env": env, |
| "txn": txn, |
| "num_samples": num_samples, |
| } |
| dataset_idx += 1 |
| return lmdb_sets |
|
|
| def dataset_traversal(self): |
| lmdb_num = len(self.lmdb_sets) |
| total_sample_num = 0 |
| for lno in range(lmdb_num): |
| total_sample_num += self.lmdb_sets[lno]["num_samples"] |
| data_idx_order_list = np.zeros((total_sample_num, 2)) |
| beg_idx = 0 |
| for lno in range(lmdb_num): |
| tmp_sample_num = self.lmdb_sets[lno]["num_samples"] |
| end_idx = beg_idx + tmp_sample_num |
| data_idx_order_list[beg_idx:end_idx, 0] = lno |
| data_idx_order_list[beg_idx:end_idx, 1] = list(range(tmp_sample_num)) |
| data_idx_order_list[beg_idx:end_idx, 1] += 1 |
| beg_idx = beg_idx + tmp_sample_num |
| return data_idx_order_list |
|
|
| def get_img_data(self, value): |
| """get_img_data""" |
| if not value: |
| return None |
| imgdata = np.frombuffer(value, dtype="uint8") |
| if imgdata is None: |
| return None |
| imgori = cv2.imdecode(imgdata, 1) |
| if imgori is None: |
| return None |
| return imgori |
|
|
| def get_ext_data(self): |
| ext_data_num = 0 |
| for op in self.ops: |
| if hasattr(op, "ext_data_num"): |
| ext_data_num = getattr(op, "ext_data_num") |
| break |
| load_data_ops = self.ops[: self.ext_op_transform_idx] |
| ext_data = [] |
|
|
| while len(ext_data) < ext_data_num: |
| lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(len(self))] |
| lmdb_idx = int(lmdb_idx) |
| file_idx = int(file_idx) |
| sample_info = self.get_lmdb_sample_info( |
| self.lmdb_sets[lmdb_idx]["txn"], file_idx |
| ) |
| if sample_info is None: |
| continue |
| img, label = sample_info |
| data = {"image": img, "label": label} |
| data = transform(data, load_data_ops) |
| if data is None: |
| continue |
| ext_data.append(data) |
| return ext_data |
|
|
| def get_lmdb_sample_info(self, txn, index): |
| label_key = "label-%09d".encode() % index |
| label = txn.get(label_key) |
| if label is None: |
| return None |
| label = label.decode("utf-8") |
| img_key = "image-%09d".encode() % index |
| imgbuf = txn.get(img_key) |
| return imgbuf, label |
|
|
| def __getitem__(self, idx): |
| lmdb_idx, file_idx = self.data_idx_order_list[idx] |
| lmdb_idx = int(lmdb_idx) |
| file_idx = int(file_idx) |
| sample_info = self.get_lmdb_sample_info( |
| self.lmdb_sets[lmdb_idx]["txn"], file_idx |
| ) |
| if sample_info is None: |
| return self.__getitem__(np.random.randint(self.__len__())) |
| img, label = sample_info |
| data = {"image": img, "label": label} |
| data["ext_data"] = self.get_ext_data() |
| outs = transform(data, self.ops) |
| if outs is None: |
| return self.__getitem__(np.random.randint(self.__len__())) |
| return outs |
|
|
| def __len__(self): |
| return self.data_idx_order_list.shape[0] |
|
|
|
|
| class LMDBDataSetSR(LMDBDataSet): |
| def buf2PIL(self, txn, key, type="RGB"): |
| imgbuf = txn.get(key) |
| buf = io.BytesIO() |
| buf.write(imgbuf) |
| buf.seek(0) |
| im = Image.open(buf).convert(type) |
| return im |
|
|
| def str_filt(self, str_, voc_type): |
| alpha_dict = { |
| "digit": string.digits, |
| "lower": string.digits + string.ascii_lowercase, |
| "upper": string.digits + string.ascii_letters, |
| "all": string.digits + string.ascii_letters + string.punctuation, |
| } |
| if voc_type == "lower": |
| str_ = str_.lower() |
| for char in str_: |
| if char not in alpha_dict[voc_type]: |
| str_ = str_.replace(char, "") |
| return str_ |
|
|
| def get_lmdb_sample_info(self, txn, index): |
| self.voc_type = "upper" |
| self.max_len = 100 |
| self.test = False |
| label_key = b"label-%09d" % index |
| word = str(txn.get(label_key).decode()) |
| img_HR_key = b"image_hr-%09d" % index |
| img_lr_key = b"image_lr-%09d" % index |
| try: |
| img_HR = self.buf2PIL(txn, img_HR_key, "RGB") |
| img_lr = self.buf2PIL(txn, img_lr_key, "RGB") |
| except IOError or len(word) > self.max_len: |
| return self[index + 1] |
| label_str = self.str_filt(word, self.voc_type) |
| return img_HR, img_lr, label_str |
|
|
| def __getitem__(self, idx): |
| lmdb_idx, file_idx = self.data_idx_order_list[idx] |
| lmdb_idx = int(lmdb_idx) |
| file_idx = int(file_idx) |
| sample_info = self.get_lmdb_sample_info( |
| self.lmdb_sets[lmdb_idx]["txn"], file_idx |
| ) |
| if sample_info is None: |
| return self.__getitem__(np.random.randint(self.__len__())) |
| img_HR, img_lr, label_str = sample_info |
| data = {"image_hr": img_HR, "image_lr": img_lr, "label": label_str} |
| outs = transform(data, self.ops) |
| if outs is None: |
| return self.__getitem__(np.random.randint(self.__len__())) |
| return outs |
|
|
|
|
| class LMDBDataSetTableMaster(LMDBDataSet): |
| def load_hierarchical_lmdb_dataset(self, data_dir): |
| lmdb_sets = {} |
| dataset_idx = 0 |
| env = lmdb.open( |
| data_dir, |
| max_readers=32, |
| readonly=True, |
| lock=False, |
| readahead=False, |
| meminit=False, |
| ) |
| txn = env.begin(write=False) |
| num_samples = int(pickle.loads(txn.get(b"__len__"))) |
| lmdb_sets[dataset_idx] = { |
| "dirpath": data_dir, |
| "env": env, |
| "txn": txn, |
| "num_samples": num_samples, |
| } |
| return lmdb_sets |
|
|
| def get_img_data(self, value): |
| """get_img_data""" |
| if not value: |
| return None |
| imgdata = np.frombuffer(value, dtype="uint8") |
| if imgdata is None: |
| return None |
| imgori = cv2.imdecode(imgdata, 1) |
| if imgori is None: |
| return None |
| return imgori |
|
|
| def get_lmdb_sample_info(self, txn, index): |
| def convert_bbox(bbox_str_list): |
| bbox_list = [] |
| for bbox_str in bbox_str_list: |
| bbox_list.append(int(bbox_str)) |
| return bbox_list |
|
|
| try: |
| data = pickle.loads(txn.get(str(index).encode("utf8"))) |
| except: |
| return None |
|
|
| |
| file_name = data[0] |
| bytes = data[1] |
| info_lines = data[2] |
| |
| raw_data = info_lines.strip().split("\n") |
| raw_name, text = ( |
| raw_data[0], |
| raw_data[1], |
| ) |
| text = text.split(",") |
| bbox_str_list = raw_data[2:] |
| bbox_split = "," |
| bboxes = [ |
| {"bbox": convert_bbox(bsl.strip().split(bbox_split)), "tokens": ["1", "2"]} |
| for bsl in bbox_str_list |
| ] |
|
|
| |
| |
|
|
| line_info = {} |
| line_info["file_name"] = file_name |
| line_info["structure"] = text |
| line_info["cells"] = bboxes |
| line_info["image"] = bytes |
| return line_info |
|
|
| def __getitem__(self, idx): |
| lmdb_idx, file_idx = self.data_idx_order_list[idx] |
| lmdb_idx = int(lmdb_idx) |
| file_idx = int(file_idx) |
| data = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]["txn"], file_idx) |
| if data is None: |
| return self.__getitem__(np.random.randint(self.__len__())) |
| outs = transform(data, self.ops) |
| if outs is None: |
| return self.__getitem__(np.random.randint(self.__len__())) |
| return outs |
|
|
| def __len__(self): |
| return self.data_idx_order_list.shape[0] |
|
|