Lgr54HFi commited on
Commit
5bfbb8a
·
verified ·
1 Parent(s): e80380b

fix: OOM at batch=256 — cap batch by logits memory, enable grad ckpt

Browse files

Root cause: batch=256 × seq=16 × vocab=200073 × 4B = 3.28 GB for logits
alone. Plus backward gradients of same size + activations across 28
layers → exceeds 32 GB RAM → system swaps → appears frozen.

Fix in loops.py:
- Add _safe_batch() that caps eff_batch so logits tensor stays under
max_logits_gb (default 2 GB). With vocab=200073:
batch=256,seq=16: logits=3.28GB → capped to batch=156,seq=16: 2.0GB
batch=128,seq=32: logits=3.28GB → capped to batch=78,seq=32: 2.0GB
batch=64,seq=64: logits=3.28GB → capped to batch=39,seq=64: 2.0GB
batch=32,seq=128: logits=3.28GB → capped to batch=19,seq=128: 1.95GB
- Enable gradient_checkpointing on model (recompute activations during
backward, saves ~60% activation memory at cost of ~30% more compute)

Fix in train_hyper.py:
- batch_size default 32→4 (base batch; GrowLength scales up with cap)
- GrowLength stages use fixed safe batch sizes directly"

Files changed (1) hide show
  1. chimera/training/loops.py +41 -20
chimera/training/loops.py CHANGED
@@ -13,6 +13,24 @@ from .common import save_final_checkpoint, save_training_checkpoint
13
  from .hyper import ProgressiveLoopScheduler
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def train_fast_loop(args, model, config, loader, compute_loss) -> str:
17
  optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
18
  os.makedirs(args.output_dir, exist_ok=True)
