| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from collections import namedtuple |
| from copy import deepcopy |
| from typing import Sequence, Optional |
|
|
| import datasets |
| import evaluate |
|
|
| |
| _CITATION = """\ |
| @misc{nereval, |
| title={{NER-Evaluation}: Named Entity Evaluation as in SemEval 2013 task 9.1}, |
| url={https://github.com/davidsbatista/NER-Evaluation}, |
| note={Software available from https://github.com/davidsbatista/NER-Evaluation}, |
| author={Batista David}, |
| year={2018}, |
| } |
| """ |
|
|
| |
| _DESCRIPTION = """\ |
| ner-eval is a Python frame for sequence labeling evaluation. I twas used in SemEval 2013 task 9.1. |
| It supports exact match, partial match, spurious and other errors. |
| """ |
|
|
|
|
| |
| _KWARGS_DESCRIPTION = """ |
| Calculates how good are predictions given some references, using certain scores |
| Args: |
| predictions: List of List of predicted labels (Estimated targets as returned by a tagger) |
| references: List of List of reference labels (Ground truth (correct) target values) |
| tags: List of tags to evaluate. default: None |
| Returns: |
| 'scores' dict. Summary of the scores for overall and each tag. |
| { |
| "overall": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.0, |
| "ent_type_recall": 0.0, |
| "ent_type_f1": 0, |
| "partial_precision": 0.0, |
| "partial_recall": 0.0, |
| "partial_f1": 0, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "ORG": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.0, |
| "ent_type_recall": 0.0, |
| "ent_type_f1": 0, |
| "partial_precision": 0.0, |
| "partial_recall": 0.0, |
| "partial_f1": 0, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "PER": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.0, |
| "ent_type_recall": 0.0, |
| "ent_type_f1": 0, |
| "partial_precision": 0.0, |
| "partial_recall": 0.0, |
| "partial_f1": 0, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "LOC": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.0, |
| "ent_type_recall": 0.0, |
| "ent_type_f1": 0, |
| "partial_precision": 0.0, |
| "partial_recall": 0.0, |
| "partial_f1": 0, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| } |
| Examples: |
| >>> my_new_module = evaluate.load("fschlatt/ner_eval") |
| >>> results = my_new_module.compute( |
| ... references=[["B-LOC", "I-LOC", "I-LOC", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "I-PER", "O"]], |
| ... predictions=[["B-LOC", "I-LOC", "O", "O", "B-ORG", "I-ORG", "O", "B-PER", "I-PER", "O"]] |
| ... ) |
| >>> print(results) |
| { |
| "overall": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 2 / 3, |
| "ent_type_recall": 2 / 3, |
| "ent_type_f1": 2 / 3, |
| "partial_precision": 1 / 3, |
| "partial_recall": 1 / 3, |
| "partial_f1": 1 / 3, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "ORG": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.0, |
| "ent_type_recall": 0.0, |
| "ent_type_f1": 0, |
| "partial_precision": 0.0, |
| "partial_recall": 0.0, |
| "partial_f1": 0, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "PER": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.5, |
| "ent_type_recall": 1.0, |
| "ent_type_f1": 2 / 3, |
| "partial_precision": 0.25, |
| "partial_recall": 0.5, |
| "partial_f1": 1 / 3, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "LOC": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.5, |
| "ent_type_recall": 1.0, |
| "ent_type_f1": 2 / 3, |
| "partial_precision": 0.25, |
| "partial_recall": 0.5, |
| "partial_f1": 1 / 3, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| } |
| } |
| """ |
|
|
|
|
| @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
| class NEREval(evaluate.Metric): |
| """TODO: Short description of my evaluation module.""" |
|
|
| def _info(self): |
| return evaluate.MetricInfo( |
| |
| module_type="metric", |
| description=_DESCRIPTION, |
| citation=_CITATION, |
| homepage="https://github.com/davidsbatista/NER-Evaluation", |
| inputs_description=_KWARGS_DESCRIPTION, |
| |
| features=datasets.Features( |
| { |
| "predictions": datasets.Sequence( |
| datasets.Value("string", id="label"), id="sequence" |
| ), |
| "references": datasets.Sequence( |
| datasets.Value("string", id="label"), id="sequence" |
| ), |
| } |
| ), |
| |
| codebase_urls=["https://github.com/davidsbatista/NER-Evaluation"], |
| reference_urls=[ |
| "https://github.com/davidsbatista/NER-Evaluation", |
| "https://www.davidsbatista.net/blog/2018/05/09/Named_Entity_Evaluation/", |
| ], |
| ) |
|
|
| def _download_and_prepare(self, dl_manager): |
| """Optional: download external resources useful to compute the scores""" |
| |
| pass |
|
|
| def _compute( |
| self, |
| predictions: Sequence[Sequence[str]], |
| references: Sequence[Sequence[str]], |
| tags: Optional[Sequence[str]] = None, |
| modes: Optional[Sequence[str]] = None, |
| ): |
| if tags is None: |
| tags = list(parse_tags(predictions).union(parse_tags(references))) |
|
|
| evaluator = Evaluator(predictions, references, tags) |
| results, agg_results = evaluator.evaluate() |
|
|
| out = {"overall": parse_results(results, modes)} |
| for tag, tag_result in agg_results.items(): |
| out = {**out, tag: parse_results(tag_result, modes)} |
|
|
| return out |
|
|
|
|
| def parse_results(results, modes: Optional[Sequence[str]] = None): |
| if modes is None: |
| modes = ["strict", "ent_type", "partial", "exact"] |
|
|
| out = {} |
| for mode in modes: |
| out[f"{mode}_precision"] = results[mode]["precision"] |
| out[f"{mode}_recall"] = results[mode]["recall"] |
| out[f"{mode}_f1"] = results[mode]["f1"] |
| return out |
|
|
|
|
| def parse_tags(tokens: Sequence[Sequence[str]]): |
| tags = set() |
| for seq in tokens: |
| for t in seq: |
| tags.add(t.split("-")[-1]) |
| tags.discard("O") |
| return tags |
|
|
|
|
| Entity = namedtuple("Entity", "e_type start_offset end_offset") |
|
|
|
|
| class Evaluator: |
| def __init__(self, true, pred, tags): |
| """ """ |
|
|
| if len(true) != len(pred): |
| raise ValueError("Number of predicted documents does not equal true") |
|
|
| self.true = true |
| self.pred = pred |
| self.tags = tags |
|
|
| |
|
|
| self.metrics_results = { |
| "correct": 0, |
| "incorrect": 0, |
| "partial": 0, |
| "missed": 0, |
| "spurious": 0, |
| "possible": 0, |
| "actual": 0, |
| "precision": 0, |
| "recall": 0, |
| "f1": 0, |
| } |
|
|
| |
|
|
| self.results = { |
| "strict": deepcopy(self.metrics_results), |
| "ent_type": deepcopy(self.metrics_results), |
| "partial": deepcopy(self.metrics_results), |
| "exact": deepcopy(self.metrics_results), |
| } |
|
|
| |
|
|
| self.evaluation_agg_entities_type = {e: deepcopy(self.results) for e in tags} |
|
|
| def evaluate(self): |
| for true_ents, pred_ents in zip(self.true, self.pred): |
| |
| |
| |
|
|
| if len(true_ents) != len(pred_ents): |
| raise ValueError("Prediction length does not match true example length") |
|
|
| |
|
|
| tmp_results, tmp_agg_results = compute_metrics( |
| collect_named_entities(true_ents), |
| collect_named_entities(pred_ents), |
| self.tags, |
| ) |
|
|
| |
|
|
| |
|
|
| for eval_schema in self.results: |
| for metric in self.results[eval_schema]: |
| self.results[eval_schema][metric] += tmp_results[eval_schema][ |
| metric |
| ] |
|
|
| |
|
|
| self.results = compute_precision_recall_f1_wrapper(self.results) |
|
|
| |
|
|
| for e_type in self.tags: |
| for eval_schema in tmp_agg_results[e_type]: |
| for metric in tmp_agg_results[e_type][eval_schema]: |
| self.evaluation_agg_entities_type[e_type][eval_schema][ |
| metric |
| ] += tmp_agg_results[e_type][eval_schema][metric] |
|
|
| |
|
|
| self.evaluation_agg_entities_type[ |
| e_type |
| ] = compute_precision_recall_f1_wrapper( |
| self.evaluation_agg_entities_type[e_type] |
| ) |
|
|
| return self.results, self.evaluation_agg_entities_type |
|
|
|
|
| def collect_named_entities(tokens): |
| """ |
| Creates a list of Entity named-tuples, storing the entity type and the start and end |
| offsets of the entity. |
| |
| :param tokens: a list of tags |
| :return: a list of Entity named-tuples |
| """ |
|
|
| named_entities = [] |
| start_offset = None |
| end_offset = None |
| ent_type = None |
|
|
| for offset, token_tag in enumerate(tokens): |
| if token_tag == "O": |
| if ent_type is not None and start_offset is not None: |
| end_offset = offset - 1 |
| named_entities.append(Entity(ent_type, start_offset, end_offset)) |
| start_offset = None |
| end_offset = None |
| ent_type = None |
|
|
| elif ent_type is None: |
| ent_type = token_tag[2:] |
| start_offset = offset |
|
|
| elif ent_type != token_tag[2:] or ( |
| ent_type == token_tag[2:] and token_tag[:1] == "B" |
| ): |
| end_offset = offset - 1 |
| named_entities.append(Entity(ent_type, start_offset, end_offset)) |
|
|
| |
| ent_type = token_tag[2:] |
| start_offset = offset |
| end_offset = None |
|
|
| |
|
|
| if ent_type is not None and start_offset is not None and end_offset is None: |
| named_entities.append(Entity(ent_type, start_offset, len(tokens) - 1)) |
|
|
| return named_entities |
|
|
|
|
| def compute_metrics(true_named_entities, pred_named_entities, tags): |
| eval_metrics = { |
| "correct": 0, |
| "incorrect": 0, |
| "partial": 0, |
| "missed": 0, |
| "spurious": 0, |
| "precision": 0, |
| "recall": 0, |
| "f1": 0, |
| } |
|
|
| |
|
|
| evaluation = { |
| "strict": deepcopy(eval_metrics), |
| "ent_type": deepcopy(eval_metrics), |
| "partial": deepcopy(eval_metrics), |
| "exact": deepcopy(eval_metrics), |
| } |
|
|
| |
|
|
| evaluation_agg_entities_type = {e: deepcopy(evaluation) for e in tags} |
|
|
| |
|
|
| true_which_overlapped_with_pred = [] |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| true_named_entities = [ent for ent in true_named_entities if ent.e_type in tags] |
| pred_named_entities = [ent for ent in pred_named_entities if ent.e_type in tags] |
|
|
| |
|
|
| for pred in pred_named_entities: |
| found_overlap = False |
|
|
| |
| |
| |
|
|
| |
|
|
| if pred in true_named_entities: |
| true_which_overlapped_with_pred.append(pred) |
| evaluation["strict"]["correct"] += 1 |
| evaluation["ent_type"]["correct"] += 1 |
| evaluation["exact"]["correct"] += 1 |
| evaluation["partial"]["correct"] += 1 |
|
|
| |
| evaluation_agg_entities_type[pred.e_type]["strict"]["correct"] += 1 |
| evaluation_agg_entities_type[pred.e_type]["ent_type"]["correct"] += 1 |
| evaluation_agg_entities_type[pred.e_type]["exact"]["correct"] += 1 |
| evaluation_agg_entities_type[pred.e_type]["partial"]["correct"] += 1 |
|
|
| else: |
| |
|
|
| for true in true_named_entities: |
| pred_range = range(pred.start_offset, pred.end_offset) |
| true_range = range(true.start_offset, true.end_offset) |
|
|
| |
|
|
| if ( |
| true.start_offset == pred.start_offset |
| and pred.end_offset == true.end_offset |
| and true.e_type != pred.e_type |
| ): |
| |
| evaluation["strict"]["incorrect"] += 1 |
| evaluation["ent_type"]["incorrect"] += 1 |
| evaluation["partial"]["correct"] += 1 |
| evaluation["exact"]["correct"] += 1 |
|
|
| |
| evaluation_agg_entities_type[true.e_type]["strict"][ |
| "incorrect" |
| ] += 1 |
| evaluation_agg_entities_type[true.e_type]["ent_type"][ |
| "incorrect" |
| ] += 1 |
| evaluation_agg_entities_type[true.e_type]["partial"]["correct"] += 1 |
| evaluation_agg_entities_type[true.e_type]["exact"]["correct"] += 1 |
|
|
| true_which_overlapped_with_pred.append(true) |
| found_overlap = True |
|
|
| break |
|
|
| |
|
|
| elif find_overlap(true_range, pred_range): |
| true_which_overlapped_with_pred.append(true) |
|
|
| |
| |
| |
|
|
| if pred.e_type == true.e_type: |
| |
| evaluation["strict"]["incorrect"] += 1 |
| evaluation["ent_type"]["correct"] += 1 |
| evaluation["partial"]["partial"] += 1 |
| evaluation["exact"]["incorrect"] += 1 |
|
|
| |
| evaluation_agg_entities_type[true.e_type]["strict"][ |
| "incorrect" |
| ] += 1 |
| evaluation_agg_entities_type[true.e_type]["ent_type"][ |
| "correct" |
| ] += 1 |
| evaluation_agg_entities_type[true.e_type]["partial"][ |
| "partial" |
| ] += 1 |
| evaluation_agg_entities_type[true.e_type]["exact"][ |
| "incorrect" |
| ] += 1 |
|
|
| found_overlap = True |
|
|
| break |
|
|
| |
| |
|
|
| else: |
| |
| evaluation["strict"]["incorrect"] += 1 |
| evaluation["ent_type"]["incorrect"] += 1 |
| evaluation["partial"]["partial"] += 1 |
| evaluation["exact"]["incorrect"] += 1 |
|
|
| |
| |
|
|
| evaluation_agg_entities_type[true.e_type]["strict"][ |
| "incorrect" |
| ] += 1 |
| evaluation_agg_entities_type[true.e_type]["partial"][ |
| "partial" |
| ] += 1 |
| evaluation_agg_entities_type[true.e_type]["ent_type"][ |
| "incorrect" |
| ] += 1 |
| evaluation_agg_entities_type[true.e_type]["exact"][ |
| "incorrect" |
| ] += 1 |
|
|
| |
|
|
| |
|
|
| found_overlap = True |
|
|
| break |
|
|
| |
|
|
| if not found_overlap: |
| |
|
|
| evaluation["strict"]["spurious"] += 1 |
| evaluation["ent_type"]["spurious"] += 1 |
| evaluation["partial"]["spurious"] += 1 |
| evaluation["exact"]["spurious"] += 1 |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| for true in tags: |
| evaluation_agg_entities_type[true]["strict"]["spurious"] += 1 |
| evaluation_agg_entities_type[true]["ent_type"]["spurious"] += 1 |
| evaluation_agg_entities_type[true]["partial"]["spurious"] += 1 |
| evaluation_agg_entities_type[true]["exact"]["spurious"] += 1 |
|
|
| |
|
|
| for true in true_named_entities: |
| if true in true_which_overlapped_with_pred: |
| continue |
| else: |
| |
| evaluation["strict"]["missed"] += 1 |
| evaluation["ent_type"]["missed"] += 1 |
| evaluation["partial"]["missed"] += 1 |
| evaluation["exact"]["missed"] += 1 |
|
|
| |
| evaluation_agg_entities_type[true.e_type]["strict"]["missed"] += 1 |
| evaluation_agg_entities_type[true.e_type]["ent_type"]["missed"] += 1 |
| evaluation_agg_entities_type[true.e_type]["partial"]["missed"] += 1 |
| evaluation_agg_entities_type[true.e_type]["exact"]["missed"] += 1 |
|
|
| |
| |
|
|
| for eval_type in evaluation: |
| evaluation[eval_type] = compute_actual_possible(evaluation[eval_type]) |
|
|
| |
| |
|
|
| for entity_type, entity_level in evaluation_agg_entities_type.items(): |
| |
| |
|
|
| for eval_type in entity_level: |
| evaluation_agg_entities_type[entity_type][ |
| eval_type |
| ] = compute_actual_possible(entity_level[eval_type]) |
|
|
| return evaluation, evaluation_agg_entities_type |
|
|
|
|
| def find_overlap(true_range, pred_range): |
| """Find the overlap between two ranges |
| |
| Find the overlap between two ranges. Return the overlapping values if |
| present, else return an empty set(). |
| |
| Examples: |
| |
| >>> find_overlap((1, 2), (2, 3)) |
| 2 |
| >>> find_overlap((1, 2), (3, 4)) |
| set() |
| """ |
|
|
| true_set = set(true_range) |
| pred_set = set(pred_range) |
|
|
| overlaps = true_set.intersection(pred_set) |
|
|
| return overlaps |
|
|
|
|
| def compute_actual_possible(results): |
| """ |
| Takes a result dict that has been output by compute metrics. |
| Returns the results dict with actual, possible populated. |
| |
| When the results dicts is from partial or ent_type metrics, then |
| partial_or_type=True to ensure the right calculation is used for |
| calculating precision and recall. |
| """ |
|
|
| correct = results["correct"] |
| incorrect = results["incorrect"] |
| partial = results["partial"] |
| missed = results["missed"] |
| spurious = results["spurious"] |
|
|
| |
| |
|
|
| possible = correct + incorrect + partial + missed |
|
|
| |
|
|
| actual = correct + incorrect + partial + spurious |
|
|
| results["actual"] = actual |
| results["possible"] = possible |
|
|
| return results |
|
|
|
|
| def compute_precision_recall_f1(results, partial_or_type=False): |
| """ |
| Takes a result dict that has been output by compute metrics. |
| Returns the results dict with precison and recall populated. |
| |
| When the results dicts is from partial or ent_type metrics, then |
| partial_or_type=True to ensure the right calculation is used for |
| calculating precision and recall. |
| """ |
|
|
| actual = results["actual"] |
| possible = results["possible"] |
| partial = results["partial"] |
| correct = results["correct"] |
|
|
| if partial_or_type: |
| precision = (correct + 0.5 * partial) / actual if actual > 0 else 0 |
| recall = (correct + 0.5 * partial) / possible if possible > 0 else 0 |
|
|
| else: |
| precision = correct / actual if actual > 0 else 0 |
| recall = correct / possible if possible > 0 else 0 |
|
|
| results["precision"] = precision |
| results["recall"] = recall |
| results["f1"] = ( |
| precision * recall * 2 / (precision + recall) if precision + recall > 0 else 0 |
| ) |
|
|
| return results |
|
|
|
|
| def compute_precision_recall_f1_wrapper(results): |
| """ |
| Wraps the compute_precision_recall_f1 function and runs on a dict of results |
| """ |
|
|
| results_a = { |
| key: compute_precision_recall_f1(value, True) |
| for key, value in results.items() |
| if key in ["partial", "ent_type"] |
| } |
| results_b = { |
| key: compute_precision_recall_f1(value) |
| for key, value in results.items() |
| if key in ["strict", "exact"] |
| } |
|
|
| results = {**results_a, **results_b} |
|
|
| return results |
|
|