aauss commited on
Commit
eca4af7
·
1 Parent(s): 71222fd

Improve type hints, add early input checks and format code.

Browse files
Files changed (2) hide show
  1. tests/test_metrics.py +7 -2
  2. timebench_eval.py +66 -28
tests/test_metrics.py CHANGED
@@ -77,13 +77,18 @@ from conftest import (
77
  ],
78
  )
79
  def test_eval(prediction, reference, task, expected_metrics):
80
- metrics = TimebenchEval()._compute([prediction], [reference], task)
 
 
81
  assert metrics == expected_metrics
82
 
83
 
84
  def test_eval_many():
85
  metrics = TimebenchEval()._compute(
86
- [PREDICTION_3, PREDICTION_4], ["unanswerable", "Cardiff City"], "MenatQA"
 
 
 
87
  )
88
  assert metrics == {
89
  "exact_match": [1, 1],
 
77
  ],
78
  )
79
  def test_eval(prediction, reference, task, expected_metrics):
80
+ metrics = TimebenchEval()._compute(
81
+ [prediction], [reference], task, return_average=False
82
+ )
83
  assert metrics == expected_metrics
84
 
85
 
86
  def test_eval_many():
87
  metrics = TimebenchEval()._compute(
88
+ [PREDICTION_3, PREDICTION_4],
89
+ ["unanswerable", "Cardiff City"],
90
+ "MenatQA",
91
+ return_average=False,
92
  )
93
  assert metrics == {
94
  "exact_match": [1, 1],
timebench_eval.py CHANGED
@@ -15,12 +15,35 @@
15
 
16
  import re
17
  from datetime import datetime
 
18
 
19
  import datasets
20
  import evaluate
21
  from dateutil import parser
22
  from dateutil.parser import ParserError
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  _CITATION = """\
25
  @software{abbood2026timebench_eval,
26
  title={TimeBench Eval},
@@ -44,16 +67,18 @@ Args:
44
  should contain the marker "Thus, the correct answer is:" followed by the answer.
45
  references: list of reference answer strings.
46
  task: the task type, one of "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", or "TimeDial".
 
 
47
  Returns:
48
- exact_match: list of exact match scores (0 or 1) for each prediction.
49
- f1: list of F1 scores for each prediction (for applicable tasks).
50
  Examples:
51
  >>> timebench_eval = evaluate.load("aauss/timebench_eval")
52
  >>> predictions = ["Let me think... Thus, the correct answer is: Aug, 1987."]
53
  >>> references = ["Aug, 1987"]
54
  >>> results = timebench_eval.compute(predictions=predictions, references=references, task="Date Arithmetic")
55
  >>> print(results)
56
- {'exact_match': [1]}
57
  """
58
 
59
 
@@ -65,7 +90,7 @@ class TimebenchEval(evaluate.Metric):
65
  super().__init__(*args, **kwargs)
66
  self.squad_metric = evaluate.load("squad")
67
 
68
- def _info(self):
69
  return evaluate.MetricInfo(
70
  module_type="metric",
71
  description=_DESCRIPTION,
@@ -78,13 +103,19 @@ class TimebenchEval(evaluate.Metric):
78
  }
79
  ),
80
  homepage="https://huggingface.co/spaces/aauss/timebench_eval",
81
- codebase_urls=["https://huggingface.co/spaces/aauss/timebench_eval/tree/main"],
 
 
82
  reference_urls=["https://huggingface.co/datasets/ulab-ai/Time-Bench"],
83
  )
84
 
85
  def _compute(
86
- self, predictions: list[str], references: list[str], task: str
87
- ) -> dict[str, list[float]]:
 
 
 
 
88
  """
89
  Compute evaluation metrics for the given predictions and references.
90
 
@@ -92,25 +123,40 @@ class TimebenchEval(evaluate.Metric):
92
  predictions: List of prediction strings to evaluate.
93
  references: List of reference strings to compare against.
94
  task: Task type, one of: "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", "TimeDial".
 
95
 
96
  Returns:
97
- Dictionary containing metric scores (exact_match and/or f1) as lists of floats.
 
 
 
 
 
98
  """
99
- if task in [
100
- "TempReason",
101
- "TimeQA",
102
- "MenatQA",
103
- ]:
104
- return self._call_squad(predictions, references)
105
- elif task == "Date Arithmetic":
106
- return self._compare_dates(predictions, references)
107
- elif task == "TimeDial":
108
- return self._compute_timedial(predictions, references)
 
 
 
 
 
109
  else:
110
  raise ValueError(
111
- f"Unknown task: {task}. Expected one of: TempReason, TimeQA, MenatQA, Date Arithmetic, TimeDial"
112
  )
113
 
 
 
 
 
114
  @staticmethod
115
  def _extract_answer(response: str) -> str | None:
116
  """Extract the answer from the response"""
@@ -140,15 +186,7 @@ class TimebenchEval(evaluate.Metric):
140
  if not text:
141
  return set()
142
 
143
- # Pattern matches option letters that appear:
144
- # 1. At word boundary followed by period, comma, space, &, or end: \b[A-D](?=[.\s,&]|$)
145
- # 2. This avoids matching letters inside words like "CAD" or "BAD"
146
-
147
- # Find all A, B, C, D that look like option selections
148
- # They should be at a word boundary and followed by typical delimiters
149
- pattern = r"\b([A-D])(?:\.|,|\s|&|$)"
150
-
151
- matches = re.findall(pattern, text)
152
  return set(matches)
