Lgr54HFi commited on
Commit
5b5a08d
Β·
verified Β·
1 Parent(s): 995be31

fix: print every step + first-step timing to diagnose slow forward

Browse files

With 227M params on CPU, each step can take 30-60s. With log_every=10,
the first output wouldn't appear for 5-10 minutes β€” looks like a hang.

Changes:
- Print step 1 timing immediately after first forward+backward
- Log every step for the first 5 steps, then every log_every
- Flush stdout after every print to ensure immediate display
- Add sys.stdout.flush() calls"

Files changed (1) hide show
  1. chimera/training/loops.py +41 -12
chimera/training/loops.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
  import json
4
  import math
5
  import os
 
6
  import time
7
 
8
  import torch
@@ -28,6 +29,7 @@ def _safe_batch(desired_batch: int, seq_len: int, vocab_size: int,
28
  if capped < desired_batch:
29
  print(f" [MEM] Batch {desired_batch} β†’ {capped} (logits would be "
30
  f"{desired_batch * seq_len * vocab_size * 4 / 1e9:.1f} GB, cap={max_logits_gb} GB)")
 
31
  return capped
32
 
33
 
@@ -91,14 +93,12 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
91
  model.train()
92
 
93
  # ── Gradient checkpointing: saves ~60% activation memory ──
94
- # Critical with vocab=200K: without it, activations across 28 layers
95
- # at batch=32 can consume several GB.
96
  raw_model = getattr(model, "_orig_mod", model)
97
  if hasattr(raw_model, "enable_gradient_checkpointing"):
98
  raw_model.enable_gradient_checkpointing()
99
  print(f"[OPT] Gradient checkpointing: ON")
100
 
101
- # ── Looping: force loops=1 for all 300 steps ──
102
  cur_loops = 1
103
  if hasattr(raw_model, "loop_controller"):
104
  raw_model.loop_controller.loop_default = 1
@@ -111,9 +111,10 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
111
  log_f = open(os.path.join(args.output_dir, "log_hyper.jsonl"), "w")
112
  step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0
113
  t0 = time.time()
 
114
  cur_seq = initial_seq
115
 
116
- # ── Compute memory-safe batch size ──
117
  desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
118
  eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
119
 
@@ -123,7 +124,9 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
123
 
124
  print(f"\n{'=' * 65}")
125
  print(f"Training batch={eff_batch} seq={cur_seq} loops={cur_loops}")
126
- print(f"{'=' * 65}\n")
 
 
127
 
128
  while step < args.max_steps:
129
  if grow:
@@ -137,6 +140,7 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
137
  dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
138
  data_iter = iter(loader)
139
  print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
 
140
 
141
  if unfreezer:
142
  unfreezer.update(step)
@@ -147,20 +151,34 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
147
  data_iter = iter(loader)
148
  batch = next(data_iter)
149
 
 
 
150
  loss_val = chimera_turbo.training_step(
151
  model, batch, optimizer, scheduler,
152
  extras=extras, grad_accum_steps=1, step=step,
153
  autocast_dtype=torch.bfloat16 if use_bf16 else None,
154
  )
155
 
 
 
156
  cur_lr = optimizer.param_groups[0]["lr"] * optimizer.param_groups[0].get("lr_scale", 1.0)
157
  if math.isfinite(loss_val):
158
  total_loss += loss_val
159
  valid_count += 1
160
- toks += batch["input_ids"].numel()
 
161
  step += 1
162
 
163
- if step % args.log_every == 0:
 
 
 
 
 
 
 
 
 
164
  dt = time.time() - t0
165
  if valid_count > 0:
166
  avg = total_loss / valid_count
@@ -169,28 +187,39 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
169
  avg = float("nan")
170
  ppl = float("nan")
171
  tps = toks / dt if dt > 0 else 0
172
- eta = (args.max_steps - step) / max(1, step) * (time.time() - t0) / 3600 if step > 0 else 0
 
173
  log_f.write(json.dumps({
174
  "step": step, "loss": round(avg, 4) if math.isfinite(avg) else None,
175
  "ppl": round(ppl, 2) if math.isfinite(ppl) else None,
176
  "lr": round(cur_lr, 6), "tok/s": round(tps),
177
  "seq": cur_seq, "loops": cur_loops,
 
178
  }) + "\n")
179
  log_f.flush()
180
  print(
181
  f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} "
182
- f"| lr {cur_lr:.2e} | {tps:,.0f} tok/s | seq {cur_seq} | L{cur_loops} | ETA {eta:.1f}h"
 
183
  )
184
- best_loss = min(best_loss, avg) if math.isfinite(avg) else best_loss
185
- total_loss, valid_count, toks, t0 = 0.0, 0, 0, time.time()
 
 
 
 
186
 
187
  if step % args.save_every == 0:
188
  d = save_training_checkpoint(model, config, step,
189
  os.path.join(args.output_dir, f"ckpt-{step}"))
190
  print(f" [SAVE] {d}")
 
191
 
192
  d = save_final_checkpoint(model, config, step, best_loss,
193
  os.path.join(args.output_dir, "final"))
194
  log_f.close()
195
- print(f"\nDONE -- best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}")
 
 
 
196
  return d
 
