TinmanLabSL commited on
Commit
2cc0940
·
verified ·
1 Parent(s): 68ac838

Fix trackio auth and buffer issues

Browse files
Files changed (1) hide show
  1. train.py +25 -14
train.py CHANGED
@@ -43,6 +43,13 @@ from smolomni.svd_init import initialize_mla_from_pretrained
43
 
44
  import trackio
45
 
 
 
 
 
 
 
 
46
 
47
  # ===== Stage 1: KL Distillation Dataset =====
48
  class TextDistillationDataset(IterableDataset):
@@ -163,12 +170,14 @@ def train_stage1(args, config):
163
  )
164
 
165
  if accelerator.is_main_process:
166
- trackio.init(
167
- project="SmolOmni-MLA",
168
- name="Stage1-KD",
169
- space_id="TinmanLabSL/SmolOmni-MLA-Tracking",
170
- config=vars(args),
171
- )
 
 
172
 
173
  set_seed(args.seed)
174
 
@@ -305,7 +314,7 @@ def train_stage1(args, config):
305
  print(f"Step {global_step}/{args.max_steps} | Loss: {avg_loss:.4f} | "
306
  f"LR: {scheduler.get_last_lr()[0]:.2e} | "
307
  f"Speed: {steps_per_sec:.1f} steps/s")
308
- trackio.log(metrics)
309
 
310
  total_loss = 0.0
311
 
@@ -345,12 +354,14 @@ def train_stage2(args, config):
345
  )
346
 
347
  if accelerator.is_main_process:
348
- trackio.init(
349
- project="SmolOmni-MLA",
350
- name="Stage2-Joint",
351
- space_id="TinmanLabSL/SmolOmni-MLA-Tracking",
352
- config=vars(args),
353
- )
 
 
354
 
355
  set_seed(args.seed)
356
 
@@ -469,7 +480,7 @@ def train_stage2(args, config):
469
  f"Loss: {total_loss/n:.4f} | "
470
  f"AR: {total_ar_loss/n:.4f} | "
471
  f"Flow: {total_flow_loss/n:.4f}")
472
- trackio.log(metrics)
473
  total_loss = total_ar_loss = total_flow_loss = 0.0
474
 
475
  if global_step % args.save_every == 0 and accelerator.is_main_process:
 
43
 
44
  import trackio
45
 
46
+ # Safe trackio wrapper
47
+ def safe_trackio_log(metrics):
48
+ try:
49
+ trackio.log(metrics)
50
+ except Exception:
51
+ pass
52
+
53
 
54
  # ===== Stage 1: KL Distillation Dataset =====
55
  class TextDistillationDataset(IterableDataset):
 
170
  )
171
 
172
  if accelerator.is_main_process:
173
+ try:
174
+ trackio.init(
175
+ project="SmolOmni-MLA",
176
+ name="Stage1-KD",
177
+ config=vars(args),
178
+ )
179
+ except Exception as e:
180
+ print(f"[WARN] Trackio init failed: {e}. Continuing without remote tracking.")
181
 
182
  set_seed(args.seed)
183
 
 
314
  print(f"Step {global_step}/{args.max_steps} | Loss: {avg_loss:.4f} | "
315
  f"LR: {scheduler.get_last_lr()[0]:.2e} | "
316
  f"Speed: {steps_per_sec:.1f} steps/s")
317
+ safe_trackio_log(metrics)
318
 
319
  total_loss = 0.0
320
 
 
354
  )
355
 
356
  if accelerator.is_main_process:
357
+ try:
358
+ trackio.init(
359
+ project="SmolOmni-MLA",
360
+ name="Stage2-Joint",
361
+ config=vars(args),
362
+ )
363
+ except Exception as e:
364
+ print(f"[WARN] Trackio init failed: {e}. Continuing without remote tracking.")
365
 
366
  set_seed(args.seed)
367
 
 
480
  f"Loss: {total_loss/n:.4f} | "
481
  f"AR: {total_ar_loss/n:.4f} | "
482
  f"Flow: {total_flow_loss/n:.4f}")
483
+ safe_trackio_log(metrics)
484
  total_loss = total_ar_loss = total_flow_loss = 0.0
485
 
486
  if global_step % args.save_every == 0 and accelerator.is_main_process: