Spaces:
Sleeping
Sleeping
Improve type hints, add early input checks and format code.
Browse files- tests/test_metrics.py +7 -2
- timebench_eval.py +66 -28
tests/test_metrics.py
CHANGED
|
@@ -77,13 +77,18 @@ from conftest import (
|
|
| 77 |
],
|
| 78 |
)
|
| 79 |
def test_eval(prediction, reference, task, expected_metrics):
|
| 80 |
-
metrics = TimebenchEval()._compute(
|
|
|
|
|
|
|
| 81 |
assert metrics == expected_metrics
|
| 82 |
|
| 83 |
|
| 84 |
def test_eval_many():
|
| 85 |
metrics = TimebenchEval()._compute(
|
| 86 |
-
[PREDICTION_3, PREDICTION_4],
|
|
|
|
|
|
|
|
|
|
| 87 |
)
|
| 88 |
assert metrics == {
|
| 89 |
"exact_match": [1, 1],
|
|
|
|
| 77 |
],
|
| 78 |
)
|
| 79 |
def test_eval(prediction, reference, task, expected_metrics):
|
| 80 |
+
metrics = TimebenchEval()._compute(
|
| 81 |
+
[prediction], [reference], task, return_average=False
|
| 82 |
+
)
|
| 83 |
assert metrics == expected_metrics
|
| 84 |
|
| 85 |
|
| 86 |
def test_eval_many():
|
| 87 |
metrics = TimebenchEval()._compute(
|
| 88 |
+
[PREDICTION_3, PREDICTION_4],
|
| 89 |
+
["unanswerable", "Cardiff City"],
|
| 90 |
+
"MenatQA",
|
| 91 |
+
return_average=False,
|
| 92 |
)
|
| 93 |
assert metrics == {
|
| 94 |
"exact_match": [1, 1],
|
timebench_eval.py
CHANGED
|
@@ -15,12 +15,35 @@
|
|
| 15 |
|
| 16 |
import re
|
| 17 |
from datetime import datetime
|
|
|
|
| 18 |
|
| 19 |
import datasets
|
| 20 |
import evaluate
|
| 21 |
from dateutil import parser
|
| 22 |
from dateutil.parser import ParserError
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
_CITATION = """\
|
| 25 |
@software{abbood2026timebench_eval,
|
| 26 |
title={TimeBench Eval},
|
|
@@ -44,16 +67,18 @@ Args:
|
|
| 44 |
should contain the marker "Thus, the correct answer is:" followed by the answer.
|
| 45 |
references: list of reference answer strings.
|
| 46 |
task: the task type, one of "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", or "TimeDial".
|
|
|
|
|
|
|
| 47 |
Returns:
|
| 48 |
-
exact_match: list of exact match scores
|
| 49 |
-
f1: list of F1 scores for each prediction (for applicable tasks).
|
| 50 |
Examples:
|
| 51 |
>>> timebench_eval = evaluate.load("aauss/timebench_eval")
|
| 52 |
>>> predictions = ["Let me think... Thus, the correct answer is: Aug, 1987."]
|
| 53 |
>>> references = ["Aug, 1987"]
|
| 54 |
>>> results = timebench_eval.compute(predictions=predictions, references=references, task="Date Arithmetic")
|
| 55 |
>>> print(results)
|
| 56 |
-
{'exact_match':
|
| 57 |
"""
|
| 58 |
|
| 59 |
|
|
@@ -65,7 +90,7 @@ class TimebenchEval(evaluate.Metric):
|
|
| 65 |
super().__init__(*args, **kwargs)
|
| 66 |
self.squad_metric = evaluate.load("squad")
|
| 67 |
|
| 68 |
-
def _info(self):
|
| 69 |
return evaluate.MetricInfo(
|
| 70 |
module_type="metric",
|
| 71 |
description=_DESCRIPTION,
|
|
@@ -78,13 +103,19 @@ class TimebenchEval(evaluate.Metric):
|
|
| 78 |
}
|
| 79 |
),
|
| 80 |
homepage="https://huggingface.co/spaces/aauss/timebench_eval",
|
| 81 |
-
codebase_urls=[
|
|
|
|
|
|
|
| 82 |
reference_urls=["https://huggingface.co/datasets/ulab-ai/Time-Bench"],
|
| 83 |
)
|
| 84 |
|
| 85 |
def _compute(
|
| 86 |
-
self,
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
"""
|
| 89 |
Compute evaluation metrics for the given predictions and references.
|
| 90 |
|
|
@@ -92,25 +123,40 @@ class TimebenchEval(evaluate.Metric):
|
|
| 92 |
predictions: List of prediction strings to evaluate.
|
| 93 |
references: List of reference strings to compare against.
|
| 94 |
task: Task type, one of: "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", "TimeDial".
|
|
|
|
| 95 |
|
| 96 |
Returns:
|
| 97 |
-
Dictionary containing metric scores (exact_match and/or f1) as
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
"""
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
"
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
else:
|
| 110 |
raise ValueError(
|
| 111 |
-
f"Unknown task: {task}. Expected one of:
|
| 112 |
)
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
@staticmethod
|
| 115 |
def _extract_answer(response: str) -> str | None:
|
| 116 |
"""Extract the answer from the response"""
|
|
@@ -140,15 +186,7 @@ class TimebenchEval(evaluate.Metric):
|
|
| 140 |
if not text:
|
| 141 |
return set()
|
| 142 |
|
| 143 |
-
|
| 144 |
-
# 1. At word boundary followed by period, comma, space, &, or end: \b[A-D](?=[.\s,&]|$)
|
| 145 |
-
# 2. This avoids matching letters inside words like "CAD" or "BAD"
|
| 146 |
-
|
| 147 |
-
# Find all A, B, C, D that look like option selections
|
| 148 |
-
# They should be at a word boundary and followed by typical delimiters
|
| 149 |
-
pattern = r"\b([A-D])(?:\.|,|\s|&|$)"
|
| 150 |
-
|
| 151 |
-
matches = re.findall(pattern, text)
|
| 152 |
return set(matches)
|
| 153 |
|
| 154 |
def _call_squad(
|
|
|
|
| 15 |
|
| 16 |
import re
|
| 17 |
from datetime import datetime
|
| 18 |
+
from typing import Literal, TypedDict
|
| 19 |
|
| 20 |
import datasets
|
| 21 |
import evaluate
|
| 22 |
from dateutil import parser
|
| 23 |
from dateutil.parser import ParserError
|
| 24 |
|
| 25 |
+
TASK_TEMPREASON = "TempReason"
|
| 26 |
+
TASK_TIMEQA = "TimeQA"
|
| 27 |
+
TASK_MENATQA = "MenatQA"
|
| 28 |
+
TASK_DATE_ARITHMETIC = "Date Arithmetic"
|
| 29 |
+
TASK_TIMEDIAL = "TimeDial"
|
| 30 |
+
VALID_TASKS = frozenset(
|
| 31 |
+
{TASK_TEMPREASON, TASK_TIMEQA, TASK_MENATQA, TASK_DATE_ARITHMETIC, TASK_TIMEDIAL}
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
SQUAD_TASKS = frozenset({TASK_TEMPREASON, TASK_TIMEQA, TASK_MENATQA})
|
| 35 |
+
|
| 36 |
+
TaskType = Literal["TempReason", "TimeQA", "MenatQA", "Date Arithmetic", "TimeDial"]
|
| 37 |
+
|
| 38 |
+
SELECTED_OPTIONS_PATTERN = r"\b([A-D])(?:\.|,|\s|&|$)"
|
| 39 |
+
SELECTED_OPTIONS_REGEX = re.compile(SELECTED_OPTIONS_PATTERN)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class TimebenchResult(TypedDict, total=False):
|
| 43 |
+
exact_match: float | list[float]
|
| 44 |
+
f1: float | list[float]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
_CITATION = """\
|
| 48 |
@software{abbood2026timebench_eval,
|
| 49 |
title={TimeBench Eval},
|
|
|
|
| 67 |
should contain the marker "Thus, the correct answer is:" followed by the answer.
|
| 68 |
references: list of reference answer strings.
|
| 69 |
task: the task type, one of "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", or "TimeDial".
|
| 70 |
+
return_average: if True (default), returns average scores as floats.
|
| 71 |
+
If False, returns a list of scores for each sample.
|
| 72 |
Returns:
|
| 73 |
+
exact_match: average or list of exact match scores for each prediction.
|
| 74 |
+
f1: average or list of F1 scores for each prediction (for applicable tasks).
|
| 75 |
Examples:
|
| 76 |
>>> timebench_eval = evaluate.load("aauss/timebench_eval")
|
| 77 |
>>> predictions = ["Let me think... Thus, the correct answer is: Aug, 1987."]
|
| 78 |
>>> references = ["Aug, 1987"]
|
| 79 |
>>> results = timebench_eval.compute(predictions=predictions, references=references, task="Date Arithmetic")
|
| 80 |
>>> print(results)
|
| 81 |
+
{'exact_match': 1.0}
|
| 82 |
"""
|
| 83 |
|
| 84 |
|
|
|
|
| 90 |
super().__init__(*args, **kwargs)
|
| 91 |
self.squad_metric = evaluate.load("squad")
|
| 92 |
|
| 93 |
+
def _info(self) -> evaluate.MetricInfo:
|
| 94 |
return evaluate.MetricInfo(
|
| 95 |
module_type="metric",
|
| 96 |
description=_DESCRIPTION,
|
|
|
|
| 103 |
}
|
| 104 |
),
|
| 105 |
homepage="https://huggingface.co/spaces/aauss/timebench_eval",
|
| 106 |
+
codebase_urls=[
|
| 107 |
+
"https://huggingface.co/spaces/aauss/timebench_eval/tree/main"
|
| 108 |
+
],
|
| 109 |
reference_urls=["https://huggingface.co/datasets/ulab-ai/Time-Bench"],
|
| 110 |
)
|
| 111 |
|
| 112 |
def _compute(
|
| 113 |
+
self,
|
| 114 |
+
predictions: list[str],
|
| 115 |
+
references: list[str],
|
| 116 |
+
task: TaskType,
|
| 117 |
+
return_average: bool = True,
|
| 118 |
+
) -> TimebenchResult:
|
| 119 |
"""
|
| 120 |
Compute evaluation metrics for the given predictions and references.
|
| 121 |
|
|
|
|
| 123 |
predictions: List of prediction strings to evaluate.
|
| 124 |
references: List of reference strings to compare against.
|
| 125 |
task: Task type, one of: "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", "TimeDial".
|
| 126 |
+
return_average: If True, returns average scores; if False, returns per-sample scores.
|
| 127 |
|
| 128 |
Returns:
|
| 129 |
+
Dictionary containing metric scores (exact_match and/or f1) as floats or lists.
|
| 130 |
+
|
| 131 |
+
Raises:
|
| 132 |
+
ValueError: If predictions is empty.
|
| 133 |
+
ValueError: If predictions and references have different lengths.
|
| 134 |
+
ValueError: If task is not a valid task type.
|
| 135 |
"""
|
| 136 |
+
# Validate inputs
|
| 137 |
+
if not predictions:
|
| 138 |
+
raise ValueError("predictions cannot be empty")
|
| 139 |
+
if len(predictions) != len(references):
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"predictions and references must have same length, "
|
| 142 |
+
f"got {len(predictions)} and {len(references)}"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
if task in SQUAD_TASKS:
|
| 146 |
+
results = self._call_squad(predictions, references)
|
| 147 |
+
elif task == TASK_DATE_ARITHMETIC:
|
| 148 |
+
results = self._compare_dates(predictions, references)
|
| 149 |
+
elif task == TASK_TIMEDIAL:
|
| 150 |
+
results = self._compute_timedial(predictions, references)
|
| 151 |
else:
|
| 152 |
raise ValueError(
|
| 153 |
+
f"Unknown task: {task}. Expected one of: {', '.join(VALID_TASKS)}"
|
| 154 |
)
|
| 155 |
|
| 156 |
+
if return_average:
|
| 157 |
+
return {key: sum(values) / len(values) for key, values in results.items()}
|
| 158 |
+
return results
|
| 159 |
+
|
| 160 |
@staticmethod
|
| 161 |
def _extract_answer(response: str) -> str | None:
|
| 162 |
"""Extract the answer from the response"""
|
|
|
|
| 186 |
if not text:
|
| 187 |
return set()
|
| 188 |
|
| 189 |
+
matches = SELECTED_OPTIONS_REGEX.findall(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
return set(matches)
|
| 191 |
|
| 192 |
def _call_squad(
|