| from collections.abc import Callable, Sequence |
| from typing import Any |
|
|
| from core.model import DEFAULT_MODELS, build_prompt, parse_binary_prediction |
|
|
|
|
| ModelFn = Callable[[str, str], str] |
|
|
|
|
| def evaluate( |
| dataset: Sequence[dict[str, Any]], |
| model_fn: ModelFn, |
| model_ids: Sequence[str] | None = None, |
| ) -> list[dict[str, Any]]: |
| """Run models over the dataset and return only incorrect or unparsable cases.""" |
| failures: list[dict[str, Any]] = [] |
| selected_model_ids = list(model_ids or DEFAULT_MODELS) |
|
|
| for sample_id, sample in enumerate(dataset): |
| prompt = build_prompt(sample["x"]) |
| expected = int(sample["y"]) |
|
|
| for model_id in selected_model_ids: |
| raw_output = model_fn(prompt, model_id) |
| prediction = parse_binary_prediction(raw_output) |
| is_correct = prediction == expected |
|
|
| if not is_correct: |
| failures.append( |
| { |
| "sample_id": sample_id, |
| "x": sample["x"], |
| "y": expected, |
| "reasoning_type": sample["reasoning_type"], |
| "model_id": model_id, |
| "prompt": prompt, |
| "raw_output": raw_output, |
| "prediction": prediction, |
| "failure_kind": "parse_error" if prediction is None else "wrong_label", |
| } |
| ) |
|
|
| return failures |
|
|