Spaces:
Sleeping
Sleeping
File size: 6,891 Bytes
d6d9e31 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """
Generic iteration analyzer — pulls a trained model repo from HF Hub
and summarizes results.
Usage:
python scripts/analyze_iter.py iter3
python scripts/analyze_iter.py iter3 --model-suffix iter3
"""
import argparse
import json
import os
import statistics
import sys
from collections import Counter
from pathlib import Path
def main():
p = argparse.ArgumentParser()
p.add_argument("iter_name", help="Iteration name, e.g. 'iter3'")
p.add_argument("--model-suffix", default=None,
help="Model repo suffix (defaults to iter_name)")
p.add_argument("--owner", default="InosLihka")
p.add_argument("--no-download", action="store_true",
help="Skip download, use local files only")
args = p.parse_args()
suffix = args.model_suffix or args.iter_name
repo = f"{args.owner}/rhythm-env-meta-trained-{suffix}"
local_dir = Path(f"./{args.iter_name}_results")
local_dir.mkdir(exist_ok=True)
if not args.no_download:
for fname in ["eval_results.json", "log_history.json", "training_config.json"]:
cmd = f"hf download {repo} {fname} --repo-type=model --local-dir {local_dir}"
print(f">>> {cmd}")
os.system(cmd)
# ---- training config ----
cfg_path = local_dir / "training_config.json"
if cfg_path.exists():
with open(cfg_path) as f:
cfg = json.load(f)
print()
print("=== Training config ===")
for k in ("max_steps", "num_episodes", "max_samples", "num_generations",
"lora_rank", "beta", "learning_rate", "hint_fraction"):
print(f" {k}: {cfg.get(k)}")
# ---- eval results ----
eval_path = local_dir / "eval_results.json"
if not eval_path.exists():
print(f"FATAL: {eval_path} not found")
sys.exit(1)
with open(eval_path) as f:
data = json.load(f)
conditions = sorted(set(r["condition"] for r in data))
strategies = sorted(set(r["strategy"] for r in data))
print()
print("=" * 80)
print("EVAL SUMMARY")
print("=" * 80)
print(f"{'Condition':<40} {'Strategy':<10} {'score':>7} {'adapt':>7} {'belief_mae':>11}")
print("-" * 80)
avg_table = {}
for c in conditions:
for s in strategies:
rs = [r for r in data if r["condition"] == c and r["strategy"] == s]
if not rs:
continue
score = statistics.mean(r["final_score"] for r in rs)
adapt = statistics.mean(r["adaptation"] for r in rs)
belief = [r["belief_mae"] for r in rs if r.get("belief_mae") is not None]
belief_str = f"{statistics.mean(belief):.3f}" if belief else "-"
print(f"{c:<40} {s:<10} {score:>7.3f} {adapt:>+7.3f} {belief_str:>11}")
avg_table[(c, s)] = (score, adapt, belief)
# ---- model action distribution ----
print()
model_runs = [r for r in data if r["strategy"] == "model"]
all_actions = []
for r in model_runs:
all_actions.extend(r["actions"])
if all_actions:
print(f"=== Trained-model action distribution ({len(all_actions)} actions over {len(model_runs)} episodes) ===")
counts = Counter(all_actions)
total = len(all_actions)
for action, cnt in counts.most_common():
print(f" {action:15s} {cnt:4d} ({100*cnt/total:5.1f}%)")
print(f"Unique actions used: {len(counts)} of 10")
# ---- belief diversity ----
beliefs = [tuple(r["final_belief"]) for r in model_runs if r.get("final_belief")]
if beliefs:
unique_beliefs = len(set(beliefs))
print(f"Unique final beliefs: {unique_beliefs} of {len(beliefs)}")
# ---- training trajectory ----
log_path = local_dir / "log_history.json"
if log_path.exists():
with open(log_path) as f:
log = json.load(f)
print()
print("=== Training trajectory ===")
def series(*keys):
for k in keys:
steps, vals = [], []
for e in log:
if k in e:
steps.append(e.get("step", len(steps)))
vals.append(e[k])
if vals:
return steps, vals
return [], []
max_step = max((e.get("step", 0) for e in log), default=0)
snapshots = sorted(set([1, max_step // 4, max_step // 2, 3 * max_step // 4, max_step]))
print(f"\n{'metric':<32} " + " ".join(f"step{s:>5}" for s in snapshots))
print("-" * (32 + len(snapshots) * 11))
for label, *keys in [
("loss", "loss", "train/loss"),
("reward (mean)", "reward", "rewards/mean"),
("reward_std", "reward_std", "rewards/std"),
("frac_zero_std", "frac_reward_zero_std"),
("format_valid mean", "rewards/format_valid/mean"),
("action_legal mean", "rewards/action_legal/mean"),
("env_reward mean", "rewards/env_reward/mean"),
("belief_accuracy mean", "rewards/belief_accuracy/mean"),
("kl", "kl"),
("completion_length", "completions/mean_length", "completion_length"),
]:
steps, vals = series(*keys)
if not vals:
continue
row = []
for tgt in snapshots:
if not steps:
row.append("-")
continue
ix = min(range(len(steps)), key=lambda i: abs(steps[i] - tgt))
v = vals[ix]
row.append(f"{v:+.3f}" if isinstance(v, (int, float)) else str(v)[:8])
print(f"{label:<32} " + " ".join(f"{r:>10}" for r in row))
# ---- success criteria ----
print()
print("=" * 80)
print("VERDICT")
print("=" * 80)
targets = {
("continuous-in-distribution", "model", "score"): 0.587, # heuristic in-dist
("continuous-OOD (generalization)", "model", "score"): 0.580, # heuristic OOD
}
for (cond, strat, metric), target in targets.items():
if (cond, strat) in avg_table:
score, _, _ = avg_table[(cond, strat)]
symbol = "PASS" if score > target else "FAIL"
print(f" [{symbol}] {cond} score: {score:.3f} (need > {target} = heuristic)")
if all_actions:
n_unique = len(Counter(all_actions))
symbol = "PASS" if n_unique >= 5 else "FAIL"
print(f" [{symbol}] Action diversity: {n_unique} unique actions (need >= 5)")
# belief MAE check
in_dist_belief = avg_table.get(("continuous-in-distribution", "model"))
if in_dist_belief and in_dist_belief[2]:
avg_mae = statistics.mean(in_dist_belief[2])
symbol = "PASS" if avg_mae < 0.30 else "FAIL"
print(f" [{symbol}] Belief MAE in-dist: {avg_mae:.3f} (need < 0.30, lower is better)")
if __name__ == "__main__":
main()
|