ttll0928's picture
Initial release: RyWorld VLN stage1 discrete step15000 ckpt + eval pkg
4a61963
"""Final verification of full 835-ep eval metrics — cross-check from multiple sources."""
import gzip
import json
import re
import glob
import os
from collections import Counter
EVALS = "/imotion-vepfs/imotion_zxnet/workspace_wei/ry-dynamics-vln-ryworld/checkpoints/ryworld_stage1_discrete_20260510_025317_cw07_iw25/evals"
# === SOURCE 1: per-shard meta.json (vlnverse_emr's own aggregate per shard) ===
print("=" * 90)
print("SOURCE 1: per-shard meta.json (vlnverse_emr 自带 aggregate)")
print("=" * 90)
meta_files = sorted(glob.glob(f"{EVALS}/*full835*/meta.json"))
shard_aggs = []
for mf in meta_files:
m = json.load(open(mf))
agg = m.get("results_aggregate", {})
if agg.get("n_total"):
name = mf.split("/")[-2].split("full835_thr095_")[1] if "full835_thr095_" in mf else "?"
shard_aggs.append((name, agg, m.get("stop_head_stats", {})))
print(f" {name:<26} n={agg['n_total']:>3} SR={agg['SR']*100:>5.1f}% SPL={agg['SPL']*100:>5.1f}% OS={agg['OS']*100:>5.1f}% NE={agg['NE']:.2f}")
# Compute weighted aggregate across shards
total_n = sum(a[1]["n_total"] for a in shard_aggs)
def weighted(field):
return sum(a[1][field] * a[1]["n_total"] for a in shard_aggs) / total_n
print("-" * 90)
print(f" WEIGHTED ACROSS SHARDS n={total_n} "
f"SR={weighted('SR')*100:.2f}% SPL={weighted('SPL')*100:.2f}% "
f"OS={weighted('OS')*100:.2f}% NE={weighted('NE'):.3f}")
print()
# === SOURCE 2: dedupe-by-trajectory_id across all eval.log files ===
print("=" * 90)
print("SOURCE 2: dedupe-by-trajectory_id 从所有 eval.log (含 .gz / _resume)")
print("=" * 90)
pat = re.compile(
r'"shortest_path_length":\s*([\d.]+).*?"NE":\s*([\d.]+).*?"success":\s*([\d.]+).*?'
r'"osr":\s*([\d.]+).*?"TL":\s*([\d.]+).*?"spl":\s*([\d.]+).*?"ndtw":\s*([\d.]+).*?'
r'"trajectory_id":\s*"([^"]+)"',
re.S,
)
all_eps = {}
log_paths = sorted(glob.glob(f"{EVALS}/*full835*/eval*.log") + glob.glob(f"{EVALS}/*full835*/eval*.log.gz"))
for p in log_paths:
if p.endswith(".gz"):
content = gzip.open(p, "rt").read()
else:
content = open(p).read()
for m in pat.finditer(content):
spg, ne, succ, osr, tl, spl, ndtw, tid = m.groups()
all_eps[tid] = {
"spg": float(spg), "NE": float(ne),
"success": int(float(succ) == 1.0),
"osr": int(float(osr) == 1.0),
"TL": float(tl), "SPL": float(spl), "nDTW": float(ndtw),
}
n = len(all_eps)
eps = list(all_eps.values())
def avg(f):
return sum(e[f] for e in eps) / n
print(f" unique trajectory_id n={n}/835 "
f"SR={100*avg('success'):.2f}% SPL={100*avg('SPL'):.2f}% "
f"OSR={100*avg('osr'):.2f}% NE={avg('NE'):.3f} TL={avg('TL'):.3f} nDTW={avg('nDTW'):.4f}")
print()
# === SOURCE 3: vlnverse_emr's eval_result.log (official format) per shard ===
print("=" * 90)
print("SOURCE 3: vlnverse_emr eval_result.log (官方格式)")
print("=" * 90)
result_logs = sorted(glob.glob(f"{EVALS}/*full835*/eval_result.log"))
for rl in result_logs:
name = rl.split("/")[-2].split("full835_thr095_")[1] if "full835_thr095_" in rl else "?"
content = open(rl).read()
# Parse SR/SPL/NE
sr_m = re.search(r"SR = ([\d.]+) / (\d+) = ([\d.]+)%", content)
sp_m = re.search(r"SPL = ([\d.]+) / (\d+) = ([\d.]+)%", content)
if sr_m and sp_m:
print(f" {name:<26} SR={sr_m.group(1):>5}/{sr_m.group(2)} = {sr_m.group(3)}% "
f"SPL={sp_m.group(3)}%")
print()
# === SOURCE 4: stop_head_stats across shards ===
print("=" * 90)
print("SOURCE 4: stop_head_stats(模型 stop head 行为)")
print("=" * 90)
all_sh = []
for name, agg, sh in shard_aggs:
if sh.get("n_samples"):
all_sh.append(sh)
print(f" {name:<26} n_samples={sh['n_samples']:>5} "
f"sp_p50={sh['stop_prob_p50']:.3f} sp_p90={sh['stop_prob_p90']:.3f} "
f"pA={sh['fire_rate_pathA']*100:.2f}% pB={sh['fire_rate_pathB']*100:.2f}% "
f"no_stop={sh['no_stop_rate']*100:.2f}%")
# Weighted aggregate
if all_sh:
tot = sum(s["n_samples"] for s in all_sh)
def wsh(f):
return sum(s[f] * s["n_samples"] for s in all_sh) / tot
print("-" * 90)
print(f" WEIGHTED (n_samples={tot:>5}) sp_p50={wsh('stop_prob_p50'):.3f} "
f"sp_p90={wsh('stop_prob_p90'):.3f} pA={wsh('fire_rate_pathA')*100:.2f}% "
f"pB={wsh('fire_rate_pathB')*100:.2f}% no_stop={wsh('no_stop_rate')*100:.2f}%")
print()
# === FINAL — bucket SR ===
print("=" * 90)
print("FINAL — bucket SR by shortest_path")
print("=" * 90)
for lo, hi in [(0, 5), (5, 8), (8, 12), (12, 18), (18, 30)]:
bk = [e for e in eps if lo <= e["spg"] < hi]
if bk:
sr_b = sum(e["success"] for e in bk) / len(bk)
ne_b = sum(e["NE"] for e in bk) / len(bk)
print(f" short [{lo:>2},{hi:>2})m n={len(bk):>4} SR={100*sr_b:>5.1f}% NE={ne_b:.2f}")