File size: 8,475 Bytes
11c11f8 5b5a08d 11c11f8 3859a82 b6bcd75 11c11f8 5bfbb8a 5b5a08d 5bfbb8a 11c11f8 e2f5e25 11c11f8 9897d01 11c11f8 9897d01 11c11f8 9897d01 11c11f8 9897d01 11c11f8 e2f5e25 5bfbb8a e2f5e25 9d8c566 5bfbb8a 9d8c566 9897d01 11c11f8 6d5c935 5bfbb8a 6d5c935 f6670ea 5bfbb8a 11c11f8 b6bcd75 5bfbb8a 5b5a08d b6bcd75 9d8c566 5bfbb8a 9897d01 11c11f8 9897d01 11c11f8 5b5a08d 11c11f8 5bfbb8a 5b5a08d 5bfbb8a 31d69ba 3859a82 11c11f8 3859a82 5b5a08d 11c11f8 5bfbb8a 31d69ba 3859a82 11c11f8 8e41f12 5b5a08d b6bcd75 11c11f8 b6bcd75 11c11f8 31d69ba 5b5a08d 11c11f8 e2f5e25 3859a82 31d69ba 11c11f8 31d69ba 5b5a08d 3859a82 b6bcd75 9897d01 5b5a08d 11c11f8 31d69ba 5b5a08d 11c11f8 8e41f12 11c11f8 5b5a08d 9897d01 3859a82 5b5a08d 9897d01 11c11f8 31d69ba 5b5a08d 31d69ba 5b5a08d 31d69ba 11c11f8 3859a82 11c11f8 5b5a08d 11c11f8 3859a82 11c11f8 5b5a08d 11c11f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | from __future__ import annotations
import json
import math
import os
import sys
import time
import torch
import chimera_turbo
from .common import save_final_checkpoint, save_training_checkpoint
from .hyper import ProgressiveLoopScheduler
def _safe_batch(desired_batch: int, seq_len: int, vocab_size: int,
max_logits_gb: float = 2.0) -> int:
"""Cap batch size so the logits tensor fits in memory.
Logits shape: [batch, seq, vocab] at fp32 = batch * seq * vocab * 4 bytes.
With vocab=200073, batch=256, seq=16: 3.28 GB just for logits.
Backward doubles this. Must stay well under 32 GB total.
"""
bytes_per_sample = seq_len * vocab_size * 4 # fp32 logits
max_bytes = int(max_logits_gb * 1024**3)
max_batch = max(1, max_bytes // bytes_per_sample)
capped = min(desired_batch, max_batch)
if capped < desired_batch:
print(f" [MEM] Batch {desired_batch} β {capped} (logits would be "
f"{desired_batch * seq_len * vocab_size * 4 / 1e9:.1f} GB, cap={max_logits_gb} GB)")
sys.stdout.flush()
return capped
def train_fast_loop(args, model, config, loader, compute_loss) -> str:
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
os.makedirs(args.output_dir, exist_ok=True)
model.train()
step, total_loss, best_loss, toks = 0, 0.0, float("inf"), 0
t0 = time.time()
data_iter = iter(loader)
while step < args.max_steps:
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(loader)
batch = next(data_iter)
loss = compute_loss(batch)
loss.backward()
total_loss += float(loss.item())
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
toks += batch["input_ids"].numel()
step += 1
if step % args.log_every == 0:
dt = time.time() - t0
avg = total_loss / args.log_every
ppl = math.exp(min(avg, 20))
tps = toks / dt if dt > 0 else 0
print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} | {tps:.0f} tok/s")
best_loss = min(best_loss, avg)
total_loss, toks, t0 = 0.0, 0, time.time()
save_final_checkpoint(model, config, step, best_loss, os.path.join(args.output_dir, "final"))
return os.path.join(args.output_dir, "final")
def train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo):
pass
def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
use_compile = getattr(args, "compile", False)
vocab_size = int(config.get("vocab_size", 200073))
# ββ Muon LR for ternary BitLinear ββ
muon_lr = 0.012
muon_warmup = 30
model, optimizer, scheduler, extras = chimera_turbo.apply(
model,
max_steps=args.max_steps,
lr=muon_lr,
weight_decay=0.02,
warmup_steps=muon_warmup,
use_compile=use_compile,
mtp_heads=0,
llrd_decay=0.90,
grokfast_alpha=0.95,
grokfast_lambda=1.5,
)
model.train()
# ββ Gradient checkpointing: saves ~60% activation memory ββ
raw_model = getattr(model, "_orig_mod", model)
if hasattr(raw_model, "enable_gradient_checkpointing"):
raw_model.enable_gradient_checkpointing()
print(f"[OPT] Gradient checkpointing: ON")
# ββ Looping: force loops=1 ββ
cur_loops = 1
if hasattr(raw_model, "loop_controller"):
raw_model.loop_controller.loop_default = 1
raw_model.loop_controller.loop_min = 1
raw_model.loop_controller.loop_max = 1
use_bf16 = bool(args.bf16)
os.makedirs(args.output_dir, exist_ok=True)
log_f = open(os.path.join(args.output_dir, "log_hyper.jsonl"), "w")
step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0
t0 = time.time()
t_start = t0
cur_seq = initial_seq
# ββ Memory-safe batch size ββ
desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
loader = torch.utils.data.DataLoader(
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
data_iter = iter(loader)
print(f"\n{'=' * 65}")
print(f"Training batch={eff_batch} seq={cur_seq} loops={cur_loops}")
print(f"Starting first step (may take 30-60s on CPU with 227M params)...")
print(f"{'=' * 65}")
sys.stdout.flush()
while step < args.max_steps:
if grow:
ns = grow.get_seq_len(step)
if ns != cur_seq:
cur_seq = ns
dataset.set_seq_len(cur_seq)
desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
loader = torch.utils.data.DataLoader(
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
data_iter = iter(loader)
print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
sys.stdout.flush()
if unfreezer:
unfreezer.update(step)
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(loader)
batch = next(data_iter)
step_t0 = time.time()
loss_val = chimera_turbo.training_step(
model, batch, optimizer, scheduler,
extras=extras, grad_accum_steps=1, step=step,
autocast_dtype=torch.bfloat16 if use_bf16 else None,
)
step_dt = time.time() - step_t0
cur_lr = optimizer.param_groups[0]["lr"] * optimizer.param_groups[0].get("lr_scale", 1.0)
if math.isfinite(loss_val):
total_loss += loss_val
valid_count += 1
step_toks = batch["input_ids"].numel()
toks += step_toks
step += 1
# Print every step for the first 5 steps, then every log_every
should_log = (step <= 5) or (step % args.log_every == 0)
if step == 1:
step_tps = step_toks / step_dt if step_dt > 0 else 0
print(f" β Step 1 completed in {step_dt:.1f}s "
f"({step_tps:.0f} tok/s, loss={loss_val:.4f})")
sys.stdout.flush()
if should_log:
dt = time.time() - t0
if valid_count > 0:
avg = total_loss / valid_count
ppl = math.exp(min(avg, 20)) if math.isfinite(avg) else float("nan")
else:
avg = float("nan")
ppl = float("nan")
tps = toks / dt if dt > 0 else 0
elapsed = time.time() - t_start
eta_s = (args.max_steps - step) * (elapsed / max(1, step))
log_f.write(json.dumps({
"step": step, "loss": round(avg, 4) if math.isfinite(avg) else None,
"ppl": round(ppl, 2) if math.isfinite(ppl) else None,
"lr": round(cur_lr, 6), "tok/s": round(tps),
"seq": cur_seq, "loops": cur_loops,
"step_time": round(step_dt, 2),
}) + "\n")
log_f.flush()
print(
f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} "
f"| {tps:,.0f} tok/s | {step_dt:.1f}s/step | seq {cur_seq} "
f"| ETA {eta_s / 60:.0f}m"
)
sys.stdout.flush()
if step > 5:
# Reset counters for clean averages
best_loss = min(best_loss, avg) if math.isfinite(avg) else best_loss
total_loss, valid_count, toks, t0 = 0.0, 0, 0, time.time()
if step % args.save_every == 0:
d = save_training_checkpoint(model, config, step,
os.path.join(args.output_dir, f"ckpt-{step}"))
print(f" [SAVE] {d}")
sys.stdout.flush()
d = save_final_checkpoint(model, config, step, best_loss,
os.path.join(args.output_dir, "final"))
log_f.close()
total_time = time.time() - t_start
print(f"\nDONE -- best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}"
f" total time {total_time / 60:.1f}m")
sys.stdout.flush()
return d
|