| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """TODO: Add a description here.""" |
|
|
| from operator import eq |
| from typing import Callable, Iterable, Union |
|
|
| import evaluate |
| import datasets |
| import numpy as np |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| _CITATION = """\ |
| @InProceedings{huggingface:module, |
| title = {A great new module}, |
| authors={huggingface, Inc.}, |
| year={2020} |
| } |
| """ |
|
|
| |
| _DESCRIPTION = """\ |
| Computes precision, recall, f1 scores for joint entity-relation extraction task. |
| """ |
|
|
|
|
| |
| _KWARGS_DESCRIPTION = """ |
| Calculates how good are predictions given some references, using certain scores |
| Args: |
| predictions: list of predictions to score. Each predictions |
| should be a string with tokens separated by spaces. |
| references: list of reference for each prediction. Each |
| reference should be a string with tokens separated by spaces. |
| eq_fn: function to compare two items. Defaults to the equality operator. |
| Returns: |
| recall: |
| precision: |
| f1: |
| Examples: |
| >>> jer = evaluate.load("jer") |
| >>> results = jer.compute(references=[["Baris | play | tennis", "Deniz | travel | London"]], predictions=[["Baris | play | tennis"]]) |
| >>> print(results) |
| {'recall': 0.5, 'precision': 1.0, 'f1': 0.6666666666666666} |
| """ |
|
|
| Triplet = Union[str, tuple, int] |
|
|
| @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
| class jer(evaluate.Metric): |
| """TODO: Short description of my evaluation module.""" |
|
|
| def _info(self): |
| |
| return evaluate.MetricInfo( |
| |
| module_type="metric", |
| description=_DESCRIPTION, |
| citation=_CITATION, |
| inputs_description=_KWARGS_DESCRIPTION, |
| |
| features=datasets.Features({ |
| 'predictions': datasets.features.Sequence(datasets.Value('string')), |
| 'references': datasets.features.Sequence(datasets.Value('string')), |
| }), |
| |
| homepage="http://module.homepage", |
| |
| codebase_urls=["http://github.com/path/to/codebase/of/new_module"], |
| reference_urls=["http://path.to.reference.url/new_module"] |
| ) |
|
|
| def _download_and_prepare(self, dl_manager): |
| """Optional: download external resources useful to compute the scores""" |
| pass |
|
|
| def _compute(self, predictions, references, eq_fn=eq): |
| """Returns the scores""" |
| score_dicts = [ |
| self._compute_single(prediction=prediction, reference=reference, eq_fn=eq_fn) |
| for prediction, reference in zip(predictions, references) |
| ] |
| return {('mean_' + key): np.mean([scores[key] for scores in score_dicts]) for key in score_dicts[0].keys()} |
| |
| def _compute_single( |
| self, |
| *, |
| prediction: Iterable[Triplet], |
| reference: Iterable[Triplet], |
| eq_fn: Callable[[Triplet, Triplet], bool], |
| ): |
| reference_set = set(reference) |
| if len(reference) != len(reference_set): |
| logger.warn(f"Duplicates found in the reference list {reference}") |
| prediction_set = set(prediction) |
|
|
| tp = sum(int(is_in(item, prediction, eq_fn=eq_fn)) for item in reference) |
| fp = len(prediction_set) - tp |
| fn = len(reference_set) - tp |
| |
| |
| precision = tp / (tp + fp) if tp + fp > 0 else 0 |
| recall = tp / (tp + fn) if tp + fn > 0 else 0 |
| f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 |
| |
| return { |
| 'precision': precision, |
| 'recall': recall, |
| 'f1': f1_score |
| } |
|
|
| def is_in(target, collection: Iterable, eq_fn=eq) -> bool: |
| for item in collection: |
| if eq_fn(item, target): |
| return True |
| return False |