| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| This code is refer from: |
| https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/match.py |
| """ |
|
|
| import os |
| import re |
| import cv2 |
| import glob |
| import copy |
| import math |
| import pickle |
| import numpy as np |
|
|
| from shapely.geometry import Polygon, MultiPoint |
|
|
| """ |
| Useful function in matching. |
| """ |
|
|
|
|
| def remove_empty_bboxes(bboxes): |
| """ |
| remove [0., 0., 0., 0.] in structure master bboxes. |
| len(bboxes.shape) must be 2. |
| :param bboxes: |
| :return: |
| """ |
| new_bboxes = [] |
| for bbox in bboxes: |
| if sum(bbox) == 0.0: |
| continue |
| new_bboxes.append(bbox) |
| return np.array(new_bboxes) |
|
|
|
|
| def xywh2xyxy(bboxes): |
| if len(bboxes.shape) == 1: |
| new_bboxes = np.empty_like(bboxes) |
| new_bboxes[0] = bboxes[0] - bboxes[2] / 2 |
| new_bboxes[1] = bboxes[1] - bboxes[3] / 2 |
| new_bboxes[2] = bboxes[0] + bboxes[2] / 2 |
| new_bboxes[3] = bboxes[1] + bboxes[3] / 2 |
| return new_bboxes |
| elif len(bboxes.shape) == 2: |
| new_bboxes = np.empty_like(bboxes) |
| new_bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2 |
| new_bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2 |
| new_bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] / 2 |
| new_bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] / 2 |
| return new_bboxes |
| else: |
| raise ValueError |
|
|
|
|
| def xyxy2xywh(bboxes): |
| if len(bboxes.shape) == 1: |
| new_bboxes = np.empty_like(bboxes) |
| new_bboxes[0] = bboxes[0] + (bboxes[2] - bboxes[0]) / 2 |
| new_bboxes[1] = bboxes[1] + (bboxes[3] - bboxes[1]) / 2 |
| new_bboxes[2] = bboxes[2] - bboxes[0] |
| new_bboxes[3] = bboxes[3] - bboxes[1] |
| return new_bboxes |
| elif len(bboxes.shape) == 2: |
| new_bboxes = np.empty_like(bboxes) |
| new_bboxes[:, 0] = bboxes[:, 0] + (bboxes[:, 2] - bboxes[:, 0]) / 2 |
| new_bboxes[:, 1] = bboxes[:, 1] + (bboxes[:, 3] - bboxes[:, 1]) / 2 |
| new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] |
| new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] |
| return new_bboxes |
| else: |
| raise ValueError |
|
|
|
|
| def pickle_load(path, prefix="end2end"): |
| if os.path.isfile(path): |
| data = pickle.load(open(path, "rb")) |
| elif os.path.isdir(path): |
| data = dict() |
| search_path = os.path.join(path, "{}_*.pkl".format(prefix)) |
| pkls = glob.glob(search_path) |
| for pkl in pkls: |
| this_data = pickle.load(open(pkl, "rb")) |
| data.update(this_data) |
| else: |
| raise ValueError |
| return data |
|
|
|
|
| def convert_coord(xyxy): |
| """ |
| Convert two points format to four points format. |
| :param xyxy: |
| :return: |
| """ |
| new_bbox = np.zeros([4, 2], dtype=np.float32) |
| new_bbox[0, 0], new_bbox[0, 1] = xyxy[0], xyxy[1] |
| new_bbox[1, 0], new_bbox[1, 1] = xyxy[2], xyxy[1] |
| new_bbox[2, 0], new_bbox[2, 1] = xyxy[2], xyxy[3] |
| new_bbox[3, 0], new_bbox[3, 1] = xyxy[0], xyxy[3] |
| return new_bbox |
|
|
|
|
| def cal_iou(bbox1, bbox2): |
| bbox1_poly = Polygon(bbox1).convex_hull |
| bbox2_poly = Polygon(bbox2).convex_hull |
| union_poly = np.concatenate((bbox1, bbox2)) |
|
|
| if not bbox1_poly.intersects(bbox2_poly): |
| iou = 0 |
| else: |
| inter_area = bbox1_poly.intersection(bbox2_poly).area |
| union_area = MultiPoint(union_poly).convex_hull.area |
| if union_area == 0: |
| iou = 0 |
| else: |
| iou = float(inter_area) / union_area |
| return iou |
|
|
|
|
| def cal_distance(p1, p2): |
| delta_x = p1[0] - p2[0] |
| delta_y = p1[1] - p2[1] |
| d = math.sqrt((delta_x**2) + (delta_y**2)) |
| return d |
|
|
|
|
| def is_inside(center_point, corner_point): |
| """ |
| Find if center_point inside the bbox(corner_point) or not. |
| :param center_point: center point (x, y) |
| :param corner_point: corner point ((x1,y1),(x2,y2)) |
| :return: |
| """ |
| x_flag = False |
| y_flag = False |
| if (center_point[0] >= corner_point[0][0]) and ( |
| center_point[0] <= corner_point[1][0] |
| ): |
| x_flag = True |
| if (center_point[1] >= corner_point[0][1]) and ( |
| center_point[1] <= corner_point[1][1] |
| ): |
| y_flag = True |
| if x_flag and y_flag: |
| return True |
| else: |
| return False |
|
|
|
|
| def find_no_match(match_list, all_end2end_nums, type="end2end"): |
| """ |
| Find out no match end2end bbox in previous match list. |
| :param match_list: matching pairs. |
| :param all_end2end_nums: numbers of end2end_xywh |
| :param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1. |
| :return: no match pse bbox index list |
| """ |
| if type == "end2end": |
| idx = 0 |
| elif type == "master": |
| idx = 1 |
| else: |
| raise ValueError |
|
|
| no_match_indexs = [] |
| |
| matched_bbox_indexs = [m[idx] for m in match_list] |
| for n in range(all_end2end_nums): |
| if n not in matched_bbox_indexs: |
| no_match_indexs.append(n) |
| return no_match_indexs |
|
|
|
|
| def is_abs_lower_than_threshold(this_bbox, target_bbox, threshold=3): |
| |
| delta = abs(this_bbox[1] - target_bbox[1]) |
| if delta < threshold: |
| return True |
| else: |
| return False |
|
|
|
|
| def sort_line_bbox(g, bg): |
| """ |
| Sorted the bbox in the same line(group) |
| compare coord 'x' value, where 'y' value is closed in the same group. |
| :param g: index in the same group |
| :param bg: bbox in the same group |
| :return: |
| """ |
|
|
| xs = [bg_item[0] for bg_item in bg] |
| xs_sorted = sorted(xs) |
|
|
| g_sorted = [None] * len(xs_sorted) |
| bg_sorted = [None] * len(xs_sorted) |
| for g_item, bg_item in zip(g, bg): |
| idx = xs_sorted.index(bg_item[0]) |
| bg_sorted[idx] = bg_item |
| g_sorted[idx] = g_item |
|
|
| return g_sorted, bg_sorted |
|
|
|
|
| def flatten(sorted_groups, sorted_bbox_groups): |
| idxs = [] |
| bboxes = [] |
| for group, bbox_group in zip(sorted_groups, sorted_bbox_groups): |
| for g, bg in zip(group, bbox_group): |
| idxs.append(g) |
| bboxes.append(bg) |
| return idxs, bboxes |
|
|
|
|
| def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes): |
| """ |
| This function will group the render end2end bboxes in row. |
| :param end2end_xywh_bboxes: |
| :param no_match_end2end_indexes: |
| :return: |
| """ |
| groups = [] |
| bbox_groups = [] |
| for index, end2end_xywh_bbox in zip(no_match_end2end_indexes, end2end_xywh_bboxes): |
| this_bbox = end2end_xywh_bbox |
| if len(groups) == 0: |
| groups.append([index]) |
| bbox_groups.append([this_bbox]) |
| else: |
| flag = False |
| for g, bg in zip(groups, bbox_groups): |
| |
| if is_abs_lower_than_threshold(this_bbox, bg[0]): |
| g.append(index) |
| bg.append(this_bbox) |
| flag = True |
| break |
| if not flag: |
| |
| groups.append([index]) |
| bbox_groups.append([this_bbox]) |
|
|
| |
| tmp_groups, tmp_bbox_groups = [], [] |
| for g, bg in zip(groups, bbox_groups): |
| g_sorted, bg_sorted = sort_line_bbox(g, bg) |
| tmp_groups.append(g_sorted) |
| tmp_bbox_groups.append(bg_sorted) |
|
|
| |
| sorted_groups = [None] * len(tmp_groups) |
| sorted_bbox_groups = [None] * len(tmp_bbox_groups) |
| ys = [bg[0][1] for bg in tmp_bbox_groups] |
| sorted_ys = sorted(ys) |
| for g, bg in zip(tmp_groups, tmp_bbox_groups): |
| idx = sorted_ys.index(bg[0][1]) |
| sorted_groups[idx] = g |
| sorted_bbox_groups[idx] = bg |
|
|
| |
| end2end_sorted_idx_list, end2end_sorted_bbox_list = flatten( |
| sorted_groups, sorted_bbox_groups |
| ) |
|
|
| return ( |
| end2end_sorted_idx_list, |
| end2end_sorted_bbox_list, |
| sorted_groups, |
| sorted_bbox_groups, |
| ) |
|
|
|
|
| def get_bboxes_list(end2end_result, structure_master_result): |
| """ |
| This function is use to convert end2end results and structure master results to |
| List of xyxy bbox format and List of xywh bbox format |
| :param end2end_result: bbox's format is xyxy |
| :param structure_master_result: bbox's format is xywh |
| :return: 4 kind list of bbox () |
| """ |
| |
| end2end_xyxy_list = [] |
| end2end_xywh_list = [] |
| for end2end_item in end2end_result: |
| src_bbox = end2end_item["bbox"] |
| end2end_xyxy_list.append(src_bbox) |
| xywh_bbox = xyxy2xywh(src_bbox) |
| end2end_xywh_list.append(xywh_bbox) |
| end2end_xyxy_bboxes = np.array(end2end_xyxy_list) |
| end2end_xywh_bboxes = np.array(end2end_xywh_list) |
|
|
| |
| src_bboxes = structure_master_result["bbox"] |
| src_bboxes = remove_empty_bboxes(src_bboxes) |
| structure_master_xyxy_bboxes = src_bboxes |
| xywh_bbox = xyxy2xywh(src_bboxes) |
| structure_master_xywh_bboxes = xywh_bbox |
|
|
| return ( |
| end2end_xyxy_bboxes, |
| end2end_xywh_bboxes, |
| structure_master_xywh_bboxes, |
| structure_master_xyxy_bboxes, |
| ) |
|
|
|
|
| def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes): |
| """ |
| Judge end2end Bbox's center point is inside structure master Bbox or not, |
| if end2end Bbox's center is in structure master Bbox, get matching pair. |
| :param end2end_xywh_bboxes: |
| :param structure_master_xyxy_bboxes: |
| :return: match pairs list, e.g. [[0,1], [1,2], ...] |
| """ |
| match_pairs_list = [] |
| for i, end2end_xywh in enumerate(end2end_xywh_bboxes): |
| for j, master_xyxy in enumerate(structure_master_xyxy_bboxes): |
| x_end2end, y_end2end = end2end_xywh[0], end2end_xywh[1] |
| x_master1, y_master1, x_master2, y_master2 = ( |
| master_xyxy[0], |
| master_xyxy[1], |
| master_xyxy[2], |
| master_xyxy[3], |
| ) |
| center_point_end2end = (x_end2end, y_end2end) |
| corner_point_master = ((x_master1, y_master1), (x_master2, y_master2)) |
| if is_inside(center_point_end2end, corner_point_master): |
| match_pairs_list.append([i, j]) |
| return match_pairs_list |
|
|
|
|
| def iou_rule_match( |
| end2end_xyxy_bboxes, end2end_xyxy_indexes, structure_master_xyxy_bboxes |
| ): |
| """ |
| Use iou to find matching list. |
| choose max iou value bbox as match pair. |
| :param end2end_xyxy_bboxes: |
| :param end2end_xyxy_indexes: original end2end indexes. |
| :param structure_master_xyxy_bboxes: |
| :return: match pairs list, e.g. [[0,1], [1,2], ...] |
| """ |
| match_pair_list = [] |
| for end2end_xyxy_index, end2end_xyxy in zip( |
| end2end_xyxy_indexes, end2end_xyxy_bboxes |
| ): |
| max_iou = 0 |
| max_match = [None, None] |
| for j, master_xyxy in enumerate(structure_master_xyxy_bboxes): |
| end2end_4xy = convert_coord(end2end_xyxy) |
| master_4xy = convert_coord(master_xyxy) |
| iou = cal_iou(end2end_4xy, master_4xy) |
| if iou > max_iou: |
| max_match[0], max_match[1] = end2end_xyxy_index, j |
| max_iou = iou |
|
|
| if max_match[0] is None: |
| |
| continue |
| match_pair_list.append(max_match) |
| return match_pair_list |
|
|
|
|
| def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes, master_bboxes): |
| """ |
| Get matching between no-match end2end bboxes and no-match master bboxes. |
| Use min distance to match. |
| This rule will only run (no-match end2end nums > 0) and (no-match master nums > 0) |
| It will Return master_bboxes_nums match-pairs. |
| :param end2end_indexes: |
| :param end2end_bboxes: |
| :param master_indexes: |
| :param master_bboxes: |
| :return: match_pairs list, e.g. [[0,1], [1,2], ...] |
| """ |
| min_match_list = [] |
| for j, master_bbox in zip(master_indexes, master_bboxes): |
| min_distance = np.inf |
| min_match = [0, 0] |
| for i, end2end_bbox in zip(end2end_indexes, end2end_bboxes): |
| x_end2end, y_end2end = end2end_bbox[0], end2end_bbox[1] |
| x_master, y_master = master_bbox[0], master_bbox[1] |
| end2end_point = (x_end2end, y_end2end) |
| master_point = (x_master, y_master) |
| dist = cal_distance(master_point, end2end_point) |
| if dist < min_distance: |
| min_match[0], min_match[1] = i, j |
| min_distance = dist |
| min_match_list.append(min_match) |
| return min_match_list |
|
|
|
|
| def extra_match(no_match_end2end_indexes, master_bbox_nums): |
| """ |
| This function will create some virtual master bboxes, |
| and get match with the no match end2end indexes. |
| :param no_match_end2end_indexes: |
| :param master_bbox_nums: |
| :return: |
| """ |
| end_nums = len(no_match_end2end_indexes) + master_bbox_nums |
| extra_match_list = [] |
| for i in range(master_bbox_nums, end_nums): |
| end2end_index = no_match_end2end_indexes[i - master_bbox_nums] |
| extra_match_list.append([end2end_index, i]) |
| return extra_match_list |
|
|
|
|
| def get_match_dict(match_list): |
| """ |
| Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index. |
| :param match_list: |
| :return: |
| """ |
| match_dict = dict() |
| for match_pair in match_list: |
| end2end_index, master_index = match_pair[0], match_pair[1] |
| if master_index not in match_dict.keys(): |
| match_dict[master_index] = [end2end_index] |
| else: |
| match_dict[master_index].append(end2end_index) |
| return match_dict |
|
|
|
|
| def deal_successive_space(text): |
| """ |
| deal successive space character for text |
| 1. Replace ' '*3 with '<space>' which is real space is text |
| 2. Remove ' ', which is split token, not true space |
| 3. Replace '<space>' with ' ', to get real text |
| :param text: |
| :return: |
| """ |
| text = text.replace(" " * 3, "<space>") |
| text = text.replace(" ", "") |
| text = text.replace("<space>", " ") |
| return text |
|
|
|
|
| def reduce_repeat_bb(text_list, break_token): |
| """ |
| convert ['<b>Local</b>', '<b>government</b>', '<b>unit</b>'] to ['<b>Local government unit</b>'] |
| PS: maybe style <i>Local</i> is also exist, too. it can be processed like this. |
| :param text_list: |
| :param break_token: |
| :return: |
| """ |
| count = 0 |
| for text in text_list: |
| if text.startswith("<b>"): |
| count += 1 |
| if count == len(text_list): |
| new_text_list = [] |
| for text in text_list: |
| text = text.replace("<b>", "").replace("</b>", "") |
| new_text_list.append(text) |
| return ["<b>" + break_token.join(new_text_list) + "</b>"] |
| else: |
| return text_list |
|
|
|
|
| def get_match_text_dict(match_dict, end2end_info, break_token=" "): |
| match_text_dict = dict() |
| for master_index, end2end_index_list in match_dict.items(): |
| text_list = [ |
| end2end_info[end2end_index]["text"] for end2end_index in end2end_index_list |
| ] |
| text_list = reduce_repeat_bb(text_list, break_token) |
| text = break_token.join(text_list) |
| match_text_dict[master_index] = text |
| return match_text_dict |
|
|
|
|
| def merge_span_token(master_token_list): |
| """ |
| Merge the span style token (row span or col span). |
| :param master_token_list: |
| :return: |
| """ |
| new_master_token_list = [] |
| pointer = 0 |
| if master_token_list[-1] != "</tbody>": |
| master_token_list.append("</tbody>") |
| while master_token_list[pointer] != "</tbody>": |
| try: |
| if master_token_list[pointer] == "<td": |
| if master_token_list[pointer + 1].startswith( |
| " colspan=" |
| ) or master_token_list[pointer + 1].startswith(" rowspan="): |
| """ |
| example: |
| pattern <td colspan="3"> |
| '<td' + 'colspan=" "' + '>' + '</td>' |
| """ |
| tmp = "".join(master_token_list[pointer : pointer + 3 + 1]) |
| pointer += 4 |
| new_master_token_list.append(tmp) |
|
|
| elif master_token_list[pointer + 2].startswith( |
| " colspan=" |
| ) or master_token_list[pointer + 2].startswith(" rowspan="): |
| """ |
| example: |
| pattern <td rowspan="2" colspan="3"> |
| '<td' + 'rowspan=" "' + 'colspan=" "' + '>' + '</td>' |
| """ |
| tmp = "".join(master_token_list[pointer : pointer + 4 + 1]) |
| pointer += 5 |
| new_master_token_list.append(tmp) |
|
|
| else: |
| new_master_token_list.append(master_token_list[pointer]) |
| pointer += 1 |
| else: |
| new_master_token_list.append(master_token_list[pointer]) |
| pointer += 1 |
| except: |
| print("Break in merge...") |
| break |
| new_master_token_list.append("</tbody>") |
|
|
| return new_master_token_list |
|
|
|
|
| def deal_eb_token(master_token): |
| """ |
| post process with <eb></eb>, <eb1></eb1>, ... |
| emptyBboxTokenDict = { |
| "[]": '<eb></eb>', |
| "[' ']": '<eb1></eb1>', |
| "['<b>', ' ', '</b>']": '<eb2></eb2>', |
| "['\\u2028', '\\u2028']": '<eb3></eb3>', |
| "['<sup>', ' ', '</sup>']": '<eb4></eb4>', |
| "['<b>', '</b>']": '<eb5></eb5>', |
| "['<i>', ' ', '</i>']": '<eb6></eb6>', |
| "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>', |
| "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>', |
| "['<i>', '</i>']": '<eb9></eb9>', |
| "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>', |
| } |
| :param master_token: |
| :return: |
| """ |
| master_token = master_token.replace("<eb></eb>", "<td></td>") |
| master_token = master_token.replace("<eb1></eb1>", "<td> </td>") |
| master_token = master_token.replace("<eb2></eb2>", "<td><b> </b></td>") |
| master_token = master_token.replace("<eb3></eb3>", "<td>\u2028\u2028</td>") |
| master_token = master_token.replace("<eb4></eb4>", "<td><sup> </sup></td>") |
| master_token = master_token.replace("<eb5></eb5>", "<td><b></b></td>") |
| master_token = master_token.replace("<eb6></eb6>", "<td><i> </i></td>") |
| master_token = master_token.replace("<eb7></eb7>", "<td><b><i></i></b></td>") |
| master_token = master_token.replace("<eb8></eb8>", "<td><b><i> </i></b></td>") |
| master_token = master_token.replace("<eb9></eb9>", "<td><i></i></td>") |
| master_token = master_token.replace( |
| "<eb10></eb10>", "<td><b> \u2028 \u2028 </b></td>" |
| ) |
| return master_token |
|
|
|
|
| def insert_text_to_token(master_token_list, match_text_dict): |
| """ |
| Insert OCR text result to structure token. |
| :param master_token_list: |
| :param match_text_dict: |
| :return: |
| """ |
| master_token_list = merge_span_token(master_token_list) |
| merged_result_list = [] |
| text_count = 0 |
| for master_token in master_token_list: |
| if master_token.startswith("<td"): |
| if text_count > len(match_text_dict) - 1: |
| text_count += 1 |
| continue |
| elif text_count not in match_text_dict.keys(): |
| text_count += 1 |
| continue |
| else: |
| master_token = master_token.replace( |
| "><", ">{}<".format(match_text_dict[text_count]) |
| ) |
| text_count += 1 |
| master_token = deal_eb_token(master_token) |
| merged_result_list.append(master_token) |
|
|
| return "".join(merged_result_list) |
|
|
|
|
| def deal_isolate_span(thead_part): |
| """ |
| Deal with isolate span cases in this function. |
| It causes by wrong prediction in structure recognition model. |
| eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>. |
| :param thead_part: |
| :return: |
| """ |
| |
| isolate_pattern = ( |
| '<td></td> rowspan="(\d)+" colspan="(\d)+"></b></td>|' |
| '<td></td> colspan="(\d)+" rowspan="(\d)+"></b></td>|' |
| '<td></td> rowspan="(\d)+"></b></td>|' |
| '<td></td> colspan="(\d)+"></b></td>' |
| ) |
| isolate_iter = re.finditer(isolate_pattern, thead_part) |
| isolate_list = [i.group() for i in isolate_iter] |
|
|
| |
| span_pattern = ( |
| ' rowspan="(\d)+" colspan="(\d)+"|' |
| ' colspan="(\d)+" rowspan="(\d)+"|' |
| ' rowspan="(\d)+"|' |
| ' colspan="(\d)+"' |
| ) |
| corrected_list = [] |
| for isolate_item in isolate_list: |
| span_part = re.search(span_pattern, isolate_item) |
| spanStr_in_isolateItem = span_part.group() |
| |
| if spanStr_in_isolateItem is not None: |
| corrected_item = "<td{}></td>".format(spanStr_in_isolateItem) |
| corrected_list.append(corrected_item) |
| else: |
| corrected_list.append(None) |
|
|
| |
| for corrected_item, isolate_item in zip(corrected_list, isolate_list): |
| if corrected_item is not None: |
| thead_part = thead_part.replace(isolate_item, corrected_item) |
| else: |
| pass |
| return thead_part |
|
|
|
|
| def deal_duplicate_bb(thead_part): |
| """ |
| Deal duplicate <b> or </b> after replace. |
| Keep one <b></b> in a <td></td> token. |
| :param thead_part: |
| :return: |
| """ |
| |
| td_pattern = ( |
| '<td rowspan="(\d)+" colspan="(\d)+">(.+?)</td>|' |
| '<td colspan="(\d)+" rowspan="(\d)+">(.+?)</td>|' |
| '<td rowspan="(\d)+">(.+?)</td>|' |
| '<td colspan="(\d)+">(.+?)</td>|' |
| "<td>(.*?)</td>" |
| ) |
| td_iter = re.finditer(td_pattern, thead_part) |
| td_list = [t.group() for t in td_iter] |
|
|
| |
| new_td_list = [] |
| for td_item in td_list: |
| if td_item.count("<b>") > 1 or td_item.count("</b>") > 1: |
| |
| |
| td_item = td_item.replace("<b>", "").replace("</b>", "") |
| |
| td_item = td_item.replace("<td>", "<td><b>").replace("</td>", "</b></td>") |
| new_td_list.append(td_item) |
| else: |
| new_td_list.append(td_item) |
|
|
| |
| for td_item, new_td_item in zip(td_list, new_td_list): |
| thead_part = thead_part.replace(td_item, new_td_item) |
| return thead_part |
|
|
|
|
| def deal_bb(result_token): |
| """ |
| In our opinion, <b></b> always occurs in <thead></thead> text's context. |
| This function will find out all tokens in <thead></thead> and insert <b></b> by manual. |
| :param result_token: |
| :return: |
| """ |
| |
| thead_pattern = "<thead>(.*?)</thead>" |
| if re.search(thead_pattern, result_token) is None: |
| return result_token |
| thead_part = re.search(thead_pattern, result_token).group() |
| origin_thead_part = copy.deepcopy(thead_part) |
|
|
| |
| span_pattern = '<td rowspan="(\d)+" colspan="(\d)+">|<td colspan="(\d)+" rowspan="(\d)+">|<td rowspan="(\d)+">|<td colspan="(\d)+">' |
| span_iter = re.finditer(span_pattern, thead_part) |
| span_list = [s.group() for s in span_iter] |
| has_span_in_head = True if len(span_list) > 0 else False |
|
|
| if not has_span_in_head: |
| |
| |
| |
| |
| thead_part = ( |
| thead_part.replace("<td>", "<td><b>") |
| .replace("</td>", "</b></td>") |
| .replace("<b><b>", "<b>") |
| .replace("</b></b>", "</b>") |
| ) |
| else: |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| replaced_span_list = [] |
| for sp in span_list: |
| replaced_span_list.append(sp.replace(">", "><b>")) |
| for sp, rsp in zip(span_list, replaced_span_list): |
| thead_part = thead_part.replace(sp, rsp) |
|
|
| |
| thead_part = thead_part.replace("</td>", "</b></td>") |
|
|
| |
| mb_pattern = "(<b>)+" |
| single_b_string = "<b>" |
| thead_part = re.sub(mb_pattern, single_b_string, thead_part) |
|
|
| mgb_pattern = "(</b>)+" |
| single_gb_string = "</b>" |
| thead_part = re.sub(mgb_pattern, single_gb_string, thead_part) |
|
|
| |
| thead_part = thead_part.replace("<td>", "<td><b>").replace("<b><b>", "<b>") |
|
|
| |
| |
| thead_part = thead_part.replace("<td><b></b></td>", "<td></td>") |
| |
| thead_part = deal_duplicate_bb(thead_part) |
| |
| |
| thead_part = deal_isolate_span(thead_part) |
| |
| result_token = result_token.replace(origin_thead_part, thead_part) |
| return result_token |
|
|
|
|
| class Matcher: |
| def __init__(self, end2end_file, structure_master_file): |
| """ |
| This class process the end2end results and structure recognition results. |
| :param end2end_file: end2end results predict by end2end inference. |
| :param structure_master_file: structure recognition results predict by structure master inference. |
| """ |
| self.end2end_file = end2end_file |
| self.structure_master_file = structure_master_file |
| self.end2end_results = pickle_load(end2end_file, prefix="end2end") |
| self.structure_master_results = pickle_load( |
| structure_master_file, prefix="structure" |
| ) |
|
|
| def match(self): |
| """ |
| Match process: |
| pre-process : convert end2end and structure master results to xyxy, xywh ndnarray format. |
| 1. Use pseBbox is inside masterBbox judge rule |
| 2. Use iou between pseBbox and masterBbox rule |
| 3. Use min distance of center point rule |
| :return: |
| """ |
| match_results = dict() |
| for idx, (file_name, end2end_result) in enumerate(self.end2end_results.items()): |
| match_list = [] |
| if file_name not in self.structure_master_results: |
| continue |
| structure_master_result = self.structure_master_results[file_name] |
| ( |
| end2end_xyxy_bboxes, |
| end2end_xywh_bboxes, |
| structure_master_xywh_bboxes, |
| structure_master_xyxy_bboxes, |
| ) = get_bboxes_list(end2end_result, structure_master_result) |
|
|
| |
| center_rule_match_list = center_rule_match( |
| end2end_xywh_bboxes, structure_master_xyxy_bboxes |
| ) |
| match_list.extend(center_rule_match_list) |
|
|
| |
| |
| center_no_match_end2end_indexs = find_no_match( |
| match_list, len(end2end_xywh_bboxes), type="end2end" |
| ) |
| if len(center_no_match_end2end_indexs) > 0: |
| center_no_match_end2end_xyxy = end2end_xyxy_bboxes[ |
| center_no_match_end2end_indexs |
| ] |
| |
| iou_rule_match_list = iou_rule_match( |
| center_no_match_end2end_xyxy, |
| center_no_match_end2end_indexs, |
| structure_master_xyxy_bboxes, |
| ) |
| match_list.extend(iou_rule_match_list) |
|
|
| |
| |
| |
| |
| centerIou_no_match_end2end_indexs = find_no_match( |
| match_list, len(end2end_xywh_bboxes), type="end2end" |
| ) |
| centerIou_no_match_master_indexs = find_no_match( |
| match_list, len(structure_master_xywh_bboxes), type="master" |
| ) |
| if ( |
| len(centerIou_no_match_master_indexs) > 0 |
| and len(centerIou_no_match_end2end_indexs) > 0 |
| ): |
| centerIou_no_match_end2end_xywh = end2end_xywh_bboxes[ |
| centerIou_no_match_end2end_indexs |
| ] |
| centerIou_no_match_master_xywh = structure_master_xywh_bboxes[ |
| centerIou_no_match_master_indexs |
| ] |
| distance_match_list = distance_rule_match( |
| centerIou_no_match_end2end_indexs, |
| centerIou_no_match_end2end_xywh, |
| centerIou_no_match_master_indexs, |
| centerIou_no_match_master_xywh, |
| ) |
| match_list.extend(distance_match_list) |
|
|
| |
| |
| |
| |
| |
| |
| |
| no_match_end2end_indexes = find_no_match( |
| match_list, len(end2end_xywh_bboxes), type="end2end" |
| ) |
| if len(no_match_end2end_indexes) > 0: |
| no_match_end2end_xywh = end2end_xywh_bboxes[no_match_end2end_indexes] |
| |
| ( |
| end2end_sorted_indexes_list, |
| end2end_sorted_bboxes_list, |
| sorted_groups, |
| sorted_bboxes_groups, |
| ) = sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes) |
| |
| extra_match_list = extra_match( |
| end2end_sorted_indexes_list, len(structure_master_xywh_bboxes) |
| ) |
| match_list_add_extra_match = copy.deepcopy(match_list) |
| match_list_add_extra_match.extend(extra_match_list) |
| else: |
| |
| match_list_add_extra_match = copy.deepcopy(match_list) |
| sorted_groups = [] |
| sorted_bboxes_groups = [] |
|
|
| match_result_dict = { |
| "match_list": match_list, |
| "match_list_add_extra_match": match_list_add_extra_match, |
| "sorted_groups": sorted_groups, |
| "sorted_bboxes_groups": sorted_bboxes_groups, |
| } |
|
|
| |
| match_result_dict = self._format(match_result_dict, file_name) |
|
|
| match_results[file_name] = match_result_dict |
|
|
| return match_results |
|
|
| def _format(self, match_result, file_name): |
| """ |
| Extend the master token(insert virtual master token), and format matching result. |
| :param match_result: |
| :param file_name: |
| :return: |
| """ |
| end2end_info = self.end2end_results[file_name] |
| master_info = self.structure_master_results[file_name] |
| master_token = master_info["text"] |
| sorted_groups = match_result["sorted_groups"] |
|
|
| |
| virtual_master_token_list = [] |
| for line_group in sorted_groups: |
| tmp_list = ["<tr>"] |
| item_nums = len(line_group) |
| for _ in range(item_nums): |
| tmp_list.append("<td></td>") |
| tmp_list.append("</tr>") |
| virtual_master_token_list.extend(tmp_list) |
|
|
| |
| master_token_list = master_token.split(",") |
| if master_token_list[-1] == "</tbody>": |
| |
| |
| |
|
|
| |
| master_token_list[:-1].extend(virtual_master_token_list) |
|
|
| |
| |
| |
| |
|
|
| elif master_token_list[-1] == "<td></td>": |
| master_token_list.append("</tr>") |
| master_token_list.extend(virtual_master_token_list) |
| master_token_list.append("</tbody>") |
| else: |
| master_token_list.extend(virtual_master_token_list) |
| master_token_list.append("</tbody>") |
|
|
| |
| match_result.setdefault("matched_master_token_list", master_token_list) |
| return match_result |
|
|
| def get_merge_result(self, match_results): |
| """ |
| Merge the OCR result into structure token to get final results. |
| :param match_results: |
| :return: |
| """ |
| merged_results = dict() |
|
|
| |
| break_token = " " |
|
|
| for idx, (file_name, match_info) in enumerate(match_results.items()): |
| end2end_info = self.end2end_results[file_name] |
| master_token_list = match_info["matched_master_token_list"] |
| match_list = match_info["match_list_add_extra_match"] |
|
|
| match_dict = get_match_dict(match_list) |
| match_text_dict = get_match_text_dict(match_dict, end2end_info, break_token) |
| merged_result = insert_text_to_token(master_token_list, match_text_dict) |
| merged_result = deal_bb(merged_result) |
|
|
| merged_results[file_name] = merged_result |
|
|
| return merged_results |
|
|
|
|
| class TableMasterMatcher(Matcher): |
| def __init__(self): |
| pass |
|
|
| def __call__(self, structure_res, dt_boxes, rec_res, img_name=1): |
| end2end_results = {img_name: []} |
| for dt_box, res in zip(dt_boxes, rec_res): |
| d = dict( |
| bbox=np.array(dt_box), |
| text=res[0], |
| ) |
| end2end_results[img_name].append(d) |
|
|
| self.end2end_results = end2end_results |
|
|
| structure_master_result_dict = {img_name: {}} |
| pred_structures, pred_bboxes = structure_res |
| pred_structures = ",".join(pred_structures[3:-3]) |
| structure_master_result_dict[img_name]["text"] = pred_structures |
| structure_master_result_dict[img_name]["bbox"] = pred_bboxes |
| self.structure_master_results = structure_master_result_dict |
|
|
| |
| match_results = self.match() |
| merged_results = self.get_merge_result(match_results) |
| pred_html = merged_results[img_name] |
| pred_html = "<html><body><table>" + pred_html + "</table></body></html>" |
| return pred_html |
|
|