# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Evaluation metric for the TimeBench temporal reasoning benchmark.""" import re from datetime import datetime from typing import Literal, TypedDict import datasets import evaluate from dateutil import parser from dateutil.parser import ParserError TASK_TEMPREASON = "TempReason" TASK_TIMEQA = "TimeQA" TASK_MENATQA = "MenatQA" TASK_DATE_ARITHMETIC = "Date Arithmetic" TASK_TIMEDIAL = "TimeDial" VALID_TASKS = frozenset( {TASK_TEMPREASON, TASK_TIMEQA, TASK_MENATQA, TASK_DATE_ARITHMETIC, TASK_TIMEDIAL} ) SQUAD_TASKS = frozenset({TASK_TEMPREASON, TASK_TIMEQA, TASK_MENATQA}) TaskType = Literal["TempReason", "TimeQA", "MenatQA", "Date Arithmetic", "TimeDial"] SELECTED_OPTIONS_PATTERN = r"\b([A-D])(?:\.|,|\s|&|$)" SELECTED_OPTIONS_REGEX = re.compile(SELECTED_OPTIONS_PATTERN) class TimebenchResult(TypedDict, total=False): exact_match: float | list[float] f1: float | list[float] _CITATION = """\ @software{abbood2026timebench_eval, title={TimeBench Eval}, author={Abbood, Auss}, year={2026}, url={https://huggingface.co/spaces/aauss/timebench_eval} } """ _DESCRIPTION = """\ Evaluation metric for the TimeBench benchmark, which assesses temporal reasoning abilities in large language models. Supports multiple task types including TempReason, TimeQA, MenatQA, Date Arithmetic, and TimeDial. """ _KWARGS_DESCRIPTION = """ Calculates evaluation metrics for temporal reasoning tasks. Args: predictions: list of prediction strings from the model. Each prediction should contain the marker "Thus, the correct answer is:" followed by the answer. references: list of reference answer strings. task: the task type, one of "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", or "TimeDial". return_average: if True (default), returns average scores as floats. If False, returns a list of scores for each sample. Returns: exact_match: average or list of exact match scores for each prediction. f1: average or list of F1 scores for each prediction (for applicable tasks). Examples: >>> timebench_eval = evaluate.load("aauss/timebench_eval") >>> predictions = ["Let me think... Thus, the correct answer is: Aug, 1987."] >>> references = ["Aug, 1987"] >>> results = timebench_eval.compute(predictions=predictions, references=references, task="Date Arithmetic") >>> print(results) {'exact_match': 1.0} """ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class TimebenchEval(evaluate.Metric): """Evaluation metric for TimeBench temporal reasoning tasks.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.squad_metric = evaluate.load("squad") def _info(self) -> evaluate.MetricInfo: return evaluate.MetricInfo( module_type="metric", description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features( { "predictions": datasets.Value("string"), "references": datasets.Value("string"), } ), homepage="https://huggingface.co/spaces/aauss/timebench_eval", codebase_urls=[ "https://huggingface.co/spaces/aauss/timebench_eval/tree/main" ], reference_urls=["https://huggingface.co/datasets/ulab-ai/Time-Bench"], ) def _compute( self, predictions: list[str], references: list[str], task: TaskType, return_average: bool = True, ) -> TimebenchResult: """ Compute evaluation metrics for the given predictions and references. Args: predictions: List of prediction strings to evaluate. references: List of reference strings to compare against. task: Task type, one of: "TempReason", "TimeQA", "MenatQA", "Date Arithmetic", "TimeDial". return_average: If True, returns average scores; if False, returns per-sample scores. Returns: Dictionary containing metric scores (exact_match and/or f1) as floats or lists. Raises: ValueError: If predictions is empty. ValueError: If predictions and references have different lengths. ValueError: If task is not a valid task type. """ # Validate inputs if not predictions: raise ValueError("predictions cannot be empty") if len(predictions) != len(references): raise ValueError( f"predictions and references must have same length, " f"got {len(predictions)} and {len(references)}" ) if task in SQUAD_TASKS: results = self._call_squad(predictions, references) elif task == TASK_DATE_ARITHMETIC: results = self._compare_dates(predictions, references) elif task == TASK_TIMEDIAL: results = self._compute_timedial(predictions, references) else: raise ValueError( f"Unknown task: {task}. Expected one of: {', '.join(VALID_TASKS)}" ) if return_average: return {key: sum(values) / len(values) for key, values in results.items()} return results @staticmethod def _extract_answer(response: str) -> str | None: """Extract the answer from the response""" marker = "Thus, the correct answer is:" if marker not in response: return None answer = response.split(marker)[-1] # Take only the first line (stops at newlines if model continues) answer = answer.strip().split("\n")[0] answer = answer.rstrip(".!?").strip() if "unanswerable" in answer.lower(): return "unanswerable" return answer or None def _extract_selected_options(self, text: str) -> set[str]: """ Extract selected option letters (A, B, C, D) from various formats: - "B, C" - "B and C" - "B & C" - "B && C" - "B. No more than ten minutes && C. No more than five minutes" - "Options B and C" - "The answer is B, C" """ if not text: return set() matches = SELECTED_OPTIONS_REGEX.findall(text) return set(matches) def _call_squad( self, predictions: list[str], references: list[str] ) -> dict[str, list[float]]: """ Compute SQuAD metrics (Exact Matchand F1) for predictions and references. Args: predictions: List of prediction strings. references: List of reference answer strings. Returns: Dictionary with "exact_match" and "f1" keys, each containing a list of scores. """ exact_matches = [] f1_scores = [] for i, (pred, ref) in enumerate(zip(predictions, references)): formatted_pred = [ {"id": "0", "prediction_text": self._extract_answer(pred) or ""} ] formatted_ref = [ {"id": "0", "answers": {"text": [ref], "answer_start": [0]}} ] results = self.squad_metric.compute( predictions=formatted_pred, references=formatted_ref ) exact_matches.append(results["exact_match"] / 100) f1_scores.append(results["f1"] / 100) return { "exact_match": exact_matches, "f1": f1_scores, } def _compare_dates( self, predictions: list[str], references: list[str] ) -> dict[str, list[int]]: """ Parses and compares dates in predictions and references for exact match. Args: predictions: List of prediction strings containing dates. references: List of reference date strings. Returns: Dictionary with "exact_match" key containing a list of 0/1 scores. """ predictions = [ self._parse_historical_date(self._extract_answer(pred)) for pred in predictions ] references = [self._parse_historical_date(ref) for ref in references] return { "exact_match": [ 1 if pred == ref else 0 for pred, ref in zip(predictions, references) ], } def _compute_timedial( self, predictions: list[str], references: list[str] ) -> dict[str, list[float]]: """ Compute TimeDial metrics (Exact Match and F1) using set-based comparison of selected options. Args: predictions: List of prediction strings. references: List of reference strings containing selected options. Returns: Dictionary with "exact_match" and "f1" keys, each containing a list of scores. """ exact_matches = [] f1_scores = [] for pred, ref in zip(predictions, references): pred_answer = self._extract_answer(pred) # Get text after marker pred_options = ( self._extract_selected_options(pred_answer) if pred_answer else set() ) ref_options = self._extract_selected_options(ref) # Exact match: sets must be identical em = 1 if pred_options == ref_options else 0 exact_matches.append(em) # F1: set-based if not pred_options and not ref_options: f1 = 1.0 # Both empty = perfect match elif not pred_options or not ref_options: f1 = 0.0 # One empty, one not else: tp = len(pred_options & ref_options) precision = tp / len(pred_options) recall = tp / len(ref_options) f1 = ( 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 ) f1_scores.append(f1) return {"exact_match": exact_matches, "f1": f1_scores} @staticmethod def _parse_historical_date(date_str: str | None) -> datetime | None: """ Parse a date string and return a datetime object with day set to 1. Args: date_str: String representation of a date, or None. Returns: datetime object with day set to 1, or None if parsing fails or input is None. """ if date_str is None: return None try: return parser.parse(date_str).replace(day=1) except (ParserError, OverflowError): return None