EgoMemReason / evaluator.py
Ziyang Wang
initial Space
1bf5b23
"""Scoring logic for EgoMemReason.
Pure stdlib — no Gradio, no HF imports. Returns a flat metrics dict.
Raises ValueError with per-example messages on validation failure.
"""
import json
from collections import defaultdict
# Order matches the leaderboard column order.
QUERY_TYPES = [
"Cumulative State Tracking",
"Temporal Counting",
"Event Ordering",
"Event Linking",
"Spatial Preference",
"Activity Pattern",
]
def _load(path):
with open(path) as f:
return json.load(f)
def _build_gt(ann):
"""Returns {example_id: (correct_letter, query_type, valid_option_letters)}.
Questions have 4-10 options (letters up to J), so the valid answer set
is per-question, not a fixed A-D.
"""
samples = ann["samples"] if isinstance(ann, dict) else ann
gt = {}
for s in samples:
eid = s["example_id"]
opts = {str(k).strip().upper() for k in s["options"].keys()}
gt[eid] = (s["correct_answer"].strip().upper(), s["query_type"], opts)
return gt
def _validate(preds, gt):
if not isinstance(preds, list):
raise ValueError("Submission must be a JSON list of objects.")
errors = []
seen = set()
for i, item in enumerate(preds):
if not isinstance(item, dict):
errors.append(f"item {i}: not a JSON object")
continue
eid = item.get("example_id")
ans = item.get("predicted_answer")
if not isinstance(eid, int):
errors.append(f"item {i}: 'example_id' must be an int, got {type(eid).__name__}")
continue
if eid in seen:
errors.append(f"duplicate example_id: {eid}")
seen.add(eid)
if eid not in gt:
errors.append(f"unknown example_id: {eid}")
continue
valid = gt[eid][2]
if not isinstance(ans, str) or ans.strip().upper() not in valid:
errors.append(
f"example_id {eid}: 'predicted_answer' must be one of "
f"{'/'.join(sorted(valid))}, got {ans!r}"
)
missing = set(gt) - seen
if missing:
errors.append(
f"missing {len(missing)} example_ids (e.g. {sorted(missing)[:5]}); "
f"submission must cover all {len(gt)} questions"
)
if errors:
msg = "Submission validation failed:\n - " + "\n - ".join(errors[:20])
if len(errors) > 20:
msg += f"\n - ... and {len(errors) - 20} more error(s)"
raise ValueError(msg)
def _score(preds, gt):
correct_total = 0
count_by_qt = defaultdict(int)
correct_by_qt = defaultdict(int)
for _eid, (_gt_ans, qt, _opts) in gt.items():
count_by_qt[qt] += 1
for item in preds:
eid = item["example_id"]
ans = item["predicted_answer"].strip().upper()
gt_ans, qt, _opts = gt[eid]
if ans == gt_ans:
correct_total += 1
correct_by_qt[qt] += 1
metrics = {}
for qt in QUERY_TYPES:
n = count_by_qt.get(qt, 0)
metrics[qt] = round(100.0 * correct_by_qt[qt] / n, 2) if n else 0.0
metrics["Overall"] = round(100.0 * correct_total / len(gt), 2)
return metrics
def score_submission(submission_path, annotation_path="annotations_private.json"):
"""Returns {"Cumulative State Tracking": ..., ..., "Overall": ...} as percentages."""
gt = _build_gt(_load(annotation_path))
preds = _load(submission_path)
_validate(preds, gt)
return _score(preds, gt)
if __name__ == "__main__":
import argparse, pprint
p = argparse.ArgumentParser()
p.add_argument("--annotation", default="annotations_private.json")
p.add_argument("--submission", required=True)
args = p.parse_args()
pprint.pp(score_submission(args.submission, args.annotation))