File size: 1,939 Bytes
5c1bb37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import math
from pathlib import Path

from _common import candidate_by_id, pearson, read_json, read_jsonl, selected_ids, write_csv


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Compare selected-view marginal gains against oracle annotations.")
    parser.add_argument("--candidates", type=Path, required=True)
    parser.add_argument("--selected", type=Path, required=True)
    parser.add_argument("--output", type=Path, required=True)
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    candidates = read_jsonl(args.candidates)
    selected_doc = read_json(args.selected)
    by_id = candidate_by_id(candidates)
    selected_by_id = {str(row["candidate_id"]): row for row in selected_doc.get("selected_viewpoints", [])}

    oracle: list[float] = []
    scores: list[float] = []
    marginal: list[float] = []
    gaps: list[float] = []
    for cid in selected_ids(selected_doc):
        candidate = by_id.get(cid)
        selected = selected_by_id.get(cid)
        if not candidate or not selected or "oracle_gain" not in candidate:
            continue
        o = float(candidate["oracle_gain"])
        s = float(selected.get("score", 0.0))
        m = float(selected.get("marginal_gain", 0.0))
        oracle.append(o)
        scores.append(s)
        marginal.append(m)
        gaps.append(abs(o - m))

    row = {
        "scene_id": selected_doc.get("scene_id", "unknown"),
        "num_pairs": len(oracle),
        "score_oracle_pearson": pearson(scores, oracle),
        "marginal_oracle_pearson": pearson(marginal, oracle),
        "mean_abs_oracle_gap": sum(gaps) / len(gaps) if gaps else math.nan,
        "max_abs_oracle_gap": max(gaps) if gaps else math.nan,
    }
    write_csv(args.output, [row])
    print(f"Wrote {args.output}")


if __name__ == "__main__":
    main()