Don Rishabh Claude Opus 4.7 (1M context) commited on
Commit
7ca042f
·
1 Parent(s): 89ed87f

multistep: gradient checkpointing + tighter memory defaults

Browse files

OOM in multi-turn rollouts because:
- Prompts grow each turn (turn 3 can hit 3-5k tokens with
prior_attempts folded into the chat user message)
- All B×G×turn_limit StepRecords sit in memory until the gradient
pass (24 records on default knobs)
- Three models co-resident on L40S (agent+LoRA, frozen target,
8-bit judge) leave only ~30 GB for activations + gradients

Fixes:
- --gradient-checkpointing default ON: ~80% activation memory
saved at ~30% extra compute. Critical for multi-step.
- --update-micro-batch 4 -> 2: half the activation memory per
backward
- --max-prompt-tokens 4096 -> 2048: drops the longest prior
turn first when chat prompt overflows
- --max-new-tokens 768 -> 384: half the per-turn generation cap
(bump back if thinking-mode answers truncate)
- hf_job_train_multistep.sh launcher exposes the new knobs as
env vars (MAX_PROMPT_TOKENS, MAX_NEW_TOKENS, UPDATE_MICRO_BATCH)
and passes them through

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

training/hf_job_train_multistep.sh CHANGED
@@ -25,6 +25,11 @@ SFT_ADAPTER="${SFT_ADAPTER:-}" # optional warmstart from a single-step adapter
25
  MAX_STEPS="${MAX_STEPS:-200}"
26
  NUM_GENS="${NUM_GENS:-4}"
27
  BATCH_SIZE="${BATCH_SIZE:-2}"
 
 
 
 
 
28
  LR="${LR:-3e-6}"
29
  BETA="${BETA:-0.04}"
30
  TURN_LIMIT="${TURN_LIMIT:-3}"
@@ -88,6 +93,9 @@ python -u training/train_grpo_multistep.py \\
88
  --max-steps ${MAX_STEPS} \\
89
  --num-gens ${NUM_GENS} \\
90
  --batch-size ${BATCH_SIZE} \\
 
 
 
91
  --lr ${LR} \\
92
  --beta ${BETA} \\
93
  --output-dir /app/outputs/grpo_multistep \\
 
25
  MAX_STEPS="${MAX_STEPS:-200}"
26
  NUM_GENS="${NUM_GENS:-4}"
27
  BATCH_SIZE="${BATCH_SIZE:-2}"
28
+ # Memory-aware defaults — multi-turn prompts grow with prior_attempts
29
+ # (turn 3 can hit 3-5k tokens), so tighter caps + grad checkpointing.
30
+ MAX_PROMPT_TOKENS="${MAX_PROMPT_TOKENS:-2048}"
31
+ MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-384}"
32
+ UPDATE_MICRO_BATCH="${UPDATE_MICRO_BATCH:-2}"
33
  LR="${LR:-3e-6}"
34
  BETA="${BETA:-0.04}"
35
  TURN_LIMIT="${TURN_LIMIT:-3}"
 
93
  --max-steps ${MAX_STEPS} \\
94
  --num-gens ${NUM_GENS} \\
95
  --batch-size ${BATCH_SIZE} \\
96
+ --max-prompt-tokens ${MAX_PROMPT_TOKENS} \\
97
+ --max-new-tokens ${MAX_NEW_TOKENS} \\
98
+ --update-micro-batch ${UPDATE_MICRO_BATCH} \\
99
  --lr ${LR} \\
100
  --beta ${BETA} \\
101
  --output-dir /app/outputs/grpo_multistep \\
training/train_grpo_multistep.py CHANGED
@@ -235,11 +235,27 @@ def parse_args() -> argparse.Namespace:
235
  p.add_argument("--beta", type=float, default=0.04,
236
  help="KL penalty vs frozen LoRA snapshot.")
237
  p.add_argument("--temperature", type=float, default=0.9)