153
 
154
  def _call_squad(
 
15
 
16
  import re
17
  from datetime import datetime
18
+ from typing import Literal, TypedDict
19
 
20
  import datasets
21
  import evaluate
22
  from dateutil import parser
23
  from dateutil.parser import ParserError
24
 
25
+ TASK_TEMPREASON = "TempReason"
26
+ TASK_TIMEQA = "TimeQA"
27
+ TASK_MENATQA = "MenatQA"
28
+ TASK_DATE_ARITHMETIC = "Date Arithmetic"
29
+ TASK_TIMEDIAL = "TimeDial"
30
+ VALID_TASKS = frozenset(
31
+ {TASK_TEMPREASON, TASK_TIMEQA, TASK_MENATQA, TASK_DATE_ARITHMETIC, TASK_TIMEDIAL}
32
+ )
33
+
34
+ SQUAD_TASKS = frozenset({TASK_TEMPREASON, TASK_TIMEQA, TASK_MENATQA})
35
+
36
+ TaskType = Literal["TempReason", "TimeQA", "MenatQA", "Date Arithmetic", "TimeDial"]
37
+
38
+ SELECTED_OPTIONS_PATTERN = r"\b([A-D])(?:\.|,|\s|&|$)"
39
+ SELECTED_OPTIONS_REGEX = re.compile(SELECTED_OPTIONS_PATTERN)
40
+
41
+
42
+ class TimebenchResult(TypedDict, total=False):
43
+ exact_match: float | list[float]
44
+ f1: float | list[float]
45
+
46
+
47
  _CITATION = """\
48
  @software{abbood2026timebench_eval,
49
  title={TimeBench Eval},
 
67
  should contain the marker "Thus, the correct answer is:" followed by the answer.
68
  references: list of reference answer strings.
69
  task: the task type, one of "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", or "TimeDial".
70
+ return_average: if True (default), returns average scores as floats.
71
+ If False, returns a list of scores for each sample.
72
  Returns:
73
+ exact_match: average or list of exact match scores for each prediction.
74
+ f1: average or list of F1 scores for each prediction (for applicable tasks).
75
  Examples:
76
  >>> timebench_eval = evaluate.load("aauss/timebench_eval")
77
  >>> predictions = ["Let me think... Thus, the correct answer is: Aug, 1987."]
78
  >>> references = ["Aug, 1987"]
79
  >>> results = timebench_eval.compute(predictions=predictions, references=references, task="Date Arithmetic")
80
  >>> print(results)
81
+ {'exact_match': 1.0}
82
  """
83
 
84
 
 
90
  super().__init__(*args, **kwargs)
91
  self.squad_metric = evaluate.load("squad")
92
 
93
+ def _info(self) -> evaluate.MetricInfo:
94
  return evaluate.MetricInfo(
95
  module_type="metric",
96
  description=_DESCRIPTION,
 
103
  }
104
  ),
105
  homepage="https://huggingface.co/spaces/aauss/timebench_eval",
106
+ codebase_urls=[
107
+ "https://huggingface.co/spaces/aauss/timebench_eval/tree/main"
108
+ ],
109
  reference_urls=["https://huggingface.co/datasets/ulab-ai/Time-Bench"],
110
  )
111
 
112
  def _compute(
113
+ self,
114
+ predictions: list[str],
115
+ references: list[str],
116
+ task: TaskType,
117
+ return_average: bool = True,
118
+ ) -> TimebenchResult:
119
  """
120
  Compute evaluation metrics for the given predictions and references.
121
 
 
123
  predictions: List of prediction strings to evaluate.
124
  references: List of reference strings to compare against.
125
  task: Task type, one of: "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", "TimeDial".
126
+ return_average: If True, returns average scores; if False, returns per-sample scores.
127
 
128
  Returns:
129
+ Dictionary containing metric scores (exact_match and/or f1) as floats or lists.
130
+
131
+ Raises:
132
+ ValueError: If predictions is empty.
133
+ ValueError: If predictions and references have different lengths.
134
+ ValueError: If task is not a valid task type.
135
  """
136
+ # Validate inputs
137
+ if not predictions:
138
+ raise ValueError("predictions cannot be empty")
139
+ if len(predictions) != len(references):
140
+ raise ValueError(
141
+ f"predictions and references must have same length, "
142
+ f"got {len(predictions)} and {len(references)}"
143
+ )
144
+
145
+ if task in SQUAD_TASKS:
146
+ results = self._call_squad(predictions, references)
147
+ elif task == TASK_DATE_ARITHMETIC:
148
+ results = self._compare_dates(predictions, references)
149
+ elif task == TASK_TIMEDIAL:
150
+ results = self._compute_timedial(predictions, references)
151
  else:
152
  raise ValueError(
153
+ f"Unknown task: {task}. Expected one of: {', '.join(VALID_TASKS)}"
154
  )
155
 
156
+ if return_average:
157
+ return {key: sum(values) / len(values) for key, values in results.items()}
158
+ return results
159
+
160
  @staticmethod
161
  def _extract_answer(response: str) -> str | None:
162
  """Extract the answer from the response"""
 
186
  if not text:
187
  return set()
188
 
189
+ matches = SELECTED_OPTIONS_REGEX.findall(text)
 
 
 
 
 
 
 
 
190
  return set(matches)
191
 
192
  def _call_squad(