@@ -52,41 +70,40 @@ def train_standard_loop(args, model, config, loader, compute_loss, optimizer, us
52
 
53
  def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
54
  use_compile = getattr(args, "compile", False)
 
55
 
56
  # ── Muon LR for ternary BitLinear ──
57
- # v12.1: Raised from 0.008 to 0.012. The clamp-aware STE in BitLinear
58
- # gates gradients to zero for weights outside [-1, 1], so the effective
59
- # learning signal is self-limiting. 0.012 is the highest rate before
60
- # NS-orthogonalized momentum causes oscillation at the STE boundary.
61
- # At 300 steps, every step counts — 0.008 converges too slowly.
62
  muon_lr = 0.012
63
- muon_warmup = 30 # 10% of 300 steps; was min(args.warmup, 100)
64
 
65
  model, optimizer, scheduler, extras = chimera_turbo.apply(
66
  model,
67
  max_steps=args.max_steps,
68
  lr=muon_lr,
69
- weight_decay=0.02, # was 0.01; BitNet SLM: wd=0.05 optimal
70
  warmup_steps=muon_warmup,
71
  use_compile=use_compile,
72
- mtp_heads=0, # vocab/hidden=781:1 → MTP noisy + slow
73
- llrd_decay=0.90, # was 0.92; 0.90^27=0.058 → more bottom grad
74
- grokfast_alpha=0.95, # was 0.98; shorter EMA window for 300 steps
75
- grokfast_lambda=1.5, # was 1.0; amplify slow grads more aggressively
76
  )
77
  model.train()
78
 
 
 
 
 
 
 
 
 
79
  # ── Looping: force loops=1 for all 300 steps ──
80
- # Progressive 1→2→3 doubles/triples forward cost. At 300 steps,
81
- # throughput (more tokens seen) beats iterative refinement (same
82
- # tokens processed multiple times). Each extra loop adds ~18 layers
83
- # of compute through the loop trunk for diminishing convergence gain.
84
  cur_loops = 1
85
- raw_model = getattr(model, "_orig_mod", model)
86
  if hasattr(raw_model, "loop_controller"):
87
  raw_model.loop_controller.loop_default = 1
88
  raw_model.loop_controller.loop_min = 1
89
- raw_model.loop_controller.loop_max = 1 # Lock to 1
90
 
91
  use_bf16 = bool(args.bf16)
92
 
@@ -95,7 +112,11 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
95
  step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0
96
  t0 = time.time()
97
  cur_seq = initial_seq
98
- eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
 
 
 
 
99
  loader = torch.utils.data.DataLoader(
100
  dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
101
  data_iter = iter(loader)
@@ -110,13 +131,13 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
110
  if ns != cur_seq:
111
  cur_seq = ns
112
  dataset.set_seq_len(cur_seq)
113
- eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
 
114
  loader = torch.utils.data.DataLoader(
115
  dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
116
  data_iter = iter(loader)
117
  print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
118
 
119
- # Loops locked to 1 — no progressive schedule
120
  if unfreezer:
121
  unfreezer.update(step)
122
 
 
13
  from .hyper import ProgressiveLoopScheduler
14
 
15
 
16
+ def _safe_batch(desired_batch: int, seq_len: int, vocab_size: int,
17
+ max_logits_gb: float = 2.0) -> int:
18
+ """Cap batch size so the logits tensor fits in memory.
19
+
20
+ Logits shape: [batch, seq, vocab] at fp32 = batch * seq * vocab * 4 bytes.
21
+ With vocab=200073, batch=256, seq=16: 3.28 GB just for logits.
22
+ Backward doubles this. Must stay well under 32 GB total.
23
+ """
24
+ bytes_per_sample = seq_len * vocab_size * 4 # fp32 logits
25
+ max_bytes = int(max_logits_gb * 1024**3)
26
+ max_batch = max(1, max_bytes // bytes_per_sample)
27
+ capped = min(desired_batch, max_batch)
28
+ if capped < desired_batch:
29
+ print(f" [MEM] Batch {desired_batch} → {capped} (logits would be "
30
+ f"{desired_batch * seq_len * vocab_size * 4 / 1e9:.1f} GB, cap={max_logits_gb} GB)")
31
+ return capped
32
+
33
+
34
  def train_fast_loop(args, model, config, loader, compute_loss) -> str:
35
  optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
36
  os.makedirs(args.output_dir, exist_ok=True)
 
70
 
71
  def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
72
  use_compile = getattr(args, "compile", False)
73
+ vocab_size = int(config.get("vocab_size", 200073))
74
 
75
  # ── Muon LR for ternary BitLinear ──
 
 
 
 
 
76
  muon_lr = 0.012
77
+ muon_warmup = 30
78
 
79
  model, optimizer, scheduler, extras = chimera_turbo.apply(
80
  model,
81
  max_steps=args.max_steps,
82
  lr=muon_lr,
83
+ weight_decay=0.02,
84
  warmup_steps=muon_warmup,
85
  use_compile=use_compile,
86
+ mtp_heads=0,
87
+ llrd_decay=0.90,
88
+ grokfast_alpha=0.95,
89
+ grokfast_lambda=1.5,
90
  )
91
  model.train()
92
 
93
+ # ── Gradient checkpointing: saves ~60% activation memory ──
94
+ # Critical with vocab=200K: without it, activations across 28 layers
95
+ # at batch=32 can consume several GB.
96
+ raw_model = getattr(model, "_orig_mod", model)
97
+ if hasattr(raw_model, "enable_gradient_checkpointing"):
98
+ raw_model.enable_gradient_checkpointing()
99
+ print(f"[OPT] Gradient checkpointing: ON")
100
+
101
  # ── Looping: force loops=1 for all 300 steps ──
 
 
 
 
102
  cur_loops = 1
 
103
  if hasattr(raw_model, "loop_controller"):
104
  raw_model.loop_controller.loop_default = 1
105
  raw_model.loop_controller.loop_min = 1
106
+ raw_model.loop_controller.loop_max = 1
107
 
108
  use_bf16 = bool(args.bf16)
109
 
 
112
  step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0
113
  t0 = time.time()
114
  cur_seq = initial_seq
115
+
116
+ # ── Compute memory-safe batch size ──
117
+ desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
118
+ eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
119
+
120
  loader = torch.utils.data.DataLoader(
121
  dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
122
  data_iter = iter(loader)
 
131
  if ns != cur_seq:
132
  cur_seq = ns
133
  dataset.set_seq_len(cur_seq)
134
+ desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
135
+ eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
136
  loader = torch.utils.data.DataLoader(
137
  dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
138
  data_iter = iter(loader)
139
  print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
140
 
 
141
  if unfreezer:
142
  unfreezer.update(step)
143