238
- p.add_argument("--max-new-tokens", type=int, default=768)
239
- p.add_argument("--max-prompt-tokens", type=int, default=4096)
 
 
 
 
 
 
240
  p.add_argument("--max-grad-norm", type=float, default=0.5)
241
- p.add_argument("--update-micro-batch", type=int, default=4,
242
- help="Records per batched forward pass.")
 
 
 
 
 
 
 
 
 
 
243
  p.add_argument("--save-every", type=int, default=50)
244
 
245
  # LoRA (used when --sft-adapter is not given — fresh LoRA init)
@@ -331,6 +347,26 @@ def main() -> None:
331
  n_tr = sum(p.numel() for p in model.parameters() if p.requires_grad)
332
  print(f" trainable params: {n_tr:,}", flush=True)
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  # ---- Snapshot trainable weights as the KL reference ----
335
  print("Snapshotting trainable weights as KL reference...", flush=True)
336
  ref_state: Dict[str, torch.Tensor] = {
 
235
  p.add_argument("--beta", type=float, default=0.04,
236
  help="KL penalty vs frozen LoRA snapshot.")
237
  p.add_argument("--temperature", type=float, default=0.9)
238
+ p.add_argument("--max-new-tokens", type=int, default=384,
239
+ help="Per-turn agent generation cap. Trim from 768 to "
240
+ "halve forward+backward memory. Bump back if "
241
+ "thinking-mode answers get truncated.")
242
+ p.add_argument("--max-prompt-tokens", type=int, default=2048,
243
+ help="Trim from 4096 — turn-3 prompts with "
244
+ "prior_attempts can hit 3-5k tokens; truncating "
245
+ "to 2k drops the longest prior turn first.")
246
  p.add_argument("--max-grad-norm", type=float, default=0.5)
247
+ p.add_argument("--update-micro-batch", type=int, default=2,
248
+ help="Records per batched forward pass. 2 halves "
249
+ "activation memory vs the default 4.")
250
+ p.add_argument("--gradient-checkpointing", action="store_true",
251
+ default=True,
252
+ help="Recompute forward activations during backward "
253
+ "instead of caching. ~80%% activation memory "
254
+ "saving at ~30%% extra compute. Default ON for "
255
+ "multi-step because trajectory rollouts blow up "
256
+ "activation memory.")
257
+ p.add_argument("--no-gradient-checkpointing",
258
+ dest="gradient_checkpointing", action="store_false")
259
  p.add_argument("--save-every", type=int, default=50)
260
 
261
  # LoRA (used when --sft-adapter is not given — fresh LoRA init)
 
347
  n_tr = sum(p.numel() for p in model.parameters() if p.requires_grad)
348
  print(f" trainable params: {n_tr:,}", flush=True)
349
 
350
+ # ---- Gradient checkpointing (default ON for multi-step) ----
351
+ # Saves ~80% activation memory at ~30% extra compute. Critical for
352
+ # multi-step because trajectory rollouts (B × G × turn_limit records)
353
+ # blow up activation memory during the backward pass.
354
+ if args.gradient_checkpointing:
355
+ # PEFT models need use_reentrant=False on modern PyTorch
356
+ try:
357
+ model.gradient_checkpointing_enable(
358
+ gradient_checkpointing_kwargs={"use_reentrant": False}
359
+ )
360
+ except TypeError:
361
+ # Older transformers/peft don't take the kwarg
362
+ model.gradient_checkpointing_enable()
363
+ # PEFT requires inputs to require grad when checkpointing the base
364
+ if hasattr(model, "enable_input_require_grads"):
365
+ model.enable_input_require_grads()
366
+ print(" gradient_checkpointing: ENABLED", flush=True)
367
+ else:
368
+ print(" gradient_checkpointing: disabled", flush=True)
369
+
370
  # ---- Snapshot trainable weights as the KL reference ----
371
  print("Snapshotting trainable weights as KL reference...", flush=True)
372
  ref_state: Dict[str, torch.Tensor] = {