| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import paddle |
|
|
|
|
| class VQAReTokenLayoutLMPostProcess(object): |
| """Convert between text-label and text-index""" |
|
|
| def __init__(self, **kwargs): |
| super(VQAReTokenLayoutLMPostProcess, self).__init__() |
|
|
| def __call__(self, preds, label=None, *args, **kwargs): |
| pred_relations = preds["pred_relations"] |
| if isinstance(preds["pred_relations"], paddle.Tensor): |
| pred_relations = pred_relations.numpy() |
| pred_relations = self.decode_pred(pred_relations) |
|
|
| if label is not None: |
| return self._metric(pred_relations, label) |
| else: |
| return self._infer(pred_relations, *args, **kwargs) |
|
|
| def _metric(self, pred_relations, label): |
| return pred_relations, label[-1], label[-2] |
|
|
| def _infer(self, pred_relations, *args, **kwargs): |
| ser_results = kwargs["ser_results"] |
| entity_idx_dict_batch = kwargs["entity_idx_dict_batch"] |
|
|
| |
| results = [] |
| for pred_relation, ser_result, entity_idx_dict in zip( |
| pred_relations, ser_results, entity_idx_dict_batch |
| ): |
| result = [] |
| used_tail_id = [] |
| for relation in pred_relation: |
| if relation["tail_id"] in used_tail_id: |
| continue |
| used_tail_id.append(relation["tail_id"]) |
| ocr_info_head = ser_result[entity_idx_dict[relation["head_id"]]] |
| ocr_info_tail = ser_result[entity_idx_dict[relation["tail_id"]]] |
| result.append((ocr_info_head, ocr_info_tail)) |
| results.append(result) |
| return results |
|
|
| def decode_pred(self, pred_relations): |
| pred_relations_new = [] |
| for pred_relation in pred_relations: |
| pred_relation_new = [] |
| pred_relation = pred_relation[1 : pred_relation[0, 0, 0] + 1] |
| for relation in pred_relation: |
| relation_new = dict() |
| relation_new["head_id"] = relation[0, 0] |
| relation_new["head"] = tuple(relation[1]) |
| relation_new["head_type"] = relation[2, 0] |
| relation_new["tail_id"] = relation[3, 0] |
| relation_new["tail"] = tuple(relation[4]) |
| relation_new["tail_type"] = relation[5, 0] |
| relation_new["type"] = relation[6, 0] |
| pred_relation_new.append(relation_new) |
| pred_relations_new.append(pred_relation_new) |
| return pred_relations_new |
|
|
|
|
| class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess): |
| """ |
| DistillationRePostProcess |
| """ |
|
|
| def __init__(self, model_name=["Student"], key=None, **kwargs): |
| super().__init__(**kwargs) |
| if not isinstance(model_name, list): |
| model_name = [model_name] |
| self.model_name = model_name |
| self.key = key |
|
|
| def __call__(self, preds, *args, **kwargs): |
| output = dict() |
| for name in self.model_name: |
| pred = preds[name] |
| if self.key is not None: |
| pred = pred[self.key] |
| output[name] = super().__call__(pred, *args, **kwargs) |
| return output |
|
|