fix: train_hyper_loop grad_accum=1 (DataLoader already batches), better tok/s logging
Browse files- chimera/training/loops.py +43 -11
chimera/training/loops.py
CHANGED
|
@@ -150,10 +150,10 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 150 |
model, optimizer, scheduler = chimera_turbo.apply(
|
| 151 |
model,
|
| 152 |
max_steps=args.max_steps,
|
| 153 |
-
lr=
|
| 154 |
weight_decay=0.05,
|
| 155 |
warmup_steps=min(500, args.max_steps // 10),
|
| 156 |
-
use_compile=
|
| 157 |
use_ipex=True,
|
| 158 |
)
|
| 159 |
model.train()
|
|
@@ -169,7 +169,9 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 169 |
t0 = time.time()
|
| 170 |
cur_seq = initial_seq
|
| 171 |
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 172 |
-
loader = torch.utils.data.DataLoader(
|
|
|
|
|
|
|
| 173 |
data_iter = iter(loader)
|
| 174 |
|
| 175 |
print(f"\n{'=' * 65}")
|
|
@@ -183,9 +185,11 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 183 |
cur_seq = ns
|
| 184 |
dataset.set_seq_len(cur_seq)
|
| 185 |
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 186 |
-
loader = torch.utils.data.DataLoader(
|
|
|
|
|
|
|
| 187 |
data_iter = iter(loader)
|
| 188 |
-
print(f" [P1] seq
|
| 189 |
if unfreezer:
|
| 190 |
unfreezer.update(step)
|
| 191 |
try:
|
|
@@ -193,32 +197,60 @@ def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
|
| 193 |
except StopIteration:
|
| 194 |
data_iter = iter(loader)
|
| 195 |
batch = next(data_iter)
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
loss_val = chimera_turbo.training_step(
|
| 198 |
-
model,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
)
|
|
|
|
| 200 |
cur_lr = optimizer.param_groups[0]["lr"]
|
| 201 |
total_loss += loss_val
|
| 202 |
toks += batch["input_ids"].numel()
|
| 203 |
step += 1
|
|
|
|
| 204 |
if step % args.log_every == 0:
|
| 205 |
dt = time.time() - t0
|
| 206 |
avg = total_loss / args.log_every
|
| 207 |
ppl = math.exp(min(avg, 20))
|
| 208 |
tps = toks / dt if dt > 0 else 0
|
| 209 |
eta = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0
|
| 210 |
-
log_f.write(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
log_f.flush()
|
| 212 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 213 |
best_loss = min(best_loss, avg)
|
| 214 |
total_loss = 0.0
|
| 215 |
toks = 0
|
| 216 |
t0 = time.time()
|
|
|
|
| 217 |
if step % args.save_every == 0:
|
| 218 |
-
d = save_training_checkpoint(
|
|
|
|
|
|
|
| 219 |
print(f" [SAVE] {d}")
|
| 220 |
|
| 221 |
-
d = save_final_checkpoint(
|
|
|
|
|
|
|
| 222 |
log_f.close()
|
| 223 |
print(f"\nDONE — best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}")
|
| 224 |
return d
|
|
|
|
| 150 |
model, optimizer, scheduler = chimera_turbo.apply(
|
| 151 |
model,
|
| 152 |
max_steps=args.max_steps,
|
| 153 |
+
lr=args.lr,
|
| 154 |
weight_decay=0.05,
|
| 155 |
warmup_steps=min(500, args.max_steps // 10),
|
| 156 |
+
use_compile=False, # ← disabled: 84 graph breaks from STE
|
| 157 |
use_ipex=True,
|
| 158 |
)
|
| 159 |
model.train()
|
|
|
|
| 169 |
t0 = time.time()
|
| 170 |
cur_seq = initial_seq
|
| 171 |
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 172 |
+
loader = torch.utils.data.DataLoader(
|
| 173 |
+
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True
|
| 174 |
+
)
|
| 175 |
data_iter = iter(loader)
|
| 176 |
|
| 177 |
print(f"\n{'=' * 65}")
|
|
|
|
| 185 |
cur_seq = ns
|
| 186 |
dataset.set_seq_len(cur_seq)
|
| 187 |
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 188 |
+
loader = torch.utils.data.DataLoader(
|
| 189 |
+
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True
|
| 190 |
+
)
|
| 191 |
data_iter = iter(loader)
|
| 192 |
+
print(f" [P1] seq → {cur_seq} batch → {eff_batch}")
|
| 193 |
if unfreezer:
|
| 194 |
unfreezer.update(step)
|
| 195 |
try:
|
|
|
|
| 197 |
except StopIteration:
|
| 198 |
data_iter = iter(loader)
|
| 199 |
batch = next(data_iter)
|
| 200 |
+
|
| 201 |
+
# grad_accum_steps=1: DataLoader already provides eff_batch items.
|
| 202 |
+
# The effective batch IS eff_batch. No need to accumulate further.
|
| 203 |
loss_val = chimera_turbo.training_step(
|
| 204 |
+
model,
|
| 205 |
+
batch,
|
| 206 |
+
optimizer,
|
| 207 |
+
scheduler,
|
| 208 |
+
grad_accum_steps=1,
|
| 209 |
+
step=step,
|
| 210 |
+
autocast_dtype=torch.bfloat16 if use_bf16 else None,
|
| 211 |
)
|
| 212 |
+
|
| 213 |
cur_lr = optimizer.param_groups[0]["lr"]
|
| 214 |
total_loss += loss_val
|
| 215 |
toks += batch["input_ids"].numel()
|
| 216 |
step += 1
|
| 217 |
+
|
| 218 |
if step % args.log_every == 0:
|
| 219 |
dt = time.time() - t0
|
| 220 |
avg = total_loss / args.log_every
|
| 221 |
ppl = math.exp(min(avg, 20))
|
| 222 |
tps = toks / dt if dt > 0 else 0
|
| 223 |
eta = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0
|
| 224 |
+
log_f.write(
|
| 225 |
+
json.dumps({
|
| 226 |
+
"step": step,
|
| 227 |
+
"loss": round(avg, 4),
|
| 228 |
+
"ppl": round(ppl, 2),
|
| 229 |
+
"lr": round(cur_lr, 6),
|
| 230 |
+
"tok/s": round(tps),
|
| 231 |
+
"seq_len": cur_seq,
|
| 232 |
+
"eff_batch": eff_batch,
|
| 233 |
+
}) + "\n"
|
| 234 |
+
)
|
| 235 |
log_f.flush()
|
| 236 |
+
print(
|
| 237 |
+
f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} "
|
| 238 |
+
f"| lr {cur_lr:.2e} | {tps:,.0f} tok/s | seq {cur_seq} | ETA {eta:.1f}h"
|
| 239 |
+
)
|
| 240 |
best_loss = min(best_loss, avg)
|
| 241 |
total_loss = 0.0
|
| 242 |
toks = 0
|
| 243 |
t0 = time.time()
|
| 244 |
+
|
| 245 |
if step % args.save_every == 0:
|
| 246 |
+
d = save_training_checkpoint(
|
| 247 |
+
model, config, step, os.path.join(args.output_dir, f"ckpt-{step}")
|
| 248 |
+
)
|
| 249 |
print(f" [SAVE] {d}")
|
| 250 |
|
| 251 |
+
d = save_final_checkpoint(
|
| 252 |
+
model, config, step, best_loss, os.path.join(args.output_dir, "final")
|
| 253 |
+
)
|
| 254 |
log_f.close()
|
| 255 |
print(f"\nDONE — best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}")
|
| 256 |
return d
|