InosLihka Claude Opus 4.7 (1M context) commited on
Commit
dc0186f
·
1 Parent(s): 1a865f8

tune: GRPO hyperparameter fixes from ML reviewer

Browse files

Root cause of likely under-training in v1: 4-layer reward stack with
format/legal saturating at +1 each was drowning the env_reward and
belief_accuracy meta-RL signals. Plus several TRL/DeepSeek best-practice
deviations.

Changes:
- beta: 0.1 -> 0.04 (TRL/DeepSeek default; was anchoring policy too hard)
- num_generations: 4 -> 8 (halves advantage-estimate variance in
continuous-profile setting; critical for GRPO signal)
- max_completion_length: 20 -> 32 (was silently truncating belief digits
for actions like FAMILY_TIME and BINGE_WATCH after BPE tokenization)
- reward_weights=[0.3, 0.3, 1.0, 1.0] in GRPOConfig: scales format_valid +
action_legal down so env_reward and belief_accuracy dominate the gradient.
Wrapped in try/except for TRL versions that don't support reward_weights.

FAST_MODE preset also bumped NUM_GENERATIONS from 2 -> 4. Two completions
per group is too low to estimate group advantage at all.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (2) hide show
  1. scripts/train_on_hf.py +3 -2
  2. training/train.py +20 -8
scripts/train_on_hf.py CHANGED
@@ -57,12 +57,13 @@ PLOTS_DIR = "/tmp/rhythm_env/plots"
57
  FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
58
 
59
  if FAST_MODE:
 
60
  DEFAULTS = dict(MAX_STEPS=200, NUM_EPISODES=80, MAX_SAMPLES=800,
61
- NUM_GENERATIONS=2, LORA_RANK=8, BETA=0.1,
62
  LEARNING_RATE=5e-5, EVAL_EPISODES=2)
63
  else:
64
  DEFAULTS = dict(MAX_STEPS=1500, NUM_EPISODES=300, MAX_SAMPLES=3000,
65
- NUM_GENERATIONS=4, LORA_RANK=8, BETA=0.1,
66
  LEARNING_RATE=5e-5, EVAL_EPISODES=5)
67
 
68
  MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"])))
 
57
  FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
58
 
59
  if FAST_MODE:
60
+ # Smoke-train preset: enough signal in 200 steps to decide go/no-go
61
  DEFAULTS = dict(MAX_STEPS=200, NUM_EPISODES=80, MAX_SAMPLES=800,
62
+ NUM_GENERATIONS=4, LORA_RANK=8, BETA=0.04,
63
  LEARNING_RATE=5e-5, EVAL_EPISODES=2)
64
  else:
65
  DEFAULTS = dict(MAX_STEPS=1500, NUM_EPISODES=300, MAX_SAMPLES=3000,
66
+ NUM_GENERATIONS=8, LORA_RANK=8, BETA=0.04,
67
  LEARNING_RATE=5e-5, EVAL_EPISODES=5)
68
 
69
  MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"])))
training/train.py CHANGED
@@ -34,11 +34,11 @@ def main():
34
  help="Number of episodes for dataset generation (more diversity = better meta-RL)")
35
  parser.add_argument("--max_samples", type=int, default=3000,
36
  help="Maximum training samples")
37
- parser.add_argument("--num_generations", type=int, default=4,
38
- help="Completions per prompt for GRPO (higher = lower variance, more compute)")
39
  parser.add_argument("--learning_rate", type=float, default=5e-5)
40
- parser.add_argument("--beta", type=float, default=0.1,
41
- help="KL penalty (raise to 0.2 if training is unstable)")
42
  parser.add_argument("--lora_rank", type=int, default=8,
43
  help="LoRA rank (8 = more capacity than original 4 for meta-RL)")
44
  parser.add_argument("--hint_fraction", type=float, default=0.15,
@@ -141,10 +141,15 @@ def main():
141
 
142
  from trl import GRPOConfig, GRPOTrainer
143
 
144
- max_prompt_length = 600 # bumped from 400 for longer prompts (history + hint)
145
- max_completion_length = 20 # bumped from 16 for "ACTION_NAME D D D" format
146
 
147
- training_args = GRPOConfig(
 
 
 
 
 
148
  temperature=1.0,
149
  learning_rate=args.learning_rate,
150
  beta=args.beta,
@@ -160,10 +165,17 @@ def main():
160
  max_prompt_length=max_prompt_length,
161
  max_completion_length=max_completion_length,
162
  max_steps=args.max_steps,
163
- save_steps=250, # checkpoint every 250 (was 100)
164
  report_to=args.report_to,
165
  output_dir=args.output_dir,
166
  )
 
 
 
 
 
 
 
167
 
168
  print(f"max_steps={args.max_steps}, num_generations={args.num_generations}, "
169
  f"lr={args.learning_rate}, beta={args.beta}")
 
34
  help="Number of episodes for dataset generation (more diversity = better meta-RL)")
35
  parser.add_argument("--max_samples", type=int, default=3000,
36
  help="Maximum training samples")
37
+ parser.add_argument("--num_generations", type=int, default=8,
38
+ help="Completions per prompt for GRPO (8 default, lower variance for continuous-profile meta-RL)")
39
  parser.add_argument("--learning_rate", type=float, default=5e-5)
40
+ parser.add_argument("--beta", type=float, default=0.04,
41
+ help="KL penalty (TRL/DeepSeek default; raise to 0.1+ if KL diverges)")
42
  parser.add_argument("--lora_rank", type=int, default=8,
43
  help="LoRA rank (8 = more capacity than original 4 for meta-RL)")
44
  parser.add_argument("--hint_fraction", type=float, default=0.15,
 
141
 
142
  from trl import GRPOConfig, GRPOTrainer
143
 
144
+ max_prompt_length = 600 # history + hint room
145
+ max_completion_length = 32 # bumped from 20 to prevent silent truncation of belief digits
146
 
147
+ # reward_weights: scale per-layer to prevent format/legal (saturated near +1) from
148
+ # drowning out env_reward and belief_accuracy (the actual learning signals).
149
+ # Order MUST match reward_funcs in main(): format_valid, action_legal, env_reward, belief_accuracy
150
+ reward_weights = [0.3, 0.3, 1.0, 1.0]
151
+
152
+ training_args_kwargs = dict(
153
  temperature=1.0,
154
  learning_rate=args.learning_rate,
155
  beta=args.beta,
 
165
  max_prompt_length=max_prompt_length,
166
  max_completion_length=max_completion_length,
167
  max_steps=args.max_steps,
168
+ save_steps=250,
169
  report_to=args.report_to,
170
  output_dir=args.output_dir,
171
  )
172
+ # reward_weights was added in TRL 0.13+; pass only if supported
173
+ try:
174
+ training_args = GRPOConfig(**training_args_kwargs, reward_weights=reward_weights)
175
+ print(f"Using GRPOConfig with reward_weights={reward_weights}")
176
+ except TypeError:
177
+ training_args = GRPOConfig(**training_args_kwargs)
178
+ print("WARN: TRL version does not support reward_weights; using uniform weighting")
179
 
180
  print(f"max_steps={args.max_steps}, num_generations={args.num_generations}, "
181
  f"lr={args.learning_rate}, beta={args.beta}")