InosLihka commited on
Commit
73c7ea0
·
1 Parent(s): 7340206

feat: HF Jobs training script + plot generator

Browse files

Adds:
- scripts/train_on_hf.py: PEP 723 self-contained job script.
Submit with: hf jobs uv run --flavor a10g-large --secrets HF_TOKEN scripts/train_on_hf.py
Clones repo, trains, evals, generates plots, uploads to HF Hub.
- scripts/plot_from_log.py: regenerates all 5 plots from saved log_history.json.
Usable both inside the HF Job AND locally to re-plot without re-training.
- training/train.py: now saves trainer.state.log_history to log_history.json
so plots can be generated offline.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

scripts/plot_from_log.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate training plots from a saved trainer.state.log_history JSON.
3
+
4
+ Used by HF Jobs (where there's no notebook to call trainer.state directly)
5
+ and locally to regenerate plots after training without re-running.
6
+
7
+ Usage:
8
+ python scripts/plot_from_log.py --log outputs/.../log_history.json --out plots/
9
+ """
10
+
11
+ import argparse
12
+ import json
13
+ import os
14
+
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+
18
+
19
+ def series(log, *keys):
20
+ """Find the first matching key across log entries; return (steps, values, key)."""
21
+ for k in keys:
22
+ steps, vals = [], []
23
+ for entry in log:
24
+ if k in entry:
25
+ steps.append(entry.get("step", len(steps)))
26
+ vals.append(entry[k])
27
+ if vals:
28
+ return steps, vals, k
29
+ return [], [], None
30
+
31
+
32
+ def main():
33
+ p = argparse.ArgumentParser()
34
+ p.add_argument("--log", required=True)
35
+ p.add_argument("--out", default="plots")
36
+ args = p.parse_args()
37
+
38
+ with open(args.log) as f:
39
+ log = json.load(f)
40
+
41
+ os.makedirs(args.out, exist_ok=True)
42
+
43
+ # Discover keys to help debug
44
+ all_keys = set()
45
+ for entry in log:
46
+ all_keys.update(entry.keys())
47
+ print(f"Available log keys: {sorted(all_keys)}")
48
+
49
+ # 1. Training Loss
50
+ steps, losses, _ = series(log, "loss", "train/loss")
51
+ if losses:
52
+ fig, ax = plt.subplots(figsize=(10, 5))
53
+ ax.plot(steps, losses, color="#2563eb", linewidth=1.5, alpha=0.8)
54
+ ax.set_xlabel("Training Step")
55
+ ax.set_ylabel("Loss")
56
+ ax.set_title("GRPO Training Loss - RhythmEnv Meta-RL")
57
+ ax.grid(True, alpha=0.3)
58
+ plt.tight_layout()
59
+ plt.savefig(f"{args.out}/training_loss.png", dpi=150)
60
+ plt.close()
61
+ print(f"Saved: {args.out}/training_loss.png ({len(losses)} points)")
62
+
63
+ # 2. Mean Reward
64
+ rs, rv, rk = series(log, "reward", "rewards/mean", "rewards/total/mean")
65
+ ss, sv, _ = series(log, "reward_std", "rewards/std", "rewards/total/std")
66
+ if rv:
67
+ fig, ax = plt.subplots(figsize=(10, 5))
68
+ ax.plot(rs, rv, color="#16a34a", linewidth=1.5, label=f"Mean Reward ({rk})")
69
+ if sv and len(sv) == len(rv):
70
+ r, s = np.array(rv), np.array(sv)
71
+ ax.fill_between(rs, r - s, r + s, color="#16a34a", alpha=0.15, label="+/-1 std")
72
+ ax.set_xlabel("Training Step")
73
+ ax.set_ylabel("Mean Total Reward")
74
+ ax.set_title("GRPO Mean Reward over Training - RhythmEnv Meta-RL")
75
+ ax.legend()
76
+ ax.grid(True, alpha=0.3)
77
+ plt.tight_layout()
78
+ plt.savefig(f"{args.out}/reward_curve.png", dpi=150)
79
+ plt.close()
80
+ print(f"Saved: {args.out}/reward_curve.png ({len(rv)} points)")
81
+
82
+ # 3. Per-Reward-Function components
83
+ components = [
84
+ ("format_valid", ["rewards/format_valid/mean", "rewards/format_valid", "format_valid_reward"]),
85
+ ("action_legal", ["rewards/action_legal/mean", "rewards/action_legal", "action_legal_reward"]),
86
+ ("env_reward", ["rewards/env_reward/mean", "rewards/env_reward", "env_reward_reward"]),
87
+ ("belief_accuracy", ["rewards/belief_accuracy/mean", "rewards/belief_accuracy", "belief_accuracy_reward"]),
88
+ ]
89
+ found = []
90
+ for name, keys in components:
91
+ s, v, k = series(log, *keys)
92
+ if v:
93
+ found.append((name, s, v))
94
+ print(f" {name}: matched key '{k}'")
95
+ else:
96
+ print(f" {name}: NOT FOUND")
97
+
98
+ if found:
99
+ fig, ax = plt.subplots(figsize=(12, 6))
100
+ colors = {"format_valid": "#94a3b8", "action_legal": "#60a5fa", "env_reward": "#22c55e", "belief_accuracy": "#a855f7"}
101
+ for name, s, v in found:
102
+ ax.plot(s, v, color=colors.get(name, "#000"), linewidth=1.5, alpha=0.85, label=name)
103
+ ax.axhline(0, color="k", linewidth=0.4)
104
+ ax.set_xlabel("Training Step")
105
+ ax.set_ylabel("Mean Reward Component")
106
+ ax.set_title("4-Layer Reward Stack over Training (RhythmEnv Meta-RL)")
107
+ ax.legend(loc="best")
108
+ ax.grid(True, alpha=0.3)
109
+ plt.tight_layout()
110
+ plt.savefig(f"{args.out}/reward_components.png", dpi=150)
111
+ plt.close()
112
+ print(f"Saved: {args.out}/reward_components.png ({len(found)} components)")
113
+
114
+ # 4. Belief-Accuracy curve
115
+ bs, bv, _ = series(log, "rewards/belief_accuracy/mean", "rewards/belief_accuracy", "belief_accuracy_reward")
116
+ if bv:
117
+ fig, ax = plt.subplots(figsize=(10, 5))
118
+ ax.plot(bs, bv, color="#a855f7", linewidth=2.0, alpha=0.9, label="Belief reward")
119
+ if len(bv) > 20:
120
+ win = max(10, len(bv) // 30)
121
+ kernel = np.ones(win) / win
122
+ smooth = np.convolve(bv, kernel, mode="valid")
123
+ ax.plot(bs[win - 1:], smooth, color="#7e22ce", linewidth=2.5, label=f"Rolling mean ({win}-step)")
124
+ ax.axhline(0.0, color="k", linewidth=0.5, linestyle="--", alpha=0.5, label="neutral baseline")
125
+ ax.set_xlabel("Training Step")
126
+ ax.set_ylabel("Mean belief_accuracy reward (-0.5 to +0.5)")
127
+ ax.set_title("Belief-Accuracy Reward over Training (proof agent learned to model user)")
128
+ ax.legend(loc="best")
129
+ ax.grid(True, alpha=0.3)
130
+ plt.tight_layout()
131
+ plt.savefig(f"{args.out}/belief_accuracy.png", dpi=150)
132
+ plt.close()
133
+ print(f"Saved: {args.out}/belief_accuracy.png ({len(bv)} points)")
134
+
135
+ # 5. Comparison plot if eval_results.json is available
136
+ eval_path = "eval_results.json"
137
+ if os.path.exists(eval_path):
138
+ with open(eval_path) as f:
139
+ results = json.load(f)
140
+ conditions = ["discrete-3-profiles (legacy)", "continuous-in-distribution", "continuous-OOD (generalization)"]
141
+
142
+ def avg(cond, strat, key="final_score"):
143
+ rs = [r[key] for r in results if r["condition"] == cond and r["strategy"] == strat]
144
+ return float(np.mean(rs)) if rs else 0.0
145
+
146
+ x = np.arange(len(conditions))
147
+ width = 0.27
148
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
149
+ rand = [avg(c, "random") for c in conditions]
150
+ heur = [avg(c, "heuristic") for c in conditions]
151
+ trnd = [avg(c, "model") for c in conditions]
152
+ axes[0].bar(x - width, rand, width, label="Random", color="#94a3b8")
153
+ axes[0].bar(x, heur, width, label="Heuristic", color="#60a5fa")
154
+ axes[0].bar(x + width, trnd, width, label="Trained Qwen", color="#22c55e")
155
+ axes[0].set_ylabel("Final score (0-1)")
156
+ axes[0].set_title("Final score by condition")
157
+ axes[0].set_xticks(x)
158
+ axes[0].set_xticklabels([c.split(" ")[0] for c in conditions], fontsize=10)
159
+ axes[0].legend()
160
+ axes[0].grid(axis="y", alpha=0.3)
161
+
162
+ rand_a = [avg(c, "random", "adaptation") for c in conditions]
163
+ heur_a = [avg(c, "heuristic", "adaptation") for c in conditions]
164
+ trnd_a = [avg(c, "model", "adaptation") for c in conditions]
165
+ axes[1].bar(x - width, rand_a, width, label="Random", color="#94a3b8")
166
+ axes[1].bar(x, heur_a, width, label="Heuristic", color="#60a5fa")
167
+ axes[1].bar(x + width, trnd_a, width, label="Trained Qwen", color="#22c55e")
168
+ axes[1].set_ylabel("Adaptation (late-half - early-half mean reward)")
169
+ axes[1].set_title("Adaptation: did agent get better mid-episode?")
170
+ axes[1].set_xticks(x)
171
+ axes[1].set_xticklabels([c.split(" ")[0] for c in conditions], fontsize=10)
172
+ axes[1].axhline(0, color="k", linewidth=0.5)
173
+ axes[1].legend()
174
+ axes[1].grid(axis="y", alpha=0.3)
175
+ plt.tight_layout()
176
+ plt.savefig(f"{args.out}/baseline_vs_trained.png", dpi=150)
177
+ plt.close()
178
+ print(f"Saved: {args.out}/baseline_vs_trained.png")
179
+
180
+ print()
181
+ print(f"{'Condition':<40} {'Random':>10} {'Heuristic':>10} {'Trained':>10} {'vs Heuristic':>14}")
182
+ print("-" * 90)
183
+ for c, r, h, t in zip(conditions, rand, heur, trnd):
184
+ print(f"{c:<40} {r:>10.3f} {h:>10.3f} {t:>10.3f} {(t - h):>+14.3f}")
185
+
186
+
187
+ if __name__ == "__main__":
188
+ main()
scripts/train_on_hf.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch",
5
+ # "transformers==4.56.2",
6
+ # "trl==0.22.2",
7
+ # "datasets",
8
+ # "peft",
9
+ # "accelerate",
10
+ # "bitsandbytes",
11
+ # "unsloth",
12
+ # "openenv-core",
13
+ # "fastapi",
14
+ # "uvicorn",
15
+ # "pydantic",
16
+ # "matplotlib",
17
+ # "huggingface_hub",
18
+ # ]
19
+ # ///
20
+ """
21
+ End-to-end training job for HF Jobs.
22
+
23
+ Submit from local machine with:
24
+ hf jobs uv run --flavor a10g-large --secrets HF_TOKEN scripts/train_on_hf.py
25
+
26
+ What it does (no babysitting required):
27
+ 1. Clone rhythm_env from HF Space (gets latest meta-RL code from main)
28
+ 2. Generate dataset (continuous profiles, hint_fraction=0.15)
29
+ 3. Train Qwen 2.5-3B + LoRA rank 8 via GRPO (1500 steps)
30
+ 4. Run eval on all 3 conditions (discrete, in-dist, OOD)
31
+ 5. Generate all 5 plots from log_history
32
+ 6. Upload trained model + plots + eval JSON to a new HF Hub model repo
33
+
34
+ Override defaults via env vars:
35
+ MAX_STEPS, NUM_EPISODES, LORA_RANK, BETA, MODEL_REPO
36
+
37
+ Estimated cost on a10g-large at $1.50/hr: ~$3 for 1500 steps (~2h).
38
+ """
39
+
40
+ import json
41
+ import os
42
+ import shutil
43
+ import subprocess
44
+ import sys
45
+ from pathlib import Path
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Config (overridable via env vars)
49
+ # ---------------------------------------------------------------------------
50
+ REPO_URL = os.environ.get("REPO_URL", "https://huggingface.co/spaces/InosLihka/rhythm_env")
51
+ WORK_DIR = "/tmp/rhythm_env"
52
+ OUTPUT_DIR = "/tmp/rhythm_env/outputs/rhythmenv_meta_trained"
53
+ PLOTS_DIR = "/tmp/rhythm_env/plots"
54
+
55
+ MAX_STEPS = int(os.environ.get("MAX_STEPS", "1500"))
56
+ NUM_EPISODES = int(os.environ.get("NUM_EPISODES", "300"))
57
+ LORA_RANK = int(os.environ.get("LORA_RANK", "8"))
58
+ BETA = float(os.environ.get("BETA", "0.1"))
59
+
60
+ MODEL_REPO = os.environ.get("MODEL_REPO", "InosLihka/rhythm-env-meta-trained")
61
+
62
+
63
+ def run(cmd: list[str], **kw):
64
+ """Run subprocess with logging."""
65
+ print(f"\n>>> {' '.join(cmd) if isinstance(cmd, list) else cmd}", flush=True)
66
+ subprocess.run(cmd, check=True, **kw)
67
+
68
+
69
+ def main():
70
+ # ---------------------------------------------------------------
71
+ # 1. Clone the rhythm_env repo
72
+ # ---------------------------------------------------------------
73
+ if Path(WORK_DIR).exists():
74
+ shutil.rmtree(WORK_DIR)
75
+ run(["git", "clone", REPO_URL, WORK_DIR])
76
+ os.chdir(WORK_DIR)
77
+ sys.path.insert(0, WORK_DIR)
78
+ sys.path.insert(0, os.path.join(WORK_DIR, "training"))
79
+
80
+ # Verify meta-RL code is present
81
+ dataset_py = Path("training/dataset.py").read_text()
82
+ assert "profile_mode" in dataset_py, "Cloned repo doesn't have meta-RL code"
83
+ print("OK: meta-RL code present in cloned repo")
84
+
85
+ # ---------------------------------------------------------------
86
+ # 2. Train
87
+ # ---------------------------------------------------------------
88
+ train_args = [
89
+ "python", "training/train.py",
90
+ "--max_steps", str(MAX_STEPS),
91
+ "--num_episodes", str(NUM_EPISODES),
92
+ "--lora_rank", str(LORA_RANK),
93
+ "--beta", str(BETA),
94
+ "--output_dir", OUTPUT_DIR,
95
+ ]
96
+ run(train_args)
97
+
98
+ # ---------------------------------------------------------------
99
+ # 3. Eval (3 conditions: discrete-3 / in-dist / OOD)
100
+ # ---------------------------------------------------------------
101
+ eval_args = [
102
+ "python", "training/inference_eval.py",
103
+ "--model_path", OUTPUT_DIR,
104
+ "--num_episodes", "5",
105
+ "--output_file", "eval_results.json",
106
+ ]
107
+ run(eval_args)
108
+
109
+ # ---------------------------------------------------------------
110
+ # 4. Generate plots from saved log_history
111
+ # ---------------------------------------------------------------
112
+ Path(PLOTS_DIR).mkdir(exist_ok=True)
113
+ log_path = os.path.join(OUTPUT_DIR, "log_history.json")
114
+ if Path(log_path).exists():
115
+ run(["python", "scripts/plot_from_log.py", "--log", log_path, "--out", PLOTS_DIR])
116
+ else:
117
+ print(f"WARNING: log_history.json not found at {log_path}")
118
+
119
+ # ---------------------------------------------------------------
120
+ # 5. Upload everything to HF Hub
121
+ # ---------------------------------------------------------------
122
+ token = os.environ.get("HF_TOKEN")
123
+ if not token:
124
+ print("WARNING: HF_TOKEN not set, skipping upload")
125
+ print(f"Outputs in: {OUTPUT_DIR}")
126
+ return
127
+
128
+ from huggingface_hub import HfApi, login
129
+ login(token=token)
130
+ api = HfApi()
131
+ api.create_repo(MODEL_REPO, exist_ok=True, repo_type="model")
132
+
133
+ # Upload trained model + config + log_history
134
+ api.upload_folder(
135
+ folder_path=OUTPUT_DIR,
136
+ repo_id=MODEL_REPO,
137
+ repo_type="model",
138
+ commit_message=f"Trained {MAX_STEPS}-step GRPO meta-RL agent",
139
+ )
140
+
141
+ # Upload eval JSON
142
+ api.upload_file(
143
+ path_or_fileobj="eval_results.json",
144
+ path_in_repo="eval_results.json",
145
+ repo_id=MODEL_REPO,
146
+ repo_type="model",
147
+ )
148
+
149
+ # Upload plots if generated
150
+ if Path(PLOTS_DIR).exists() and any(Path(PLOTS_DIR).iterdir()):
151
+ api.upload_folder(
152
+ folder_path=PLOTS_DIR,
153
+ path_in_repo="plots",
154
+ repo_id=MODEL_REPO,
155
+ repo_type="model",
156
+ )
157
+
158
+ print()
159
+ print("=" * 60)
160
+ print("DONE")
161
+ print(f" Trained model: https://huggingface.co/{MODEL_REPO}")
162
+ print(f" Eval JSON: https://huggingface.co/{MODEL_REPO}/blob/main/eval_results.json")
163
+ print(f" Plots: https://huggingface.co/{MODEL_REPO}/tree/main/plots")
164
+ print("=" * 60)
165
+
166
+
167
+ if __name__ == "__main__":
168
+ main()
training/train.py CHANGED
@@ -204,8 +204,14 @@ def main():
204
  with open(config_path, "w") as f:
205
  json.dump(vars(args), f, indent=2)
206
 
 
 
 
 
 
207
  print(f"Model saved to: {args.output_dir}")
208
  print(f"Training config saved to: {config_path}")
 
209
  print("\nNext: run inference_eval.py to compare baseline vs trained")
210
  print(" python training/inference_eval.py --model_path " + args.output_dir)
211
 
 
204
  with open(config_path, "w") as f:
205
  json.dump(vars(args), f, indent=2)
206
 
207
+ # Save log_history for offline plotting (job runs don't have a notebook to inspect trainer.state)
208
+ log_path = os.path.join(args.output_dir, "log_history.json")
209
+ with open(log_path, "w") as f:
210
+ json.dump(trainer.state.log_history, f, indent=2)
211
+
212
  print(f"Model saved to: {args.output_dir}")
213
  print(f"Training config saved to: {config_path}")
214
+ print(f"Log history saved to: {log_path}")
215
  print("\nNext: run inference_eval.py to compare baseline vs trained")
216
  print(" python training/inference_eval.py --model_path " + args.output_dir)
217