rhythm_env / scripts /analyze_iter.py
InosLihka's picture
tooling: scripts/analyze_iter.py + docs/results.md template
d6d9e31
"""
Generic iteration analyzer — pulls a trained model repo from HF Hub
and summarizes results.
Usage:
python scripts/analyze_iter.py iter3
python scripts/analyze_iter.py iter3 --model-suffix iter3
"""
import argparse
import json
import os
import statistics
import sys
from collections import Counter
from pathlib import Path
def main():
p = argparse.ArgumentParser()
p.add_argument("iter_name", help="Iteration name, e.g. 'iter3'")
p.add_argument("--model-suffix", default=None,
help="Model repo suffix (defaults to iter_name)")
p.add_argument("--owner", default="InosLihka")
p.add_argument("--no-download", action="store_true",
help="Skip download, use local files only")
args = p.parse_args()
suffix = args.model_suffix or args.iter_name
repo = f"{args.owner}/rhythm-env-meta-trained-{suffix}"
local_dir = Path(f"./{args.iter_name}_results")
local_dir.mkdir(exist_ok=True)
if not args.no_download:
for fname in ["eval_results.json", "log_history.json", "training_config.json"]:
cmd = f"hf download {repo} {fname} --repo-type=model --local-dir {local_dir}"
print(f">>> {cmd}")
os.system(cmd)
# ---- training config ----
cfg_path = local_dir / "training_config.json"
if cfg_path.exists():
with open(cfg_path) as f:
cfg = json.load(f)
print()
print("=== Training config ===")
for k in ("max_steps", "num_episodes", "max_samples", "num_generations",
"lora_rank", "beta", "learning_rate", "hint_fraction"):
print(f" {k}: {cfg.get(k)}")
# ---- eval results ----
eval_path = local_dir / "eval_results.json"
if not eval_path.exists():
print(f"FATAL: {eval_path} not found")
sys.exit(1)
with open(eval_path) as f:
data = json.load(f)
conditions = sorted(set(r["condition"] for r in data))
strategies = sorted(set(r["strategy"] for r in data))
print()
print("=" * 80)
print("EVAL SUMMARY")
print("=" * 80)
print(f"{'Condition':<40} {'Strategy':<10} {'score':>7} {'adapt':>7} {'belief_mae':>11}")
print("-" * 80)
avg_table = {}
for c in conditions:
for s in strategies:
rs = [r for r in data if r["condition"] == c and r["strategy"] == s]
if not rs:
continue
score = statistics.mean(r["final_score"] for r in rs)
adapt = statistics.mean(r["adaptation"] for r in rs)
belief = [r["belief_mae"] for r in rs if r.get("belief_mae") is not None]
belief_str = f"{statistics.mean(belief):.3f}" if belief else "-"
print(f"{c:<40} {s:<10} {score:>7.3f} {adapt:>+7.3f} {belief_str:>11}")
avg_table[(c, s)] = (score, adapt, belief)
# ---- model action distribution ----
print()
model_runs = [r for r in data if r["strategy"] == "model"]
all_actions = []
for r in model_runs:
all_actions.extend(r["actions"])
if all_actions:
print(f"=== Trained-model action distribution ({len(all_actions)} actions over {len(model_runs)} episodes) ===")
counts = Counter(all_actions)
total = len(all_actions)
for action, cnt in counts.most_common():
print(f" {action:15s} {cnt:4d} ({100*cnt/total:5.1f}%)")
print(f"Unique actions used: {len(counts)} of 10")
# ---- belief diversity ----
beliefs = [tuple(r["final_belief"]) for r in model_runs if r.get("final_belief")]
if beliefs:
unique_beliefs = len(set(beliefs))
print(f"Unique final beliefs: {unique_beliefs} of {len(beliefs)}")
# ---- training trajectory ----
log_path = local_dir / "log_history.json"
if log_path.exists():
with open(log_path) as f:
log = json.load(f)
print()
print("=== Training trajectory ===")
def series(*keys):
for k in keys:
steps, vals = [], []
for e in log:
if k in e:
steps.append(e.get("step", len(steps)))
vals.append(e[k])
if vals:
return steps, vals
return [], []
max_step = max((e.get("step", 0) for e in log), default=0)
snapshots = sorted(set([1, max_step // 4, max_step // 2, 3 * max_step // 4, max_step]))
print(f"\n{'metric':<32} " + " ".join(f"step{s:>5}" for s in snapshots))
print("-" * (32 + len(snapshots) * 11))
for label, *keys in [
("loss", "loss", "train/loss"),
("reward (mean)", "reward", "rewards/mean"),
("reward_std", "reward_std", "rewards/std"),
("frac_zero_std", "frac_reward_zero_std"),
("format_valid mean", "rewards/format_valid/mean"),
("action_legal mean", "rewards/action_legal/mean"),
("env_reward mean", "rewards/env_reward/mean"),
("belief_accuracy mean", "rewards/belief_accuracy/mean"),
("kl", "kl"),
("completion_length", "completions/mean_length", "completion_length"),
]:
steps, vals = series(*keys)
if not vals:
continue
row = []
for tgt in snapshots:
if not steps:
row.append("-")
continue
ix = min(range(len(steps)), key=lambda i: abs(steps[i] - tgt))
v = vals[ix]
row.append(f"{v:+.3f}" if isinstance(v, (int, float)) else str(v)[:8])
print(f"{label:<32} " + " ".join(f"{r:>10}" for r in row))
# ---- success criteria ----
print()
print("=" * 80)
print("VERDICT")
print("=" * 80)
targets = {
("continuous-in-distribution", "model", "score"): 0.587, # heuristic in-dist
("continuous-OOD (generalization)", "model", "score"): 0.580, # heuristic OOD
}
for (cond, strat, metric), target in targets.items():
if (cond, strat) in avg_table:
score, _, _ = avg_table[(cond, strat)]
symbol = "PASS" if score > target else "FAIL"
print(f" [{symbol}] {cond} score: {score:.3f} (need > {target} = heuristic)")
if all_actions:
n_unique = len(Counter(all_actions))
symbol = "PASS" if n_unique >= 5 else "FAIL"
print(f" [{symbol}] Action diversity: {n_unique} unique actions (need >= 5)")
# belief MAE check
in_dist_belief = avg_table.get(("continuous-in-distribution", "model"))
if in_dist_belief and in_dist_belief[2]:
avg_mae = statistics.mean(in_dist_belief[2])
symbol = "PASS" if avg_mae < 0.30 else "FAIL"
print(f" [{symbol}] Belief MAE in-dist: {avg_mae:.3f} (need < 0.30, lower is better)")
if __name__ == "__main__":
main()