| import datasets
|
| import evaluate
|
| from transformers.trainer_utils import EvalPrediction
|
|
|
| accuracy = evaluate.load("accuracy").compute
|
| precision = evaluate.load("precision").compute
|
| recall = evaluate.load("recall").compute
|
| f1 = evaluate.load("f1").compute
|
| squad_v2 = evaluate.load("squad_v2").compute
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def compute_classification_metric(p: EvalPrediction):
|
| """
|
| Compute classification metrics for a given prediction.
|
|
|
| Args:
|
| p (EvalPrediction): The prediction object.
|
|
|
| Returns:
|
| datasets.Metric: The metric object containing accuracy, precision,
|
| recall, and f1 score.
|
| """
|
|
|
| predictions = p.predictions.argmax(axis=1)
|
| references = p.label_ids
|
|
|
|
|
| metric = accuracy(predictions=predictions, references=references)
|
|
|
|
|
| metric.update(precision(predictions=predictions, references=references))
|
| metric.update(recall(predictions=predictions, references=references))
|
| metric.update(f1(predictions=predictions, references=references))
|
|
|
|
|
| return metric
|
|
|
|
|
| def compute_squad_v2(p: EvalPrediction):
|
| """
|
| Compute SQuAD v2 metrics for a given prediction.
|
|
|
| Args:
|
| p (EvalPrediction): The prediction object.
|
|
|
| Returns:
|
| datasets.Metric: The metric object containing SQuAD v2 metrics.
|
| """
|
|
|
| predictions = p.predictions
|
| references = p.label_ids
|
|
|
|
|
| return squad_v2(predictions=predictions, references=references)
|
|
|
|
|