Spaces:
Sleeping
multistep: gradient checkpointing + tighter memory defaults
Browse filesOOM in multi-turn rollouts because:
- Prompts grow each turn (turn 3 can hit 3-5k tokens with
prior_attempts folded into the chat user message)
- All B×G×turn_limit StepRecords sit in memory until the gradient
pass (24 records on default knobs)
- Three models co-resident on L40S (agent+LoRA, frozen target,
8-bit judge) leave only ~30 GB for activations + gradients
Fixes:
- --gradient-checkpointing default ON: ~80% activation memory
saved at ~30% extra compute. Critical for multi-step.
- --update-micro-batch 4 -> 2: half the activation memory per
backward
- --max-prompt-tokens 4096 -> 2048: drops the longest prior
turn first when chat prompt overflows
- --max-new-tokens 768 -> 384: half the per-turn generation cap
(bump back if thinking-mode answers truncate)
- hf_job_train_multistep.sh launcher exposes the new knobs as
env vars (MAX_PROMPT_TOKENS, MAX_NEW_TOKENS, UPDATE_MICRO_BATCH)
and passes them through
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@@ -25,6 +25,11 @@ SFT_ADAPTER="${SFT_ADAPTER:-}" # optional warmstart from a single-step adapter
|
|
| 25 |
MAX_STEPS="${MAX_STEPS:-200}"
|
| 26 |
NUM_GENS="${NUM_GENS:-4}"
|
| 27 |
BATCH_SIZE="${BATCH_SIZE:-2}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
LR="${LR:-3e-6}"
|
| 29 |
BETA="${BETA:-0.04}"
|
| 30 |
TURN_LIMIT="${TURN_LIMIT:-3}"
|
|
@@ -88,6 +93,9 @@ python -u training/train_grpo_multistep.py \\
|
|
| 88 |
--max-steps ${MAX_STEPS} \\
|
| 89 |
--num-gens ${NUM_GENS} \\
|
| 90 |
--batch-size ${BATCH_SIZE} \\
|
|
|
|
|
|
|
|
|
|
| 91 |
--lr ${LR} \\
|
| 92 |
--beta ${BETA} \\
|
| 93 |
--output-dir /app/outputs/grpo_multistep \\
|
|
|
|
| 25 |
MAX_STEPS="${MAX_STEPS:-200}"
|
| 26 |
NUM_GENS="${NUM_GENS:-4}"
|
| 27 |
BATCH_SIZE="${BATCH_SIZE:-2}"
|
| 28 |
+
# Memory-aware defaults — multi-turn prompts grow with prior_attempts
|
| 29 |
+
# (turn 3 can hit 3-5k tokens), so tighter caps + grad checkpointing.
|
| 30 |
+
MAX_PROMPT_TOKENS="${MAX_PROMPT_TOKENS:-2048}"
|
| 31 |
+
MAX_NEW_TOKENS="${MAX_NEW_TOKENS:-384}"
|
| 32 |
+
UPDATE_MICRO_BATCH="${UPDATE_MICRO_BATCH:-2}"
|
| 33 |
LR="${LR:-3e-6}"
|
| 34 |
BETA="${BETA:-0.04}"
|
| 35 |
TURN_LIMIT="${TURN_LIMIT:-3}"
|
|
|
|
| 93 |
--max-steps ${MAX_STEPS} \\
|
| 94 |
--num-gens ${NUM_GENS} \\
|
| 95 |
--batch-size ${BATCH_SIZE} \\
|
| 96 |
+
--max-prompt-tokens ${MAX_PROMPT_TOKENS} \\
|
| 97 |
+
--max-new-tokens ${MAX_NEW_TOKENS} \\
|
| 98 |
+
--update-micro-batch ${UPDATE_MICRO_BATCH} \\
|
| 99 |
--lr ${LR} \\
|
| 100 |
--beta ${BETA} \\
|
| 101 |
--output-dir /app/outputs/grpo_multistep \\
|
|
@@ -235,11 +235,27 @@ def parse_args() -> argparse.Namespace:
|
|
| 235 |
p.add_argument("--beta", type=float, default=0.04,
|
| 236 |
help="KL penalty vs frozen LoRA snapshot.")
|
| 237 |
p.add_argument("--temperature", type=float, default=0.9)
|
| 238 |
-
p.add_argument("--max-new-tokens", type=int, default=
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
p.add_argument("--max-grad-norm", type=float, default=0.5)
|
| 241 |
-
p.add_argument("--update-micro-batch", type=int, default=
|
| 242 |
-
help="Records per batched forward pass."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
p.add_argument("--save-every", type=int, default=50)
|
| 244 |
|
| 245 |
# LoRA (used when --sft-adapter is not given — fresh LoRA init)
|
|
@@ -331,6 +347,26 @@ def main() -> None:
|
|
| 331 |
n_tr = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 332 |
print(f" trainable params: {n_tr:,}", flush=True)
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
# ---- Snapshot trainable weights as the KL reference ----
|
| 335 |
print("Snapshotting trainable weights as KL reference...", flush=True)
|
| 336 |
ref_state: Dict[str, torch.Tensor] = {
|
|
|
|
| 235 |
p.add_argument("--beta", type=float, default=0.04,
|
| 236 |
help="KL penalty vs frozen LoRA snapshot.")
|
| 237 |
p.add_argument("--temperature", type=float, default=0.9)
|
| 238 |
+
p.add_argument("--max-new-tokens", type=int, default=384,
|
| 239 |
+
help="Per-turn agent generation cap. Trim from 768 to "
|
| 240 |
+
"halve forward+backward memory. Bump back if "
|
| 241 |
+
"thinking-mode answers get truncated.")
|
| 242 |
+
p.add_argument("--max-prompt-tokens", type=int, default=2048,
|
| 243 |
+
help="Trim from 4096 — turn-3 prompts with "
|
| 244 |
+
"prior_attempts can hit 3-5k tokens; truncating "
|
| 245 |
+
"to 2k drops the longest prior turn first.")
|
| 246 |
p.add_argument("--max-grad-norm", type=float, default=0.5)
|
| 247 |
+
p.add_argument("--update-micro-batch", type=int, default=2,
|
| 248 |
+
help="Records per batched forward pass. 2 halves "
|
| 249 |
+
"activation memory vs the default 4.")
|
| 250 |
+
p.add_argument("--gradient-checkpointing", action="store_true",
|
| 251 |
+
default=True,
|
| 252 |
+
help="Recompute forward activations during backward "
|
| 253 |
+
"instead of caching. ~80%% activation memory "
|
| 254 |
+
"saving at ~30%% extra compute. Default ON for "
|
| 255 |
+
"multi-step because trajectory rollouts blow up "
|
| 256 |
+
"activation memory.")
|
| 257 |
+
p.add_argument("--no-gradient-checkpointing",
|
| 258 |
+
dest="gradient_checkpointing", action="store_false")
|
| 259 |
p.add_argument("--save-every", type=int, default=50)
|
| 260 |
|
| 261 |
# LoRA (used when --sft-adapter is not given — fresh LoRA init)
|
|
|
|
| 347 |
n_tr = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 348 |
print(f" trainable params: {n_tr:,}", flush=True)
|
| 349 |
|
| 350 |
+
# ---- Gradient checkpointing (default ON for multi-step) ----
|
| 351 |
+
# Saves ~80% activation memory at ~30% extra compute. Critical for
|
| 352 |
+
# multi-step because trajectory rollouts (B × G × turn_limit records)
|
| 353 |
+
# blow up activation memory during the backward pass.
|
| 354 |
+
if args.gradient_checkpointing:
|
| 355 |
+
# PEFT models need use_reentrant=False on modern PyTorch
|
| 356 |
+
try:
|
| 357 |
+
model.gradient_checkpointing_enable(
|
| 358 |
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
| 359 |
+
)
|
| 360 |
+
except TypeError:
|
| 361 |
+
# Older transformers/peft don't take the kwarg
|
| 362 |
+
model.gradient_checkpointing_enable()
|
| 363 |
+
# PEFT requires inputs to require grad when checkpointing the base
|
| 364 |
+
if hasattr(model, "enable_input_require_grads"):
|
| 365 |
+
model.enable_input_require_grads()
|
| 366 |
+
print(" gradient_checkpointing: ENABLED", flush=True)
|
| 367 |
+
else:
|
| 368 |
+
print(" gradient_checkpointing: disabled", flush=True)
|
| 369 |
+
|
| 370 |
# ---- Snapshot trainable weights as the KL reference ----
|
| 371 |
print("Snapshotting trainable weights as KL reference...", flush=True)
|
| 372 |
ref_state: Dict[str, torch.Tensor] = {
|