File size: 1,483 Bytes
1b435f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from collections.abc import Callable, Sequence
from typing import Any

from core.model import DEFAULT_MODELS, build_prompt, parse_binary_prediction


ModelFn = Callable[[str, str], str]


def evaluate(
    dataset: Sequence[dict[str, Any]],
    model_fn: ModelFn,
    model_ids: Sequence[str] | None = None,
) -> list[dict[str, Any]]:
    """Run models over the dataset and return only incorrect or unparsable cases."""
    failures: list[dict[str, Any]] = []
    selected_model_ids = list(model_ids or DEFAULT_MODELS)

    for sample_id, sample in enumerate(dataset):
        prompt = build_prompt(sample["x"])
        expected = int(sample["y"])

        for model_id in selected_model_ids:
            raw_output = model_fn(prompt, model_id)
            prediction = parse_binary_prediction(raw_output)
            is_correct = prediction == expected

            if not is_correct:
                failures.append(
                    {
                        "sample_id": sample_id,
                        "x": sample["x"],
                        "y": expected,
                        "reasoning_type": sample["reasoning_type"],
                        "model_id": model_id,
                        "prompt": prompt,
                        "raw_output": raw_output,
                        "prediction": prediction,
                        "failure_kind": "parse_error" if prediction is None else "wrong_label",
                    }
                )

    return failures