TinmanLabSL commited on
Commit
1fe8697
·
verified ·
1 Parent(s): 445411d

Fix: bf16 casting after checkpoint load for Stage 2

Browse files
Files changed (1) hide show
  1. train.py +6 -1
train.py CHANGED
@@ -385,7 +385,8 @@ def train_stage2(args, config):
385
  if args.checkpoint:
386
  ckpt_path = os.path.join(args.checkpoint, "model.pt")
387
  if os.path.exists(ckpt_path):
388
- model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
 
389
  print(f"Loaded Stage 1 checkpoint from {ckpt_path}")
390
  else:
391
  print("No Stage 1 checkpoint found, training from scratch")
@@ -393,6 +394,10 @@ def train_stage2(args, config):
393
  else:
394
  model = initialize_mla_from_pretrained(model, config.base_model, config)
395
 
 
 
 
 
396
  # Dataset
397
  dataset = ImageTextDataset(
398
  tokenizer, vae,
 
385
  if args.checkpoint:
386
  ckpt_path = os.path.join(args.checkpoint, "model.pt")
387
  if os.path.exists(ckpt_path):
388
+ state = torch.load(ckpt_path, map_location="cpu")
389
+ model.load_state_dict(state, strict=False)
390
  print(f"Loaded Stage 1 checkpoint from {ckpt_path}")
391
  else:
392
  print("No Stage 1 checkpoint found, training from scratch")
 
394
  else:
395
  model = initialize_mla_from_pretrained(model, config.base_model, config)
396
 
397
+ # Cast to bf16 AFTER loading checkpoint (ckpt weights may be fp32)
398
+ model = model.to(torch.bfloat16)
399
+ print("Model cast to bfloat16")
400
+
401
  # Dataset
402
  dataset = ImageTextDataset(
403
  tokenizer, vae,