Fix: bf16 casting after checkpoint load for Stage 2
Browse files
train.py
CHANGED
|
@@ -385,7 +385,8 @@ def train_stage2(args, config):
|
|
| 385 |
if args.checkpoint:
|
| 386 |
ckpt_path = os.path.join(args.checkpoint, "model.pt")
|
| 387 |
if os.path.exists(ckpt_path):
|
| 388 |
-
|
|
|
|
| 389 |
print(f"Loaded Stage 1 checkpoint from {ckpt_path}")
|
| 390 |
else:
|
| 391 |
print("No Stage 1 checkpoint found, training from scratch")
|
|
@@ -393,6 +394,10 @@ def train_stage2(args, config):
|
|
| 393 |
else:
|
| 394 |
model = initialize_mla_from_pretrained(model, config.base_model, config)
|
| 395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
# Dataset
|
| 397 |
dataset = ImageTextDataset(
|
| 398 |
tokenizer, vae,
|
|
|
|
| 385 |
if args.checkpoint:
|
| 386 |
ckpt_path = os.path.join(args.checkpoint, "model.pt")
|
| 387 |
if os.path.exists(ckpt_path):
|
| 388 |
+
state = torch.load(ckpt_path, map_location="cpu")
|
| 389 |
+
model.load_state_dict(state, strict=False)
|
| 390 |
print(f"Loaded Stage 1 checkpoint from {ckpt_path}")
|
| 391 |
else:
|
| 392 |
print("No Stage 1 checkpoint found, training from scratch")
|
|
|
|
| 394 |
else:
|
| 395 |
model = initialize_mla_from_pretrained(model, config.base_model, config)
|
| 396 |
|
| 397 |
+
# Cast to bf16 AFTER loading checkpoint (ckpt weights may be fp32)
|
| 398 |
+
model = model.to(torch.bfloat16)
|
| 399 |
+
print("Model cast to bfloat16")
|
| 400 |
+
|
| 401 |
# Dataset
|
| 402 |
dataset = ImageTextDataset(
|
| 403 |
tokenizer, vae,
|