Lgr54HFi commited on
Commit
31d69ba
·
verified ·
1 Parent(s): 20ad65d

fix: train_hyper_loop grad_accum=1 (DataLoader already batches), better tok/s logging

Browse files
Files changed (1) hide show
  1. chimera/training/loops.py +43 -11
chimera/training/loops.py CHANGED
@@ -150,10 +150,10 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
150
  model, optimizer, scheduler = chimera_turbo.apply(
151
  model,
152
  max_steps=args.max_steps,
153
- lr=1e-3,
154
  weight_decay=0.05,
155
  warmup_steps=min(500, args.max_steps // 10),
156
- use_compile=True,
157
  use_ipex=True,
158
  )
159
  model.train()
@@ -169,7 +169,9 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
169
  t0 = time.time()
170
  cur_seq = initial_seq
171
  eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
172
- loader = torch.utils.data.DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
 
 
173
  data_iter = iter(loader)
174
 
175
  print(f"\n{'=' * 65}")
@@ -183,9 +185,11 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
183
  cur_seq = ns
184
  dataset.set_seq_len(cur_seq)
185
  eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
186
- loader = torch.utils.data.DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
 
 
187
  data_iter = iter(loader)
188
- print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
189
  if unfreezer:
190
  unfreezer.update(step)
191
  try:
@@ -193,32 +197,60 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
193
  except StopIteration:
194
  data_iter = iter(loader)
195
  batch = next(data_iter)
196
- grad_accum_steps = max(1, eff_batch // max(1, args.batch_size))
 
 
197
  loss_val = chimera_turbo.training_step(
198
- model, batch, optimizer, scheduler, grad_accum_steps=grad_accum_steps, step=step, autocast_dtype=torch.bfloat16 if use_bf16 else None
 
 
 
 
 
 
199
  )
 
200
  cur_lr = optimizer.param_groups[0]["lr"]
201
  total_loss += loss_val
202
  toks += batch["input_ids"].numel()
203
  step += 1
 
204
  if step % args.log_every == 0:
205
  dt = time.time() - t0
206
  avg = total_loss / args.log_every
207
  ppl = math.exp(min(avg, 20))
208
  tps = toks / dt if dt > 0 else 0
209
  eta = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0
210
- log_f.write(json.dumps({"step": step, "loss": round(avg, 4), "ppl": round(ppl, 2), "lr": cur_lr, "tok/s": round(tps), "seq_len": cur_seq, "eff_batch": eff_batch}) + "\n")
 
 
 
 
 
 
 
 
 
 
211
  log_f.flush()
212
- print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} | {tps:,.0f} tok/s | seq {cur_seq} | ETA {eta:.1f}h")
 
 
 
213
  best_loss = min(best_loss, avg)
214
  total_loss = 0.0
215
  toks = 0
216
  t0 = time.time()
 
217
  if step % args.save_every == 0:
218
- d = save_training_checkpoint(model, config, step, os.path.join(args.output_dir, f"ckpt-{step}"))
 
 
219
  print(f" [SAVE] {d}")
220
 
221
- d = save_final_checkpoint(model, config, step, best_loss, os.path.join(args.output_dir, "final"))
 
 
222
  log_f.close()
223
  print(f"\nDONE — best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}")
224
  return d
 
150
  model, optimizer, scheduler = chimera_turbo.apply(
151
  model,
152
  max_steps=args.max_steps,
153
+ lr=args.lr,
154
  weight_decay=0.05,
155
  warmup_steps=min(500, args.max_steps // 10),
156
+ use_compile=False, # ← disabled: 84 graph breaks from STE
157
  use_ipex=True,
158
  )
159
  model.train()
 
169
  t0 = time.time()
170
  cur_seq = initial_seq
171
  eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
172
+ loader = torch.utils.data.DataLoader(
173
+ dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True
174
+ )
175
  data_iter = iter(loader)
176
 
177
  print(f"\n{'=' * 65}")
 
185
  cur_seq = ns
186
  dataset.set_seq_len(cur_seq)
187
  eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
188
+ loader = torch.utils.data.DataLoader(
189
+ dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True
190
+ )
191
  data_iter = iter(loader)
192
+ print(f" [P1] seq {cur_seq} batch {eff_batch}")
193
  if unfreezer:
194
  unfreezer.update(step)
195
  try:
 
197
  except StopIteration:
198
  data_iter = iter(loader)
199
  batch = next(data_iter)
200
+
201
+ # grad_accum_steps=1: DataLoader already provides eff_batch items.
202
+ # The effective batch IS eff_batch. No need to accumulate further.
203
  loss_val = chimera_turbo.training_step(
204
+ model,
205
+ batch,
206
+ optimizer,
207
+ scheduler,
208
+ grad_accum_steps=1,
209
+ step=step,
210
+ autocast_dtype=torch.bfloat16 if use_bf16 else None,
211
  )
212
+
213
  cur_lr = optimizer.param_groups[0]["lr"]
214
  total_loss += loss_val
215
  toks += batch["input_ids"].numel()
216
  step += 1
217
+
218
  if step % args.log_every == 0:
219
  dt = time.time() - t0
220
  avg = total_loss / args.log_every
221
  ppl = math.exp(min(avg, 20))
222
  tps = toks / dt if dt > 0 else 0
223
  eta = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0
224
+ log_f.write(
225
+ json.dumps({
226
+ "step": step,
227
+ "loss": round(avg, 4),
228
+ "ppl": round(ppl, 2),
229
+ "lr": round(cur_lr, 6),
230
+ "tok/s": round(tps),
231
+ "seq_len": cur_seq,
232
+ "eff_batch": eff_batch,
233
+ }) + "\n"
234
+ )
235
  log_f.flush()
236
+ print(
237
+ f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} "
238
+ f"| lr {cur_lr:.2e} | {tps:,.0f} tok/s | seq {cur_seq} | ETA {eta:.1f}h"
239
+ )
240
  best_loss = min(best_loss, avg)
241
  total_loss = 0.0
242
  toks = 0
243
  t0 = time.time()
244
+
245
  if step % args.save_every == 0:
246
+ d = save_training_checkpoint(
247
+ model, config, step, os.path.join(args.output_dir, f"ckpt-{step}")
248
+ )
249
  print(f" [SAVE] {d}")
250
 
251
+ d = save_final_checkpoint(
252
+ model, config, step, best_loss, os.path.join(args.output_dir, "final")
253
+ )
254
  log_f.close()
255
  print(f"\nDONE — best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}")
256
  return d