Spaces:
Sleeping
tune: GRPO hyperparameter fixes from ML reviewer
Browse filesRoot cause of likely under-training in v1: 4-layer reward stack with
format/legal saturating at +1 each was drowning the env_reward and
belief_accuracy meta-RL signals. Plus several TRL/DeepSeek best-practice
deviations.
Changes:
- beta: 0.1 -> 0.04 (TRL/DeepSeek default; was anchoring policy too hard)
- num_generations: 4 -> 8 (halves advantage-estimate variance in
continuous-profile setting; critical for GRPO signal)
- max_completion_length: 20 -> 32 (was silently truncating belief digits
for actions like FAMILY_TIME and BINGE_WATCH after BPE tokenization)
- reward_weights=[0.3, 0.3, 1.0, 1.0] in GRPOConfig: scales format_valid +
action_legal down so env_reward and belief_accuracy dominate the gradient.
Wrapped in try/except for TRL versions that don't support reward_weights.
FAST_MODE preset also bumped NUM_GENERATIONS from 2 -> 4. Two completions
per group is too low to estimate group advantage at all.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- scripts/train_on_hf.py +3 -2
- training/train.py +20 -8
|
@@ -57,12 +57,13 @@ PLOTS_DIR = "/tmp/rhythm_env/plots"
|
|
| 57 |
FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
|
| 58 |
|
| 59 |
if FAST_MODE:
|
|
|
|
| 60 |
DEFAULTS = dict(MAX_STEPS=200, NUM_EPISODES=80, MAX_SAMPLES=800,
|
| 61 |
-
NUM_GENERATIONS=
|
| 62 |
LEARNING_RATE=5e-5, EVAL_EPISODES=2)
|
| 63 |
else:
|
| 64 |
DEFAULTS = dict(MAX_STEPS=1500, NUM_EPISODES=300, MAX_SAMPLES=3000,
|
| 65 |
-
NUM_GENERATIONS=
|
| 66 |
LEARNING_RATE=5e-5, EVAL_EPISODES=5)
|
| 67 |
|
| 68 |
MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"])))
|
|
|
|
| 57 |
FAST_MODE = os.environ.get("FAST_MODE", "0") == "1"
|
| 58 |
|
| 59 |
if FAST_MODE:
|
| 60 |
+
# Smoke-train preset: enough signal in 200 steps to decide go/no-go
|
| 61 |
DEFAULTS = dict(MAX_STEPS=200, NUM_EPISODES=80, MAX_SAMPLES=800,
|
| 62 |
+
NUM_GENERATIONS=4, LORA_RANK=8, BETA=0.04,
|
| 63 |
LEARNING_RATE=5e-5, EVAL_EPISODES=2)
|
| 64 |
else:
|
| 65 |
DEFAULTS = dict(MAX_STEPS=1500, NUM_EPISODES=300, MAX_SAMPLES=3000,
|
| 66 |
+
NUM_GENERATIONS=8, LORA_RANK=8, BETA=0.04,
|
| 67 |
LEARNING_RATE=5e-5, EVAL_EPISODES=5)
|
| 68 |
|
| 69 |
MAX_STEPS = int(os.environ.get("MAX_STEPS", str(DEFAULTS["MAX_STEPS"])))
|
|
@@ -34,11 +34,11 @@ def main():
|
|
| 34 |
help="Number of episodes for dataset generation (more diversity = better meta-RL)")
|
| 35 |
parser.add_argument("--max_samples", type=int, default=3000,
|
| 36 |
help="Maximum training samples")
|
| 37 |
-
parser.add_argument("--num_generations", type=int, default=
|
| 38 |
-
help="Completions per prompt for GRPO (
|
| 39 |
parser.add_argument("--learning_rate", type=float, default=5e-5)
|
| 40 |
-
parser.add_argument("--beta", type=float, default=0.
|
| 41 |
-
help="KL penalty (raise to 0.
|
| 42 |
parser.add_argument("--lora_rank", type=int, default=8,
|
| 43 |
help="LoRA rank (8 = more capacity than original 4 for meta-RL)")
|
| 44 |
parser.add_argument("--hint_fraction", type=float, default=0.15,
|
|
@@ -141,10 +141,15 @@ def main():
|
|
| 141 |
|
| 142 |
from trl import GRPOConfig, GRPOTrainer
|
| 143 |
|
| 144 |
-
max_prompt_length = 600 #
|
| 145 |
-
max_completion_length =
|
| 146 |
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
temperature=1.0,
|
| 149 |
learning_rate=args.learning_rate,
|
| 150 |
beta=args.beta,
|
|
@@ -160,10 +165,17 @@ def main():
|
|
| 160 |
max_prompt_length=max_prompt_length,
|
| 161 |
max_completion_length=max_completion_length,
|
| 162 |
max_steps=args.max_steps,
|
| 163 |
-
save_steps=250,
|
| 164 |
report_to=args.report_to,
|
| 165 |
output_dir=args.output_dir,
|
| 166 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
print(f"max_steps={args.max_steps}, num_generations={args.num_generations}, "
|
| 169 |
f"lr={args.learning_rate}, beta={args.beta}")
|
|
|
|
| 34 |
help="Number of episodes for dataset generation (more diversity = better meta-RL)")
|
| 35 |
parser.add_argument("--max_samples", type=int, default=3000,
|
| 36 |
help="Maximum training samples")
|
| 37 |
+
parser.add_argument("--num_generations", type=int, default=8,
|
| 38 |
+
help="Completions per prompt for GRPO (8 default, lower variance for continuous-profile meta-RL)")
|
| 39 |
parser.add_argument("--learning_rate", type=float, default=5e-5)
|
| 40 |
+
parser.add_argument("--beta", type=float, default=0.04,
|
| 41 |
+
help="KL penalty (TRL/DeepSeek default; raise to 0.1+ if KL diverges)")
|
| 42 |
parser.add_argument("--lora_rank", type=int, default=8,
|
| 43 |
help="LoRA rank (8 = more capacity than original 4 for meta-RL)")
|
| 44 |
parser.add_argument("--hint_fraction", type=float, default=0.15,
|
|
|
|
| 141 |
|
| 142 |
from trl import GRPOConfig, GRPOTrainer
|
| 143 |
|
| 144 |
+
max_prompt_length = 600 # history + hint room
|
| 145 |
+
max_completion_length = 32 # bumped from 20 to prevent silent truncation of belief digits
|
| 146 |
|
| 147 |
+
# reward_weights: scale per-layer to prevent format/legal (saturated near +1) from
|
| 148 |
+
# drowning out env_reward and belief_accuracy (the actual learning signals).
|
| 149 |
+
# Order MUST match reward_funcs in main(): format_valid, action_legal, env_reward, belief_accuracy
|
| 150 |
+
reward_weights = [0.3, 0.3, 1.0, 1.0]
|
| 151 |
+
|
| 152 |
+
training_args_kwargs = dict(
|
| 153 |
temperature=1.0,
|
| 154 |
learning_rate=args.learning_rate,
|
| 155 |
beta=args.beta,
|
|
|
|
| 165 |
max_prompt_length=max_prompt_length,
|
| 166 |
max_completion_length=max_completion_length,
|
| 167 |
max_steps=args.max_steps,
|
| 168 |
+
save_steps=250,
|
| 169 |
report_to=args.report_to,
|
| 170 |
output_dir=args.output_dir,
|
| 171 |
)
|
| 172 |
+
# reward_weights was added in TRL 0.13+; pass only if supported
|
| 173 |
+
try:
|
| 174 |
+
training_args = GRPOConfig(**training_args_kwargs, reward_weights=reward_weights)
|
| 175 |
+
print(f"Using GRPOConfig with reward_weights={reward_weights}")
|
| 176 |
+
except TypeError:
|
| 177 |
+
training_args = GRPOConfig(**training_args_kwargs)
|
| 178 |
+
print("WARN: TRL version does not support reward_weights; using uniform weighting")
|
| 179 |
|
| 180 |
print(f"max_steps={args.max_steps}, num_generations={args.num_generations}, "
|
| 181 |
f"lr={args.learning_rate}, beta={args.beta}")
|