File size: 1,483 Bytes
1b435f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | from collections.abc import Callable, Sequence
from typing import Any
from core.model import DEFAULT_MODELS, build_prompt, parse_binary_prediction
ModelFn = Callable[[str, str], str]
def evaluate(
dataset: Sequence[dict[str, Any]],
model_fn: ModelFn,
model_ids: Sequence[str] | None = None,
) -> list[dict[str, Any]]:
"""Run models over the dataset and return only incorrect or unparsable cases."""
failures: list[dict[str, Any]] = []
selected_model_ids = list(model_ids or DEFAULT_MODELS)
for sample_id, sample in enumerate(dataset):
prompt = build_prompt(sample["x"])
expected = int(sample["y"])
for model_id in selected_model_ids:
raw_output = model_fn(prompt, model_id)
prediction = parse_binary_prediction(raw_output)
is_correct = prediction == expected
if not is_correct:
failures.append(
{
"sample_id": sample_id,
"x": sample["x"],
"y": expected,
"reasoning_type": sample["reasoning_type"],
"model_id": model_id,
"prompt": prompt,
"raw_output": raw_output,
"prediction": prediction,
"failure_kind": "parse_error" if prediction is None else "wrong_label",
}
)
return failures
|