Fix trackio auth and buffer issues
Browse files
train.py
CHANGED
|
@@ -43,6 +43,13 @@ from smolomni.svd_init import initialize_mla_from_pretrained
|
|
| 43 |
|
| 44 |
import trackio
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
# ===== Stage 1: KL Distillation Dataset =====
|
| 48 |
class TextDistillationDataset(IterableDataset):
|
|
@@ -163,12 +170,14 @@ def train_stage1(args, config):
|
|
| 163 |
)
|
| 164 |
|
| 165 |
if accelerator.is_main_process:
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
|
|
|
|
|
|
| 172 |
|
| 173 |
set_seed(args.seed)
|
| 174 |
|
|
@@ -305,7 +314,7 @@ def train_stage1(args, config):
|
|
| 305 |
print(f"Step {global_step}/{args.max_steps} | Loss: {avg_loss:.4f} | "
|
| 306 |
f"LR: {scheduler.get_last_lr()[0]:.2e} | "
|
| 307 |
f"Speed: {steps_per_sec:.1f} steps/s")
|
| 308 |
-
|
| 309 |
|
| 310 |
total_loss = 0.0
|
| 311 |
|
|
@@ -345,12 +354,14 @@ def train_stage2(args, config):
|
|
| 345 |
)
|
| 346 |
|
| 347 |
if accelerator.is_main_process:
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
|
|
|
|
|
|
| 354 |
|
| 355 |
set_seed(args.seed)
|
| 356 |
|
|
@@ -469,7 +480,7 @@ def train_stage2(args, config):
|
|
| 469 |
f"Loss: {total_loss/n:.4f} | "
|
| 470 |
f"AR: {total_ar_loss/n:.4f} | "
|
| 471 |
f"Flow: {total_flow_loss/n:.4f}")
|
| 472 |
-
|
| 473 |
total_loss = total_ar_loss = total_flow_loss = 0.0
|
| 474 |
|
| 475 |
if global_step % args.save_every == 0 and accelerator.is_main_process:
|
|
|
|
| 43 |
|
| 44 |
import trackio
|
| 45 |
|
| 46 |
+
# Safe trackio wrapper
|
| 47 |
+
def safe_trackio_log(metrics):
|
| 48 |
+
try:
|
| 49 |
+
trackio.log(metrics)
|
| 50 |
+
except Exception:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
|
| 54 |
# ===== Stage 1: KL Distillation Dataset =====
|
| 55 |
class TextDistillationDataset(IterableDataset):
|
|
|
|
| 170 |
)
|
| 171 |
|
| 172 |
if accelerator.is_main_process:
|
| 173 |
+
try:
|
| 174 |
+
trackio.init(
|
| 175 |
+
project="SmolOmni-MLA",
|
| 176 |
+
name="Stage1-KD",
|
| 177 |
+
config=vars(args),
|
| 178 |
+
)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
print(f"[WARN] Trackio init failed: {e}. Continuing without remote tracking.")
|
| 181 |
|
| 182 |
set_seed(args.seed)
|
| 183 |
|
|
|
|
| 314 |
print(f"Step {global_step}/{args.max_steps} | Loss: {avg_loss:.4f} | "
|
| 315 |
f"LR: {scheduler.get_last_lr()[0]:.2e} | "
|
| 316 |
f"Speed: {steps_per_sec:.1f} steps/s")
|
| 317 |
+
safe_trackio_log(metrics)
|
| 318 |
|
| 319 |
total_loss = 0.0
|
| 320 |
|
|
|
|
| 354 |
)
|
| 355 |
|
| 356 |
if accelerator.is_main_process:
|
| 357 |
+
try:
|
| 358 |
+
trackio.init(
|
| 359 |
+
project="SmolOmni-MLA",
|
| 360 |
+
name="Stage2-Joint",
|
| 361 |
+
config=vars(args),
|
| 362 |
+
)
|
| 363 |
+
except Exception as e:
|
| 364 |
+
print(f"[WARN] Trackio init failed: {e}. Continuing without remote tracking.")
|
| 365 |
|
| 366 |
set_seed(args.seed)
|
| 367 |
|
|
|
|
| 480 |
f"Loss: {total_loss/n:.4f} | "
|
| 481 |
f"AR: {total_ar_loss/n:.4f} | "
|
| 482 |
f"Flow: {total_flow_loss/n:.4f}")
|
| 483 |
+
safe_trackio_log(metrics)
|
| 484 |
total_loss = total_ar_loss = total_flow_loss = 0.0
|
| 485 |
|
| 486 |
if global_step % args.save_every == 0 and accelerator.is_main_process:
|