fix: OOM at batch=256 — cap batch by logits memory, enable grad ckpt
Browse filesRoot cause: batch=256 × seq=16 × vocab=200073 × 4B = 3.28 GB for logits
alone. Plus backward gradients of same size + activations across 28
layers → exceeds 32 GB RAM → system swaps → appears frozen.
Fix in loops.py:
- Add _safe_batch() that caps eff_batch so logits tensor stays under
max_logits_gb (default 2 GB). With vocab=200073:
batch=256,seq=16: logits=3.28GB → capped to batch=156,seq=16: 2.0GB
batch=128,seq=32: logits=3.28GB → capped to batch=78,seq=32: 2.0GB
batch=64,seq=64: logits=3.28GB → capped to batch=39,seq=64: 2.0GB
batch=32,seq=128: logits=3.28GB → capped to batch=19,seq=128: 1.95GB
- Enable gradient_checkpointing on model (recompute activations during
backward, saves ~60% activation memory at cost of ~30% more compute)
Fix in train_hyper.py:
- batch_size default 32→4 (base batch; GrowLength scales up with cap)
- GrowLength stages use fixed safe batch sizes directly"
- chimera/training/loops.py +41 -20
|
@@ -13,6 +13,24 @@ from .common import save_final_checkpoint, save_training_checkpoint
|
|
| 13 |
from .hyper import ProgressiveLoopScheduler
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def train_fast_loop(args, model, config, loader, compute_loss) -> str:
|
| 17 |
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
|
| 18 |
os.makedirs(args.output_dir, exist_ok=True)
|
|
@@ -52,41 +70,40 @@ def train_standard_loop(args, model, config, loader, compute_loss, optimizer, us
|
|
| 52 |
|
| 53 |
def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
|
| 54 |
use_compile = getattr(args, "compile", False)
|
|
|
|
| 55 |
|
| 56 |
# ── Muon LR for ternary BitLinear ──
|
| 57 |
-
# v12.1: Raised from 0.008 to 0.012. The clamp-aware STE in BitLinear
|
| 58 |
-
# gates gradients to zero for weights outside [-1, 1], so the effective
|
| 59 |
-
# learning signal is self-limiting. 0.012 is the highest rate before
|
| 60 |
-
# NS-orthogonalized momentum causes oscillation at the STE boundary.
|
| 61 |
-
# At 300 steps, every step counts — 0.008 converges too slowly.
|
| 62 |
muon_lr = 0.012
|
| 63 |
-
muon_warmup = 30
|
| 64 |
|
| 65 |
model, optimizer, scheduler, extras = chimera_turbo.apply(
|
| 66 |
model,
|
| 67 |
max_steps=args.max_steps,
|
| 68 |
lr=muon_lr,
|
| 69 |
-
weight_decay=0.02,
|
| 70 |
warmup_steps=muon_warmup,
|
| 71 |
use_compile=use_compile,
|
| 72 |
-
mtp_heads=0,
|
| 73 |
-
llrd_decay=0.90,
|
| 74 |
-
grokfast_alpha=0.95,
|
| 75 |
-
grokfast_lambda=1.5,
|
| 76 |
)
|
| 77 |
model.train()
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# ── Looping: force loops=1 for all 300 steps ──
|
| 80 |
-
# Progressive 1→2→3 doubles/triples forward cost. At 300 steps,
|
| 81 |
-
# throughput (more tokens seen) beats iterative refinement (same
|
| 82 |
-
# tokens processed multiple times). Each extra loop adds ~18 layers
|
| 83 |
-
# of compute through the loop trunk for diminishing convergence gain.
|
| 84 |
cur_loops = 1
|
| 85 |
-
raw_model = getattr(model, "_orig_mod", model)
|
| 86 |
if hasattr(raw_model, "loop_controller"):
|
| 87 |
raw_model.loop_controller.loop_default = 1
|
| 88 |
raw_model.loop_controller.loop_min = 1
|
| 89 |
-
raw_model.loop_controller.loop_max = 1
|
| 90 |
|
| 91 |
use_bf16 = bool(args.bf16)
|
| 92 |
|
|
@@ -95,7 +112,11 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 95 |
step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0
|
| 96 |
t0 = time.time()
|
| 97 |
cur_seq = initial_seq
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
loader = torch.utils.data.DataLoader(
|
| 100 |
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
|
| 101 |
data_iter = iter(loader)
|
|
@@ -110,13 +131,13 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 110 |
if ns != cur_seq:
|
| 111 |
cur_seq = ns
|
| 112 |
dataset.set_seq_len(cur_seq)
|
| 113 |
-
|
|
|
|
| 114 |
loader = torch.utils.data.DataLoader(
|
| 115 |
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
|
| 116 |
data_iter = iter(loader)
|
| 117 |
print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
|
| 118 |
|
| 119 |
-
# Loops locked to 1 — no progressive schedule
|
| 120 |
if unfreezer:
|
| 121 |
unfreezer.update(step)
|
| 122 |
|
|
|
|
| 13 |
from .hyper import ProgressiveLoopScheduler
|
| 14 |
|
| 15 |
|
| 16 |
+
def _safe_batch(desired_batch: int, seq_len: int, vocab_size: int,
|
| 17 |
+
max_logits_gb: float = 2.0) -> int:
|
| 18 |
+
"""Cap batch size so the logits tensor fits in memory.
|
| 19 |
+
|
| 20 |
+
Logits shape: [batch, seq, vocab] at fp32 = batch * seq * vocab * 4 bytes.
|
| 21 |
+
With vocab=200073, batch=256, seq=16: 3.28 GB just for logits.
|
| 22 |
+
Backward doubles this. Must stay well under 32 GB total.
|
| 23 |
+
"""
|
| 24 |
+
bytes_per_sample = seq_len * vocab_size * 4 # fp32 logits
|
| 25 |
+
max_bytes = int(max_logits_gb * 1024**3)
|
| 26 |
+
max_batch = max(1, max_bytes // bytes_per_sample)
|
| 27 |
+
capped = min(desired_batch, max_batch)
|
| 28 |
+
if capped < desired_batch:
|
| 29 |
+
print(f" [MEM] Batch {desired_batch} → {capped} (logits would be "
|
| 30 |
+
f"{desired_batch * seq_len * vocab_size * 4 / 1e9:.1f} GB, cap={max_logits_gb} GB)")
|
| 31 |
+
return capped
|
| 32 |
+
|
| 33 |
+
|
| 34 |
def train_fast_loop(args, model, config, loader, compute_loss) -> str:
|
| 35 |
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
|
| 36 |
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
| 70 |
|
| 71 |
def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
|
| 72 |
use_compile = getattr(args, "compile", False)
|
| 73 |
+
vocab_size = int(config.get("vocab_size", 200073))
|
| 74 |
|
| 75 |
# ── Muon LR for ternary BitLinear ──
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
muon_lr = 0.012
|
| 77 |
+
muon_warmup = 30
|
| 78 |
|
| 79 |
model, optimizer, scheduler, extras = chimera_turbo.apply(
|
| 80 |
model,
|
| 81 |
max_steps=args.max_steps,
|
| 82 |
lr=muon_lr,
|
| 83 |
+
weight_decay=0.02,
|
| 84 |
warmup_steps=muon_warmup,
|
| 85 |
use_compile=use_compile,
|
| 86 |
+
mtp_heads=0,
|
| 87 |
+
llrd_decay=0.90,
|
| 88 |
+
grokfast_alpha=0.95,
|
| 89 |
+
grokfast_lambda=1.5,
|
| 90 |
)
|
| 91 |
model.train()
|
| 92 |
|
| 93 |
+
# ── Gradient checkpointing: saves ~60% activation memory ──
|
| 94 |
+
# Critical with vocab=200K: without it, activations across 28 layers
|
| 95 |
+
# at batch=32 can consume several GB.
|
| 96 |
+
raw_model = getattr(model, "_orig_mod", model)
|
| 97 |
+
if hasattr(raw_model, "enable_gradient_checkpointing"):
|
| 98 |
+
raw_model.enable_gradient_checkpointing()
|
| 99 |
+
print(f"[OPT] Gradient checkpointing: ON")
|
| 100 |
+
|
| 101 |
# ── Looping: force loops=1 for all 300 steps ──
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
cur_loops = 1
|
|
|
|
| 103 |
if hasattr(raw_model, "loop_controller"):
|
| 104 |
raw_model.loop_controller.loop_default = 1
|
| 105 |
raw_model.loop_controller.loop_min = 1
|
| 106 |
+
raw_model.loop_controller.loop_max = 1
|
| 107 |
|
| 108 |
use_bf16 = bool(args.bf16)
|
| 109 |
|
|
|
|
| 112 |
step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0
|
| 113 |
t0 = time.time()
|
| 114 |
cur_seq = initial_seq
|
| 115 |
+
|
| 116 |
+
# ── Compute memory-safe batch size ──
|
| 117 |
+
desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 118 |
+
eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
|
| 119 |
+
|
| 120 |
loader = torch.utils.data.DataLoader(
|
| 121 |
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
|
| 122 |
data_iter = iter(loader)
|
|
|
|
| 131 |
if ns != cur_seq:
|
| 132 |
cur_seq = ns
|
| 133 |
dataset.set_seq_len(cur_seq)
|
| 134 |
+
desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 135 |
+
eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
|
| 136 |
loader = torch.utils.data.DataLoader(
|
| 137 |
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
|
| 138 |
data_iter = iter(loader)
|
| 139 |
print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
|
| 140 |
|
|
|
|
| 141 |
if unfreezer:
|
| 142 |
unfreezer.update(step)
|
| 143 |
|