"""Final verification of full 835-ep eval metrics — cross-check from multiple sources.""" import gzip import json import re import glob import os from collections import Counter EVALS = "/imotion-vepfs/imotion_zxnet/workspace_wei/ry-dynamics-vln-ryworld/checkpoints/ryworld_stage1_discrete_20260510_025317_cw07_iw25/evals" # === SOURCE 1: per-shard meta.json (vlnverse_emr's own aggregate per shard) === print("=" * 90) print("SOURCE 1: per-shard meta.json (vlnverse_emr 自带 aggregate)") print("=" * 90) meta_files = sorted(glob.glob(f"{EVALS}/*full835*/meta.json")) shard_aggs = [] for mf in meta_files: m = json.load(open(mf)) agg = m.get("results_aggregate", {}) if agg.get("n_total"): name = mf.split("/")[-2].split("full835_thr095_")[1] if "full835_thr095_" in mf else "?" shard_aggs.append((name, agg, m.get("stop_head_stats", {}))) print(f" {name:<26} n={agg['n_total']:>3} SR={agg['SR']*100:>5.1f}% SPL={agg['SPL']*100:>5.1f}% OS={agg['OS']*100:>5.1f}% NE={agg['NE']:.2f}") # Compute weighted aggregate across shards total_n = sum(a[1]["n_total"] for a in shard_aggs) def weighted(field): return sum(a[1][field] * a[1]["n_total"] for a in shard_aggs) / total_n print("-" * 90) print(f" WEIGHTED ACROSS SHARDS n={total_n} " f"SR={weighted('SR')*100:.2f}% SPL={weighted('SPL')*100:.2f}% " f"OS={weighted('OS')*100:.2f}% NE={weighted('NE'):.3f}") print() # === SOURCE 2: dedupe-by-trajectory_id across all eval.log files === print("=" * 90) print("SOURCE 2: dedupe-by-trajectory_id 从所有 eval.log (含 .gz / _resume)") print("=" * 90) pat = re.compile( r'"shortest_path_length":\s*([\d.]+).*?"NE":\s*([\d.]+).*?"success":\s*([\d.]+).*?' r'"osr":\s*([\d.]+).*?"TL":\s*([\d.]+).*?"spl":\s*([\d.]+).*?"ndtw":\s*([\d.]+).*?' r'"trajectory_id":\s*"([^"]+)"', re.S, ) all_eps = {} log_paths = sorted(glob.glob(f"{EVALS}/*full835*/eval*.log") + glob.glob(f"{EVALS}/*full835*/eval*.log.gz")) for p in log_paths: if p.endswith(".gz"): content = gzip.open(p, "rt").read() else: content = open(p).read() for m in pat.finditer(content): spg, ne, succ, osr, tl, spl, ndtw, tid = m.groups() all_eps[tid] = { "spg": float(spg), "NE": float(ne), "success": int(float(succ) == 1.0), "osr": int(float(osr) == 1.0), "TL": float(tl), "SPL": float(spl), "nDTW": float(ndtw), } n = len(all_eps) eps = list(all_eps.values()) def avg(f): return sum(e[f] for e in eps) / n print(f" unique trajectory_id n={n}/835 " f"SR={100*avg('success'):.2f}% SPL={100*avg('SPL'):.2f}% " f"OSR={100*avg('osr'):.2f}% NE={avg('NE'):.3f} TL={avg('TL'):.3f} nDTW={avg('nDTW'):.4f}") print() # === SOURCE 3: vlnverse_emr's eval_result.log (official format) per shard === print("=" * 90) print("SOURCE 3: vlnverse_emr eval_result.log (官方格式)") print("=" * 90) result_logs = sorted(glob.glob(f"{EVALS}/*full835*/eval_result.log")) for rl in result_logs: name = rl.split("/")[-2].split("full835_thr095_")[1] if "full835_thr095_" in rl else "?" content = open(rl).read() # Parse SR/SPL/NE sr_m = re.search(r"SR = ([\d.]+) / (\d+) = ([\d.]+)%", content) sp_m = re.search(r"SPL = ([\d.]+) / (\d+) = ([\d.]+)%", content) if sr_m and sp_m: print(f" {name:<26} SR={sr_m.group(1):>5}/{sr_m.group(2)} = {sr_m.group(3)}% " f"SPL={sp_m.group(3)}%") print() # === SOURCE 4: stop_head_stats across shards === print("=" * 90) print("SOURCE 4: stop_head_stats(模型 stop head 行为)") print("=" * 90) all_sh = [] for name, agg, sh in shard_aggs: if sh.get("n_samples"): all_sh.append(sh) print(f" {name:<26} n_samples={sh['n_samples']:>5} " f"sp_p50={sh['stop_prob_p50']:.3f} sp_p90={sh['stop_prob_p90']:.3f} " f"pA={sh['fire_rate_pathA']*100:.2f}% pB={sh['fire_rate_pathB']*100:.2f}% " f"no_stop={sh['no_stop_rate']*100:.2f}%") # Weighted aggregate if all_sh: tot = sum(s["n_samples"] for s in all_sh) def wsh(f): return sum(s[f] * s["n_samples"] for s in all_sh) / tot print("-" * 90) print(f" WEIGHTED (n_samples={tot:>5}) sp_p50={wsh('stop_prob_p50'):.3f} " f"sp_p90={wsh('stop_prob_p90'):.3f} pA={wsh('fire_rate_pathA')*100:.2f}% " f"pB={wsh('fire_rate_pathB')*100:.2f}% no_stop={wsh('no_stop_rate')*100:.2f}%") print() # === FINAL — bucket SR === print("=" * 90) print("FINAL — bucket SR by shortest_path") print("=" * 90) for lo, hi in [(0, 5), (5, 8), (8, 12), (12, 18), (18, 30)]: bk = [e for e in eps if lo <= e["spg"] < hi] if bk: sr_b = sum(e["success"] for e in bk) / len(bk) ne_b = sum(e["NE"] for e in bk) / len(bk) print(f" short [{lo:>2},{hi:>2})m n={len(bk):>4} SR={100*sr_b:>5.1f}% NE={ne_b:.2f}")