3
  import json
4
  import math
5
  import os
6
+ import sys
7
  import time
8
 
9
  import torch
 
29
  if capped < desired_batch:
30
  print(f" [MEM] Batch {desired_batch} β†’ {capped} (logits would be "
31
  f"{desired_batch * seq_len * vocab_size * 4 / 1e9:.1f} GB, cap={max_logits_gb} GB)")
32
+ sys.stdout.flush()
33
  return capped
34
 
35
 
 
93
  model.train()
94
 
95
  # ── Gradient checkpointing: saves ~60% activation memory ──
 
 
96
  raw_model = getattr(model, "_orig_mod", model)
97
  if hasattr(raw_model, "enable_gradient_checkpointing"):
98
  raw_model.enable_gradient_checkpointing()
99
  print(f"[OPT] Gradient checkpointing: ON")
100
 
101
+ # ── Looping: force loops=1 ──
102
  cur_loops = 1
103
  if hasattr(raw_model, "loop_controller"):
104
  raw_model.loop_controller.loop_default = 1
 
111
  log_f = open(os.path.join(args.output_dir, "log_hyper.jsonl"), "w")
112
  step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0
113
  t0 = time.time()
114
+ t_start = t0
115
  cur_seq = initial_seq
116
 
117
+ # ── Memory-safe batch size ──
118
  desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
119
  eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
120
 
 
124
 
125
  print(f"\n{'=' * 65}")
126
  print(f"Training batch={eff_batch} seq={cur_seq} loops={cur_loops}")
127
+ print(f"Starting first step (may take 30-60s on CPU with 227M params)...")
128
+ print(f"{'=' * 65}")
129
+ sys.stdout.flush()
130
 
131
  while step < args.max_steps:
132
  if grow:
 
140
  dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
141
  data_iter = iter(loader)
142
  print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
143
+ sys.stdout.flush()
144
 
145
  if unfreezer:
146
  unfreezer.update(step)
 
151
  data_iter = iter(loader)
152
  batch = next(data_iter)
153
 
154
+ step_t0 = time.time()
155
+
156
  loss_val = chimera_turbo.training_step(
157
  model, batch, optimizer, scheduler,
158
  extras=extras, grad_accum_steps=1, step=step,
159
  autocast_dtype=torch.bfloat16 if use_bf16 else None,
160
  )
161
 
162
+ step_dt = time.time() - step_t0
163
+
164
  cur_lr = optimizer.param_groups[0]["lr"] * optimizer.param_groups[0].get("lr_scale", 1.0)
165
  if math.isfinite(loss_val):
166
  total_loss += loss_val
167
  valid_count += 1
168
+ step_toks = batch["input_ids"].numel()
169
+ toks += step_toks
170
  step += 1
171
 
172
+ # Print every step for the first 5 steps, then every log_every
173
+ should_log = (step <= 5) or (step % args.log_every == 0)
174
+
175
+ if step == 1:
176
+ step_tps = step_toks / step_dt if step_dt > 0 else 0
177
+ print(f" βœ“ Step 1 completed in {step_dt:.1f}s "
178
+ f"({step_tps:.0f} tok/s, loss={loss_val:.4f})")
179
+ sys.stdout.flush()
180
+
181
+ if should_log:
182
  dt = time.time() - t0
183
  if valid_count > 0:
184
  avg = total_loss / valid_count
 
187
  avg = float("nan")
188
  ppl = float("nan")
189
  tps = toks / dt if dt > 0 else 0
190
+ elapsed = time.time() - t_start
191
+ eta_s = (args.max_steps - step) * (elapsed / max(1, step))
192
  log_f.write(json.dumps({
193
  "step": step, "loss": round(avg, 4) if math.isfinite(avg) else None,
194
  "ppl": round(ppl, 2) if math.isfinite(ppl) else None,
195
  "lr": round(cur_lr, 6), "tok/s": round(tps),
196
  "seq": cur_seq, "loops": cur_loops,
197
+ "step_time": round(step_dt, 2),
198
  }) + "\n")
199
  log_f.flush()
200
  print(
201
  f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} "
202
+ f"| {tps:,.0f} tok/s | {step_dt:.1f}s/step | seq {cur_seq} "
203
+ f"| ETA {eta_s / 60:.0f}m"
204
  )
205
+ sys.stdout.flush()
206
+
207
+ if step > 5:
208
+ # Reset counters for clean averages
209
+ best_loss = min(best_loss, avg) if math.isfinite(avg) else best_loss
210
+ total_loss, valid_count, toks, t0 = 0.0, 0, 0, time.time()
211
 
212
  if step % args.save_every == 0:
213
  d = save_training_checkpoint(model, config, step,
214
  os.path.join(args.output_dir, f"ckpt-{step}"))
215
  print(f" [SAVE] {d}")
216
+ sys.stdout.flush()
217
 
218
  d = save_final_checkpoint(model, config, step, best_loss,
219
  os.path.join(args.output_dir, "final"))
220
  log_f.close()
221
+ total_time = time.time() - t_start
222
+ print(f"\nDONE -- best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}"
223
+ f" total time {total_time / 60:.1f}m")
224
+ sys.stdout.flush()
225
  return d