Spaces:
Running
Running
File size: 3,802 Bytes
1bf5b23 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | """Scoring logic for EgoMemReason.
Pure stdlib — no Gradio, no HF imports. Returns a flat metrics dict.
Raises ValueError with per-example messages on validation failure.
"""
import json
from collections import defaultdict
# Order matches the leaderboard column order.
QUERY_TYPES = [
"Cumulative State Tracking",
"Temporal Counting",
"Event Ordering",
"Event Linking",
"Spatial Preference",
"Activity Pattern",
]
def _load(path):
with open(path) as f:
return json.load(f)
def _build_gt(ann):
"""Returns {example_id: (correct_letter, query_type, valid_option_letters)}.
Questions have 4-10 options (letters up to J), so the valid answer set
is per-question, not a fixed A-D.
"""
samples = ann["samples"] if isinstance(ann, dict) else ann
gt = {}
for s in samples:
eid = s["example_id"]
opts = {str(k).strip().upper() for k in s["options"].keys()}
gt[eid] = (s["correct_answer"].strip().upper(), s["query_type"], opts)
return gt
def _validate(preds, gt):
if not isinstance(preds, list):
raise ValueError("Submission must be a JSON list of objects.")
errors = []
seen = set()
for i, item in enumerate(preds):
if not isinstance(item, dict):
errors.append(f"item {i}: not a JSON object")
continue
eid = item.get("example_id")
ans = item.get("predicted_answer")
if not isinstance(eid, int):
errors.append(f"item {i}: 'example_id' must be an int, got {type(eid).__name__}")
continue
if eid in seen:
errors.append(f"duplicate example_id: {eid}")
seen.add(eid)
if eid not in gt:
errors.append(f"unknown example_id: {eid}")
continue
valid = gt[eid][2]
if not isinstance(ans, str) or ans.strip().upper() not in valid:
errors.append(
f"example_id {eid}: 'predicted_answer' must be one of "
f"{'/'.join(sorted(valid))}, got {ans!r}"
)
missing = set(gt) - seen
if missing:
errors.append(
f"missing {len(missing)} example_ids (e.g. {sorted(missing)[:5]}); "
f"submission must cover all {len(gt)} questions"
)
if errors:
msg = "Submission validation failed:\n - " + "\n - ".join(errors[:20])
if len(errors) > 20:
msg += f"\n - ... and {len(errors) - 20} more error(s)"
raise ValueError(msg)
def _score(preds, gt):
correct_total = 0
count_by_qt = defaultdict(int)
correct_by_qt = defaultdict(int)
for _eid, (_gt_ans, qt, _opts) in gt.items():
count_by_qt[qt] += 1
for item in preds:
eid = item["example_id"]
ans = item["predicted_answer"].strip().upper()
gt_ans, qt, _opts = gt[eid]
if ans == gt_ans:
correct_total += 1
correct_by_qt[qt] += 1
metrics = {}
for qt in QUERY_TYPES:
n = count_by_qt.get(qt, 0)
metrics[qt] = round(100.0 * correct_by_qt[qt] / n, 2) if n else 0.0
metrics["Overall"] = round(100.0 * correct_total / len(gt), 2)
return metrics
def score_submission(submission_path, annotation_path="annotations_private.json"):
"""Returns {"Cumulative State Tracking": ..., ..., "Overall": ...} as percentages."""
gt = _build_gt(_load(annotation_path))
preds = _load(submission_path)
_validate(preds, gt)
return _score(preds, gt)
if __name__ == "__main__":
import argparse, pprint
p = argparse.ArgumentParser()
p.add_argument("--annotation", default="annotations_private.json")
p.add_argument("--submission", required=True)
args = p.parse_args()
pprint.pp(score_submission(args.submission, args.annotation))
|