fix: print every step + first-step timing to diagnose slow forward
Browse filesWith 227M params on CPU, each step can take 30-60s. With log_every=10,
the first output wouldn't appear for 5-10 minutes β looks like a hang.
Changes:
- Print step 1 timing immediately after first forward+backward
- Log every step for the first 5 steps, then every log_every
- Flush stdout after every print to ensure immediate display
- Add sys.stdout.flush() calls"
- chimera/training/loops.py +41 -12
chimera/training/loops.py
CHANGED
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
| 3 |
import json
|
| 4 |
import math
|
| 5 |
import os
|
|
|
|
| 6 |
import time
|
| 7 |
|
| 8 |
import torch
|
|
@@ -28,6 +29,7 @@ def _safe_batch(desired_batch: int, seq_len: int, vocab_size: int,
|
|
| 28 |
if capped < desired_batch:
|
| 29 |
print(f" [MEM] Batch {desired_batch} β {capped} (logits would be "
|
| 30 |
f"{desired_batch * seq_len * vocab_size * 4 / 1e9:.1f} GB, cap={max_logits_gb} GB)")
|
|
|
|
| 31 |
return capped
|
| 32 |
|
| 33 |
|
|
@@ -91,14 +93,12 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 91 |
model.train()
|
| 92 |
|
| 93 |
# ββ Gradient checkpointing: saves ~60% activation memory ββ
|
| 94 |
-
# Critical with vocab=200K: without it, activations across 28 layers
|
| 95 |
-
# at batch=32 can consume several GB.
|
| 96 |
raw_model = getattr(model, "_orig_mod", model)
|
| 97 |
if hasattr(raw_model, "enable_gradient_checkpointing"):
|
| 98 |
raw_model.enable_gradient_checkpointing()
|
| 99 |
print(f"[OPT] Gradient checkpointing: ON")
|
| 100 |
|
| 101 |
-
# ββ Looping: force loops=1
|
| 102 |
cur_loops = 1
|
| 103 |
if hasattr(raw_model, "loop_controller"):
|
| 104 |
raw_model.loop_controller.loop_default = 1
|
|
@@ -111,9 +111,10 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 111 |
log_f = open(os.path.join(args.output_dir, "log_hyper.jsonl"), "w")
|
| 112 |
step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0
|
| 113 |
t0 = time.time()
|
|
|
|
| 114 |
cur_seq = initial_seq
|
| 115 |
|
| 116 |
-
# ββ
|
| 117 |
desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 118 |
eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
|
| 119 |
|
|
@@ -123,7 +124,9 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 123 |
|
| 124 |
print(f"\n{'=' * 65}")
|
| 125 |
print(f"Training batch={eff_batch} seq={cur_seq} loops={cur_loops}")
|
| 126 |
-
print(f"
|
|
|
|
|
|
|
| 127 |
|
| 128 |
while step < args.max_steps:
|
| 129 |
if grow:
|
|
@@ -137,6 +140,7 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 137 |
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
|
| 138 |
data_iter = iter(loader)
|
| 139 |
print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
|
|
|
|
| 140 |
|
| 141 |
if unfreezer:
|
| 142 |
unfreezer.update(step)
|
|
@@ -147,20 +151,34 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 147 |
data_iter = iter(loader)
|
| 148 |
batch = next(data_iter)
|
| 149 |
|
|
|
|
|
|
|
| 150 |
loss_val = chimera_turbo.training_step(
|
| 151 |
model, batch, optimizer, scheduler,
|
| 152 |
extras=extras, grad_accum_steps=1, step=step,
|
| 153 |
autocast_dtype=torch.bfloat16 if use_bf16 else None,
|
| 154 |
)
|
| 155 |
|
|
|
|
|
|
|
| 156 |
cur_lr = optimizer.param_groups[0]["lr"] * optimizer.param_groups[0].get("lr_scale", 1.0)
|
| 157 |
if math.isfinite(loss_val):
|
| 158 |
total_loss += loss_val
|
| 159 |
valid_count += 1
|
| 160 |
-
|
|
|
|
| 161 |
step += 1
|
| 162 |
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
dt = time.time() - t0
|
| 165 |
if valid_count > 0:
|
| 166 |
avg = total_loss / valid_count
|
|
@@ -169,28 +187,39 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 169 |
avg = float("nan")
|
| 170 |
ppl = float("nan")
|
| 171 |
tps = toks / dt if dt > 0 else 0
|
| 172 |
-
|
|
|
|
| 173 |
log_f.write(json.dumps({
|
| 174 |
"step": step, "loss": round(avg, 4) if math.isfinite(avg) else None,
|
| 175 |
"ppl": round(ppl, 2) if math.isfinite(ppl) else None,
|
| 176 |
"lr": round(cur_lr, 6), "tok/s": round(tps),
|
| 177 |
"seq": cur_seq, "loops": cur_loops,
|
|
|
|
| 178 |
}) + "\n")
|
| 179 |
log_f.flush()
|
| 180 |
print(
|
| 181 |
f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} "
|
| 182 |
-
f"|
|
|
|
|
| 183 |
)
|
| 184 |
-
|
| 185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
if step % args.save_every == 0:
|
| 188 |
d = save_training_checkpoint(model, config, step,
|
| 189 |
os.path.join(args.output_dir, f"ckpt-{step}"))
|
| 190 |
print(f" [SAVE] {d}")
|
|
|
|
| 191 |
|
| 192 |
d = save_final_checkpoint(model, config, step, best_loss,
|
| 193 |
os.path.join(args.output_dir, "final"))
|
| 194 |
log_f.close()
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
| 196 |
return d
|
|
|
|
| 3 |
import json
|
| 4 |
import math
|
| 5 |
import os
|
| 6 |
+
import sys
|
| 7 |
import time
|
| 8 |
|
| 9 |
import torch
|
|
|
|
| 29 |
if capped < desired_batch:
|
| 30 |
print(f" [MEM] Batch {desired_batch} β {capped} (logits would be "
|
| 31 |
f"{desired_batch * seq_len * vocab_size * 4 / 1e9:.1f} GB, cap={max_logits_gb} GB)")
|
| 32 |
+
sys.stdout.flush()
|
| 33 |
return capped
|
| 34 |
|
| 35 |
|
|
|
|
| 93 |
model.train()
|
| 94 |
|
| 95 |
# ββ Gradient checkpointing: saves ~60% activation memory ββ
|
|
|
|
|
|
|
| 96 |
raw_model = getattr(model, "_orig_mod", model)
|
| 97 |
if hasattr(raw_model, "enable_gradient_checkpointing"):
|
| 98 |
raw_model.enable_gradient_checkpointing()
|
| 99 |
print(f"[OPT] Gradient checkpointing: ON")
|
| 100 |
|
| 101 |
+
# ββ Looping: force loops=1 ββ
|
| 102 |
cur_loops = 1
|
| 103 |
if hasattr(raw_model, "loop_controller"):
|
| 104 |
raw_model.loop_controller.loop_default = 1
|
|
|
|
| 111 |
log_f = open(os.path.join(args.output_dir, "log_hyper.jsonl"), "w")
|
| 112 |
step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0
|
| 113 |
t0 = time.time()
|
| 114 |
+
t_start = t0
|
| 115 |
cur_seq = initial_seq
|
| 116 |
|
| 117 |
+
# ββ Memory-safe batch size ββ
|
| 118 |
desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 119 |
eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
|
| 120 |
|
|
|
|
| 124 |
|
| 125 |
print(f"\n{'=' * 65}")
|
| 126 |
print(f"Training batch={eff_batch} seq={cur_seq} loops={cur_loops}")
|
| 127 |
+
print(f"Starting first step (may take 30-60s on CPU with 227M params)...")
|
| 128 |
+
print(f"{'=' * 65}")
|
| 129 |
+
sys.stdout.flush()
|
| 130 |
|
| 131 |
while step < args.max_steps:
|
| 132 |
if grow:
|
|
|
|
| 140 |
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
|
| 141 |
data_iter = iter(loader)
|
| 142 |
print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
|
| 143 |
+
sys.stdout.flush()
|
| 144 |
|
| 145 |
if unfreezer:
|
| 146 |
unfreezer.update(step)
|
|
|
|
| 151 |
data_iter = iter(loader)
|
| 152 |
batch = next(data_iter)
|
| 153 |
|
| 154 |
+
step_t0 = time.time()
|
| 155 |
+
|
| 156 |
loss_val = chimera_turbo.training_step(
|
| 157 |
model, batch, optimizer, scheduler,
|
| 158 |
extras=extras, grad_accum_steps=1, step=step,
|
| 159 |
autocast_dtype=torch.bfloat16 if use_bf16 else None,
|
| 160 |
)
|
| 161 |
|
| 162 |
+
step_dt = time.time() - step_t0
|
| 163 |
+
|
| 164 |
cur_lr = optimizer.param_groups[0]["lr"] * optimizer.param_groups[0].get("lr_scale", 1.0)
|
| 165 |
if math.isfinite(loss_val):
|
| 166 |
total_loss += loss_val
|
| 167 |
valid_count += 1
|
| 168 |
+
step_toks = batch["input_ids"].numel()
|
| 169 |
+
toks += step_toks
|
| 170 |
step += 1
|
| 171 |
|
| 172 |
+
# Print every step for the first 5 steps, then every log_every
|
| 173 |
+
should_log = (step <= 5) or (step % args.log_every == 0)
|
| 174 |
+
|
| 175 |
+
if step == 1:
|
| 176 |
+
step_tps = step_toks / step_dt if step_dt > 0 else 0
|
| 177 |
+
print(f" β Step 1 completed in {step_dt:.1f}s "
|
| 178 |
+
f"({step_tps:.0f} tok/s, loss={loss_val:.4f})")
|
| 179 |
+
sys.stdout.flush()
|
| 180 |
+
|
| 181 |
+
if should_log:
|
| 182 |
dt = time.time() - t0
|
| 183 |
if valid_count > 0:
|
| 184 |
avg = total_loss / valid_count
|
|
|
|
| 187 |
avg = float("nan")
|
| 188 |
ppl = float("nan")
|
| 189 |
tps = toks / dt if dt > 0 else 0
|
| 190 |
+
elapsed = time.time() - t_start
|
| 191 |
+
eta_s = (args.max_steps - step) * (elapsed / max(1, step))
|
| 192 |
log_f.write(json.dumps({
|
| 193 |
"step": step, "loss": round(avg, 4) if math.isfinite(avg) else None,
|
| 194 |
"ppl": round(ppl, 2) if math.isfinite(ppl) else None,
|
| 195 |
"lr": round(cur_lr, 6), "tok/s": round(tps),
|
| 196 |
"seq": cur_seq, "loops": cur_loops,
|
| 197 |
+
"step_time": round(step_dt, 2),
|
| 198 |
}) + "\n")
|
| 199 |
log_f.flush()
|
| 200 |
print(
|
| 201 |
f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} "
|
| 202 |
+
f"| {tps:,.0f} tok/s | {step_dt:.1f}s/step | seq {cur_seq} "
|
| 203 |
+
f"| ETA {eta_s / 60:.0f}m"
|
| 204 |
)
|
| 205 |
+
sys.stdout.flush()
|
| 206 |
+
|
| 207 |
+
if step > 5:
|
| 208 |
+
# Reset counters for clean averages
|
| 209 |
+
best_loss = min(best_loss, avg) if math.isfinite(avg) else best_loss
|
| 210 |
+
total_loss, valid_count, toks, t0 = 0.0, 0, 0, time.time()
|
| 211 |
|
| 212 |
if step % args.save_every == 0:
|
| 213 |
d = save_training_checkpoint(model, config, step,
|
| 214 |
os.path.join(args.output_dir, f"ckpt-{step}"))
|
| 215 |
print(f" [SAVE] {d}")
|
| 216 |
+
sys.stdout.flush()
|
| 217 |
|
| 218 |
d = save_final_checkpoint(model, config, step, best_loss,
|
| 219 |
os.path.join(args.output_dir, "final"))
|
| 220 |
log_f.close()
|
| 221 |
+
total_time = time.time() - t_start
|
| 222 |
+
print(f"\nDONE -- best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}"
|
| 223 |
+
f" total time {total_time / 60:.1f}m")
|
| 224 |
+
sys.stdout.flush()
|
| 225 |
return d
|