File size: 3,925 Bytes
b2c2640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
"""Merge sharded eval_results.shard*.jsonl files and recompute metrics.

Usage:
    python merge_shards.py --bench videomme \
        --label-dir /home/ubuntu/eval_results/videomme/vmme_minicpmo_4_5

The script finds all `eval_results.shard*.jsonl` under `--label-dir`,
concatenates them into `eval_results.jsonl` (deduping by a bench-specific
primary key), then re-runs the bench's `compute_metrics` + `print_summary`.
Final outputs: `eval_results.jsonl`, `metrics.json`, `summary.txt`.
"""
from __future__ import annotations

import _common  # noqa: F401

import argparse
import contextlib
import io
import json
import sys
from pathlib import Path


# Primary key per bench (must match the field written by each eval script).
PK = {
    "videomme": "question_id",
    "lvbench": "uid",
    "worldsense": "question_id",
    "daily_omni": "question_id",
    "dpo_sync": "video",
    "vggsoundsync": "uid",
}

# Extra label used when printing the summary
LABEL_HINT = {
    "videomme": "Video-MME",
    "lvbench": "LVBench",
    "worldsense": "WorldSense",
    "daily_omni": "Daily-Omni",
    "dpo_sync": "Sync",
    "vggsoundsync": "VGGSoundSync",
}


def main() -> int:
    p = argparse.ArgumentParser()
    p.add_argument("--bench", required=True,
                   choices=list(PK.keys()),
                   help="Which benchmark this label-dir belongs to.")
    p.add_argument("--label-dir", type=Path, required=True,
                   help="Eval output dir containing eval_results.shard*.jsonl.")
    args = p.parse_args()

    ch = _common.ch(args.bench)
    pk = PK[args.bench]

    shard_files = sorted(args.label_dir.glob("eval_results.shard*.jsonl"))
    if not shard_files:
        print(f"[merge] ERROR: no eval_results.shard*.jsonl in {args.label_dir}",
              file=sys.stderr)
        return 1

    print(f"[merge] Found {len(shard_files)} shard file(s):")
    for sf in shard_files:
        print(f"         - {sf.name}")

    merged_path = args.label_dir / "eval_results.jsonl"
    all_results = []
    seen: set = set()
    n_dup = 0
    with open(merged_path, "w", encoding="utf-8") as out:
        for sf in shard_files:
            with open(sf) as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    obj = json.loads(line)
                    key = obj.get(pk)
                    if key in seen:
                        n_dup += 1
                        continue
                    seen.add(key)
                    out.write(line + "\n")
                    all_results.append(obj)

    print(f"[merge] Merged {len(all_results)} unique results "
          f"({n_dup} duplicates skipped) -> {merged_path}")

    metrics = ch.compute_metrics(all_results)
    # Preserve eval_config from any shard if present
    for sf in shard_files:
        try:
            with open(sf) as f:
                first = f.readline().strip()
            if first:
                obj = json.loads(first)
                if "eval_config" in obj:
                    metrics["eval_config"] = obj["eval_config"]
                    break
        except Exception:
            pass

    metrics_json = args.label_dir / "metrics.json"
    summary_txt = args.label_dir / "summary.txt"

    with open(metrics_json, "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)

    label = args.label_dir.name
    ch.print_summary(metrics, label)

    buf = io.StringIO()
    with contextlib.redirect_stdout(buf):
        ch.print_summary(metrics, label)
    with open(summary_txt, "w", encoding="utf-8") as f:
        f.write(buf.getvalue())

    print(f"\n[merge] Done.")
    print(f"  Results:  {merged_path}")
    print(f"  Metrics:  {metrics_json}")
    print(f"  Summary:  {summary_txt}")
    return 0


if __name__ == "__main__":
    sys.exit(main())