""" Generic iteration analyzer — pulls a trained model repo from HF Hub and summarizes results. Usage: python scripts/analyze_iter.py iter3 python scripts/analyze_iter.py iter3 --model-suffix iter3 """ import argparse import json import os import statistics import sys from collections import Counter from pathlib import Path def main(): p = argparse.ArgumentParser() p.add_argument("iter_name", help="Iteration name, e.g. 'iter3'") p.add_argument("--model-suffix", default=None, help="Model repo suffix (defaults to iter_name)") p.add_argument("--owner", default="InosLihka") p.add_argument("--no-download", action="store_true", help="Skip download, use local files only") args = p.parse_args() suffix = args.model_suffix or args.iter_name repo = f"{args.owner}/rhythm-env-meta-trained-{suffix}" local_dir = Path(f"./{args.iter_name}_results") local_dir.mkdir(exist_ok=True) if not args.no_download: for fname in ["eval_results.json", "log_history.json", "training_config.json"]: cmd = f"hf download {repo} {fname} --repo-type=model --local-dir {local_dir}" print(f">>> {cmd}") os.system(cmd) # ---- training config ---- cfg_path = local_dir / "training_config.json" if cfg_path.exists(): with open(cfg_path) as f: cfg = json.load(f) print() print("=== Training config ===") for k in ("max_steps", "num_episodes", "max_samples", "num_generations", "lora_rank", "beta", "learning_rate", "hint_fraction"): print(f" {k}: {cfg.get(k)}") # ---- eval results ---- eval_path = local_dir / "eval_results.json" if not eval_path.exists(): print(f"FATAL: {eval_path} not found") sys.exit(1) with open(eval_path) as f: data = json.load(f) conditions = sorted(set(r["condition"] for r in data)) strategies = sorted(set(r["strategy"] for r in data)) print() print("=" * 80) print("EVAL SUMMARY") print("=" * 80) print(f"{'Condition':<40} {'Strategy':<10} {'score':>7} {'adapt':>7} {'belief_mae':>11}") print("-" * 80) avg_table = {} for c in conditions: for s in strategies: rs = [r for r in data if r["condition"] == c and r["strategy"] == s] if not rs: continue score = statistics.mean(r["final_score"] for r in rs) adapt = statistics.mean(r["adaptation"] for r in rs) belief = [r["belief_mae"] for r in rs if r.get("belief_mae") is not None] belief_str = f"{statistics.mean(belief):.3f}" if belief else "-" print(f"{c:<40} {s:<10} {score:>7.3f} {adapt:>+7.3f} {belief_str:>11}") avg_table[(c, s)] = (score, adapt, belief) # ---- model action distribution ---- print() model_runs = [r for r in data if r["strategy"] == "model"] all_actions = [] for r in model_runs: all_actions.extend(r["actions"]) if all_actions: print(f"=== Trained-model action distribution ({len(all_actions)} actions over {len(model_runs)} episodes) ===") counts = Counter(all_actions) total = len(all_actions) for action, cnt in counts.most_common(): print(f" {action:15s} {cnt:4d} ({100*cnt/total:5.1f}%)") print(f"Unique actions used: {len(counts)} of 10") # ---- belief diversity ---- beliefs = [tuple(r["final_belief"]) for r in model_runs if r.get("final_belief")] if beliefs: unique_beliefs = len(set(beliefs)) print(f"Unique final beliefs: {unique_beliefs} of {len(beliefs)}") # ---- training trajectory ---- log_path = local_dir / "log_history.json" if log_path.exists(): with open(log_path) as f: log = json.load(f) print() print("=== Training trajectory ===") def series(*keys): for k in keys: steps, vals = [], [] for e in log: if k in e: steps.append(e.get("step", len(steps))) vals.append(e[k]) if vals: return steps, vals return [], [] max_step = max((e.get("step", 0) for e in log), default=0) snapshots = sorted(set([1, max_step // 4, max_step // 2, 3 * max_step // 4, max_step])) print(f"\n{'metric':<32} " + " ".join(f"step{s:>5}" for s in snapshots)) print("-" * (32 + len(snapshots) * 11)) for label, *keys in [ ("loss", "loss", "train/loss"), ("reward (mean)", "reward", "rewards/mean"), ("reward_std", "reward_std", "rewards/std"), ("frac_zero_std", "frac_reward_zero_std"), ("format_valid mean", "rewards/format_valid/mean"), ("action_legal mean", "rewards/action_legal/mean"), ("env_reward mean", "rewards/env_reward/mean"), ("belief_accuracy mean", "rewards/belief_accuracy/mean"), ("kl", "kl"), ("completion_length", "completions/mean_length", "completion_length"), ]: steps, vals = series(*keys) if not vals: continue row = [] for tgt in snapshots: if not steps: row.append("-") continue ix = min(range(len(steps)), key=lambda i: abs(steps[i] - tgt)) v = vals[ix] row.append(f"{v:+.3f}" if isinstance(v, (int, float)) else str(v)[:8]) print(f"{label:<32} " + " ".join(f"{r:>10}" for r in row)) # ---- success criteria ---- print() print("=" * 80) print("VERDICT") print("=" * 80) targets = { ("continuous-in-distribution", "model", "score"): 0.587, # heuristic in-dist ("continuous-OOD (generalization)", "model", "score"): 0.580, # heuristic OOD } for (cond, strat, metric), target in targets.items(): if (cond, strat) in avg_table: score, _, _ = avg_table[(cond, strat)] symbol = "PASS" if score > target else "FAIL" print(f" [{symbol}] {cond} score: {score:.3f} (need > {target} = heuristic)") if all_actions: n_unique = len(Counter(all_actions)) symbol = "PASS" if n_unique >= 5 else "FAIL" print(f" [{symbol}] Action diversity: {n_unique} unique actions (need >= 5)") # belief MAE check in_dist_belief = avg_table.get(("continuous-in-distribution", "model")) if in_dist_belief and in_dist_belief[2]: avg_mae = statistics.mean(in_dist_belief[2]) symbol = "PASS" if avg_mae < 0.30 else "FAIL" print(f" [{symbol}] Belief MAE in-dist: {avg_mae:.3f} (need < 0.30, lower is better)") if __name__ == "__main__": main()