Spaces:
Running
Running
| """Scoring logic for EgoMemReason. | |
| Pure stdlib — no Gradio, no HF imports. Returns a flat metrics dict. | |
| Raises ValueError with per-example messages on validation failure. | |
| """ | |
| import json | |
| from collections import defaultdict | |
| # Order matches the leaderboard column order. | |
| QUERY_TYPES = [ | |
| "Cumulative State Tracking", | |
| "Temporal Counting", | |
| "Event Ordering", | |
| "Event Linking", | |
| "Spatial Preference", | |
| "Activity Pattern", | |
| ] | |
| def _load(path): | |
| with open(path) as f: | |
| return json.load(f) | |
| def _build_gt(ann): | |
| """Returns {example_id: (correct_letter, query_type, valid_option_letters)}. | |
| Questions have 4-10 options (letters up to J), so the valid answer set | |
| is per-question, not a fixed A-D. | |
| """ | |
| samples = ann["samples"] if isinstance(ann, dict) else ann | |
| gt = {} | |
| for s in samples: | |
| eid = s["example_id"] | |
| opts = {str(k).strip().upper() for k in s["options"].keys()} | |
| gt[eid] = (s["correct_answer"].strip().upper(), s["query_type"], opts) | |
| return gt | |
| def _validate(preds, gt): | |
| if not isinstance(preds, list): | |
| raise ValueError("Submission must be a JSON list of objects.") | |
| errors = [] | |
| seen = set() | |
| for i, item in enumerate(preds): | |
| if not isinstance(item, dict): | |
| errors.append(f"item {i}: not a JSON object") | |
| continue | |
| eid = item.get("example_id") | |
| ans = item.get("predicted_answer") | |
| if not isinstance(eid, int): | |
| errors.append(f"item {i}: 'example_id' must be an int, got {type(eid).__name__}") | |
| continue | |
| if eid in seen: | |
| errors.append(f"duplicate example_id: {eid}") | |
| seen.add(eid) | |
| if eid not in gt: | |
| errors.append(f"unknown example_id: {eid}") | |
| continue | |
| valid = gt[eid][2] | |
| if not isinstance(ans, str) or ans.strip().upper() not in valid: | |
| errors.append( | |
| f"example_id {eid}: 'predicted_answer' must be one of " | |
| f"{'/'.join(sorted(valid))}, got {ans!r}" | |
| ) | |
| missing = set(gt) - seen | |
| if missing: | |
| errors.append( | |
| f"missing {len(missing)} example_ids (e.g. {sorted(missing)[:5]}); " | |
| f"submission must cover all {len(gt)} questions" | |
| ) | |
| if errors: | |
| msg = "Submission validation failed:\n - " + "\n - ".join(errors[:20]) | |
| if len(errors) > 20: | |
| msg += f"\n - ... and {len(errors) - 20} more error(s)" | |
| raise ValueError(msg) | |
| def _score(preds, gt): | |
| correct_total = 0 | |
| count_by_qt = defaultdict(int) | |
| correct_by_qt = defaultdict(int) | |
| for _eid, (_gt_ans, qt, _opts) in gt.items(): | |
| count_by_qt[qt] += 1 | |
| for item in preds: | |
| eid = item["example_id"] | |
| ans = item["predicted_answer"].strip().upper() | |
| gt_ans, qt, _opts = gt[eid] | |
| if ans == gt_ans: | |
| correct_total += 1 | |
| correct_by_qt[qt] += 1 | |
| metrics = {} | |
| for qt in QUERY_TYPES: | |
| n = count_by_qt.get(qt, 0) | |
| metrics[qt] = round(100.0 * correct_by_qt[qt] / n, 2) if n else 0.0 | |
| metrics["Overall"] = round(100.0 * correct_total / len(gt), 2) | |
| return metrics | |
| def score_submission(submission_path, annotation_path="annotations_private.json"): | |
| """Returns {"Cumulative State Tracking": ..., ..., "Overall": ...} as percentages.""" | |
| gt = _build_gt(_load(annotation_path)) | |
| preds = _load(submission_path) | |
| _validate(preds, gt) | |
| return _score(preds, gt) | |
| if __name__ == "__main__": | |
| import argparse, pprint | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--annotation", default="annotations_private.json") | |
| p.add_argument("--submission", required=True) | |
| args = p.parse_args() | |
| pprint.pp(score_submission(args.submission, args.annotation)) | |