"""Scoring logic for EgoMemReason. Pure stdlib — no Gradio, no HF imports. Returns a flat metrics dict. Raises ValueError with per-example messages on validation failure. """ import json from collections import defaultdict # Order matches the leaderboard column order. QUERY_TYPES = [ "Cumulative State Tracking", "Temporal Counting", "Event Ordering", "Event Linking", "Spatial Preference", "Activity Pattern", ] def _load(path): with open(path) as f: return json.load(f) def _build_gt(ann): """Returns {example_id: (correct_letter, query_type, valid_option_letters)}. Questions have 4-10 options (letters up to J), so the valid answer set is per-question, not a fixed A-D. """ samples = ann["samples"] if isinstance(ann, dict) else ann gt = {} for s in samples: eid = s["example_id"] opts = {str(k).strip().upper() for k in s["options"].keys()} gt[eid] = (s["correct_answer"].strip().upper(), s["query_type"], opts) return gt def _validate(preds, gt): if not isinstance(preds, list): raise ValueError("Submission must be a JSON list of objects.") errors = [] seen = set() for i, item in enumerate(preds): if not isinstance(item, dict): errors.append(f"item {i}: not a JSON object") continue eid = item.get("example_id") ans = item.get("predicted_answer") if not isinstance(eid, int): errors.append(f"item {i}: 'example_id' must be an int, got {type(eid).__name__}") continue if eid in seen: errors.append(f"duplicate example_id: {eid}") seen.add(eid) if eid not in gt: errors.append(f"unknown example_id: {eid}") continue valid = gt[eid][2] if not isinstance(ans, str) or ans.strip().upper() not in valid: errors.append( f"example_id {eid}: 'predicted_answer' must be one of " f"{'/'.join(sorted(valid))}, got {ans!r}" ) missing = set(gt) - seen if missing: errors.append( f"missing {len(missing)} example_ids (e.g. {sorted(missing)[:5]}); " f"submission must cover all {len(gt)} questions" ) if errors: msg = "Submission validation failed:\n - " + "\n - ".join(errors[:20]) if len(errors) > 20: msg += f"\n - ... and {len(errors) - 20} more error(s)" raise ValueError(msg) def _score(preds, gt): correct_total = 0 count_by_qt = defaultdict(int) correct_by_qt = defaultdict(int) for _eid, (_gt_ans, qt, _opts) in gt.items(): count_by_qt[qt] += 1 for item in preds: eid = item["example_id"] ans = item["predicted_answer"].strip().upper() gt_ans, qt, _opts = gt[eid] if ans == gt_ans: correct_total += 1 correct_by_qt[qt] += 1 metrics = {} for qt in QUERY_TYPES: n = count_by_qt.get(qt, 0) metrics[qt] = round(100.0 * correct_by_qt[qt] / n, 2) if n else 0.0 metrics["Overall"] = round(100.0 * correct_total / len(gt), 2) return metrics def score_submission(submission_path, annotation_path="annotations_private.json"): """Returns {"Cumulative State Tracking": ..., ..., "Overall": ...} as percentages.""" gt = _build_gt(_load(annotation_path)) preds = _load(submission_path) _validate(preds, gt) return _score(preds, gt) if __name__ == "__main__": import argparse, pprint p = argparse.ArgumentParser() p.add_argument("--annotation", default="annotations_private.json") p.add_argument("--submission", required=True) args = p.parse_args() pprint.pp(score_submission(args.submission, args.annotation))