| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import os |
| import cv2 |
| import random |
| import numpy as np |
| import paddle |
| import importlib.util |
| import sys |
| import subprocess |
|
|
|
|
| def print_dict(d, logger, delimiter=0): |
| """ |
| Recursively visualize a dict and |
| indenting acrrording by the relationship of keys. |
| """ |
| for k, v in sorted(d.items()): |
| if isinstance(v, dict): |
| logger.info("{}{} : ".format(delimiter * " ", str(k))) |
| print_dict(v, logger, delimiter + 4) |
| elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): |
| logger.info("{}{} : ".format(delimiter * " ", str(k))) |
| for value in v: |
| print_dict(value, logger, delimiter + 4) |
| else: |
| logger.info("{}{} : {}".format(delimiter * " ", k, v)) |
|
|
|
|
| def get_check_global_params(mode): |
| check_params = [ |
| "use_gpu", |
| "max_text_length", |
| "image_shape", |
| "image_shape", |
| "character_type", |
| "loss_type", |
| ] |
| if mode == "train_eval": |
| check_params = check_params + [ |
| "train_batch_size_per_card", |
| "test_batch_size_per_card", |
| ] |
| elif mode == "test": |
| check_params = check_params + ["test_batch_size_per_card"] |
| return check_params |
|
|
|
|
| def _check_image_file(path): |
| img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"} |
| return any([path.lower().endswith(e) for e in img_end]) |
|
|
|
|
| def get_image_file_list(img_file, infer_list=None): |
| imgs_lists = [] |
| if infer_list and not os.path.exists(infer_list): |
| raise Exception("not found infer list {}".format(infer_list)) |
| if infer_list: |
| with open(infer_list, "r") as f: |
| lines = f.readlines() |
| for line in lines: |
| image_path = line.strip().split("\t")[0] |
| image_path = os.path.join(img_file, image_path) |
| imgs_lists.append(image_path) |
| else: |
| if img_file is None or not os.path.exists(img_file): |
| raise Exception("not found any img file in {}".format(img_file)) |
|
|
| img_end = {"jpg", "bmp", "png", "jpeg", "rgb", "tif", "tiff", "gif", "pdf"} |
| if os.path.isfile(img_file) and _check_image_file(img_file): |
| imgs_lists.append(img_file) |
| elif os.path.isdir(img_file): |
| for single_file in os.listdir(img_file): |
| file_path = os.path.join(img_file, single_file) |
| if os.path.isfile(file_path) and _check_image_file(file_path): |
| imgs_lists.append(file_path) |
|
|
| if len(imgs_lists) == 0: |
| raise Exception("not found any img file in {}".format(img_file)) |
| imgs_lists = sorted(imgs_lists) |
| return imgs_lists |
|
|
|
|
| def binarize_img(img): |
| if len(img.shape) == 3 and img.shape[2] == 3: |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
| |
| _, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
| img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) |
| return img |
|
|
|
|
| def alpha_to_color(img, alpha_color=(255, 255, 255)): |
| if len(img.shape) == 3 and img.shape[2] == 4: |
| B, G, R, A = cv2.split(img) |
| alpha = A / 255 |
|
|
| R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8) |
| G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8) |
| B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8) |
|
|
| img = cv2.merge((B, G, R)) |
| return img |
|
|
|
|
| def check_and_read(img_path): |
| if os.path.basename(img_path)[-3:].lower() == "gif": |
| gif = cv2.VideoCapture(img_path) |
| ret, frame = gif.read() |
| if not ret: |
| logger = logging.getLogger("ppocr") |
| logger.info("Cannot read {}. This gif image maybe corrupted.") |
| return None, False |
| if len(frame.shape) == 2 or frame.shape[-1] == 1: |
| frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB) |
| imgvalue = frame[:, :, ::-1] |
| return imgvalue, True, False |
| elif os.path.basename(img_path)[-3:].lower() == "pdf": |
| from paddle.utils import try_import |
|
|
| fitz = try_import("fitz") |
| from PIL import Image |
|
|
| imgs = [] |
| with fitz.open(img_path) as pdf: |
| for pg in range(0, pdf.page_count): |
| page = pdf[pg] |
| mat = fitz.Matrix(2, 2) |
| pm = page.get_pixmap(matrix=mat, alpha=False) |
|
|
| |
| if pm.width > 2000 or pm.height > 2000: |
| pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False) |
|
|
| img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples) |
| img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) |
| imgs.append(img) |
| return imgs, False, True |
| return None, False, False |
|
|
|
|
| def load_vqa_bio_label_maps(label_map_path): |
| with open(label_map_path, "r", encoding="utf-8") as fin: |
| lines = fin.readlines() |
| old_lines = [line.strip() for line in lines] |
| lines = ["O"] |
| for line in old_lines: |
| |
| if line.upper() in ["OTHER", "OTHERS", "IGNORE"]: |
| continue |
| lines.append(line) |
| labels = ["O"] |
| for line in lines[1:]: |
| labels.append("B-" + line) |
| labels.append("I-" + line) |
| label2id_map = {label.upper(): idx for idx, label in enumerate(labels)} |
| id2label_map = {idx: label.upper() for idx, label in enumerate(labels)} |
| return label2id_map, id2label_map |
|
|
|
|
| def set_seed(seed=1024): |
| random.seed(seed) |
| np.random.seed(seed) |
| paddle.seed(seed) |
|
|
|
|
| def check_install(module_name, install_name): |
| spec = importlib.util.find_spec(module_name) |
| if spec is None: |
| print(f"Warning! The {module_name} module is NOT installed") |
| print( |
| f"Try install {module_name} module automatically. You can also try to install manually by pip install {install_name}." |
| ) |
| python = sys.executable |
| try: |
| subprocess.check_call( |
| [python, "-m", "pip", "install", install_name], |
| stdout=subprocess.DEVNULL, |
| ) |
| print(f"The {module_name} module is now installed") |
| except subprocess.CalledProcessError as exc: |
| raise Exception(f"Install {module_name} failed, please install manually") |
| else: |
| print(f"{module_name} has been installed.") |
|
|
|
|
| class AverageMeter: |
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| """reset""" |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| """update""" |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|