InosLihka Claude Opus 4.7 (1M context) commited on
Commit
1a865f8
·
1 Parent(s): 73c7ea0

feat: FAST_MODE preset for 10-15 min iteration cycles

Browse files

Adds knobs for fast hyperparameter sweeps:
- FAST_MODE=1 -> 200 steps, 80 episodes, 800 samples, 2 generations, 2 eval episodes
- All knobs individually overridable via env vars
- MODEL_REPO_SUFFIX lets each run upload to a unique repo for comparison

On A100-large: FAST_MODE finishes in ~10-12 min for ~$0.70 per iteration.
Use to debug training stability and tune beta/lr before committing to a
full 1500-step run.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. scripts/train_on_hf.py +38 -7
scripts/train_on_hf.py CHANGED
@@ -52,12 +52,40 @@ WORK_DIR = "/tmp/rhythm_env"
52
  OUTPUT_DIR = "/tmp/rhythm_env/outputs/rhythmenv_meta_trained"
53
  PLOTS_DIR = "/tmp/rhythm_env/plots"
54
 
55
- MAX_STEPS = int(os.environ.get("MAX_STEPS", "1500"))
56
- NUM_EPISODES = int(os.environ.get("NUM_EPISODES", "300"))
57
- LORA_RANK = int(os.environ.get("LORA_RANK", "8"))
58
- BETA = float(os.environ.get("BETA", "0.1"))
59
-
60
- MODEL_REPO = os.environ.get("MODEL_REPO", "InosLihka/rhythm-env-meta-trained")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  def run(cmd: list[str], **kw):
@@ -89,8 +117,11 @@ def main():
89
  "python", "training/train.py",
90
  "--max_steps", str(MAX_STEPS),
91
  "--num_episodes", str(NUM_EPISODES),
 
 
92
  "--lora_rank", str(LORA_RANK),
93
  "--beta", str(BETA),
 
94
  "--output_dir", OUTPUT_DIR,
95
  ]
96
  run(train_args)
@@ -101,7 +132,7 @@ def main():
101
  eval_args = [
102
  "python", "training/inference_eval.py",
103
  "--model_path", OUTPUT_DIR,
104
- "--num_episodes", "5",
105
  "--output_file", "eval_results.json",
106
  ]
107
  run(eval_args)
 
52
  OUTPUT_DIR = "/tmp/rhythm_env/outputs/rhythmenv_meta_trained"
53
  PLOTS_DIR = "/tmp/rhythm_env/plots"
54
 
55
+ # FAST_MODE preset: ~10-15 min iteration on A100 large.
56
+ # Use for hyperparameter sweeps and pipeline debugging.
57
+ FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
58
+
59
+ if FAST_MODE:
60
+ DEFAULTS = dict(MAX_STEPS=200, NUM_EPISODES=80, MAX_SAMPLES=800,
61
+ NUM_GENERATIONS=2, LORA_RANK=8, BETA=0.1,
62
+ LEARNING_RATE=5e-5, EVAL_EPISODES=2)
63
+ else:
64
+ DEFAULTS = dict(MAX_STEPS=1500, NUM_EPISODES=300, MAX_SAMPLES=3000,
65
+ NUM_GENERATIONS=4, LORA_RANK=8, BETA=0.1,
66
+ LEARNING_RATE=5e-5, EVAL_EPISODES=5)
67
+
68
+ MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"])))
69
+ NUM_EPISODES = int(os.environ.get("NUM_EPISODES", str(DEFAULTS["NUM_EPISODES"])))
70
+ MAX_SAMPLES = int(os.environ.get("MAX_SAMPLES", str(DEFAULTS["MAX_SAMPLES"])))
71
+ NUM_GENERATIONS = int(os.environ.get("NUM_GENERATIONS", str(DEFAULTS["NUM_GENERATIONS"])))
72
+ LORA_RANK = int(os.environ.get("LORA_RANK", str(DEFAULTS["LORA_RANK"])))
73
+ BETA = float(os.environ.get("BETA", str(DEFAULTS["BETA"])))
74
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", str(DEFAULTS["LEARNING_RATE"])))
75
+ EVAL_EPISODES = int(os.environ.get("EVAL_EPISODES", str(DEFAULTS["EVAL_EPISODES"])))
76
+
77
+ # Each iteration uploads to a unique repo if MODEL_REPO_SUFFIX is set
78
+ SUFFIX = os.environ.get("MODEL_REPO_SUFFIX", "")
79
+ DEFAULT_REPO = "InosLihka/rhythm-env-meta-trained" + (f"-{SUFFIX}" if SUFFIX else "")
80
+ MODEL_REPO = os.environ.get("MODEL_REPO", DEFAULT_REPO)
81
+
82
+ print(f"=== Run config ===")
83
+ print(f" FAST_MODE: {FAST_MODE}")
84
+ print(f" MAX_STEPS={MAX_STEPS}, NUM_EPISODES={NUM_EPISODES}, MAX_SAMPLES={MAX_SAMPLES}")
85
+ print(f" NUM_GENERATIONS={NUM_GENERATIONS}, LORA_RANK={LORA_RANK}, BETA={BETA}")
86
+ print(f" LEARNING_RATE={LEARNING_RATE}, EVAL_EPISODES={EVAL_EPISODES}")
87
+ print(f" MODEL_REPO={MODEL_REPO}")
88
+ print()
89
 
90
 
91
  def run(cmd: list[str], **kw):
 
117
  "python", "training/train.py",
118
  "--max_steps", str(MAX_STEPS),
119
  "--num_episodes", str(NUM_EPISODES),
120
+ "--max_samples", str(MAX_SAMPLES),
121
+ "--num_generations", str(NUM_GENERATIONS),
122
  "--lora_rank", str(LORA_RANK),
123
  "--beta", str(BETA),
124
+ "--learning_rate", str(LEARNING_RATE),
125
  "--output_dir", OUTPUT_DIR,
126
  ]
127
  run(train_args)
 
132
  eval_args = [
133
  "python", "training/inference_eval.py",
134
  "--model_path", OUTPUT_DIR,
135
+ "--num_episodes", str(EVAL_EPISODES),
136
  "--output_file", "eval_results.json",
137
  ]
138
  run(eval_args)