File size: 3,802 Bytes
1bf5b23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Scoring logic for EgoMemReason.

Pure stdlib — no Gradio, no HF imports. Returns a flat metrics dict.
Raises ValueError with per-example messages on validation failure.
"""

import json
from collections import defaultdict

# Order matches the leaderboard column order.
QUERY_TYPES = [
    "Cumulative State Tracking",
    "Temporal Counting",
    "Event Ordering",
    "Event Linking",
    "Spatial Preference",
    "Activity Pattern",
]


def _load(path):
    with open(path) as f:
        return json.load(f)


def _build_gt(ann):
    """Returns {example_id: (correct_letter, query_type, valid_option_letters)}.

    Questions have 4-10 options (letters up to J), so the valid answer set
    is per-question, not a fixed A-D.
    """
    samples = ann["samples"] if isinstance(ann, dict) else ann
    gt = {}
    for s in samples:
        eid = s["example_id"]
        opts = {str(k).strip().upper() for k in s["options"].keys()}
        gt[eid] = (s["correct_answer"].strip().upper(), s["query_type"], opts)
    return gt


def _validate(preds, gt):
    if not isinstance(preds, list):
        raise ValueError("Submission must be a JSON list of objects.")

    errors = []
    seen = set()
    for i, item in enumerate(preds):
        if not isinstance(item, dict):
            errors.append(f"item {i}: not a JSON object")
            continue
        eid = item.get("example_id")
        ans = item.get("predicted_answer")
        if not isinstance(eid, int):
            errors.append(f"item {i}: 'example_id' must be an int, got {type(eid).__name__}")
            continue
        if eid in seen:
            errors.append(f"duplicate example_id: {eid}")
        seen.add(eid)
        if eid not in gt:
            errors.append(f"unknown example_id: {eid}")
            continue
        valid = gt[eid][2]
        if not isinstance(ans, str) or ans.strip().upper() not in valid:
            errors.append(
                f"example_id {eid}: 'predicted_answer' must be one of "
                f"{'/'.join(sorted(valid))}, got {ans!r}"
            )

    missing = set(gt) - seen
    if missing:
        errors.append(
            f"missing {len(missing)} example_ids (e.g. {sorted(missing)[:5]}); "
            f"submission must cover all {len(gt)} questions"
        )

    if errors:
        msg = "Submission validation failed:\n  - " + "\n  - ".join(errors[:20])
        if len(errors) > 20:
            msg += f"\n  - ... and {len(errors) - 20} more error(s)"
        raise ValueError(msg)


def _score(preds, gt):
    correct_total = 0
    count_by_qt = defaultdict(int)
    correct_by_qt = defaultdict(int)

    for _eid, (_gt_ans, qt, _opts) in gt.items():
        count_by_qt[qt] += 1

    for item in preds:
        eid = item["example_id"]
        ans = item["predicted_answer"].strip().upper()
        gt_ans, qt, _opts = gt[eid]
        if ans == gt_ans:
            correct_total += 1
            correct_by_qt[qt] += 1

    metrics = {}
    for qt in QUERY_TYPES:
        n = count_by_qt.get(qt, 0)
        metrics[qt] = round(100.0 * correct_by_qt[qt] / n, 2) if n else 0.0
    metrics["Overall"] = round(100.0 * correct_total / len(gt), 2)
    return metrics


def score_submission(submission_path, annotation_path="annotations_private.json"):
    """Returns {"Cumulative State Tracking": ..., ..., "Overall": ...} as percentages."""
    gt = _build_gt(_load(annotation_path))
    preds = _load(submission_path)
    _validate(preds, gt)
    return _score(preds, gt)


if __name__ == "__main__":
    import argparse, pprint
    p = argparse.ArgumentParser()
    p.add_argument("--annotation", default="annotations_private.json")
    p.add_argument("--submission", required=True)
    args = p.parse_args()
    pprint.pp(score_submission(args.submission, args.annotation))