Spaces:
Running on Zero
Running on Zero
File size: 678 Bytes
0afe769 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | import torch
import torch.optim as optim
def unwrap(m):
return m.module if hasattr(m, "module") else m
def continue_training(checkpoint_path, model, ema_model, optimizer: optim.Optimizer) -> int:
"""load the latest checkpoints and optimizers"""
ckpt = torch.load(checkpoint_path, map_location="cpu")
step = ckpt["step"]
shard = ckpt["shard"]
epoch = ckpt["epoch"]
unwrap(model).load_state_dict(ckpt["model"], strict=True)
ema_model.load_state_dict(ckpt["ema_model"], strict=True)
optimizer.load_state_dict(ckpt["optimizer"])
print(f'resume model and optimizer from {epoch} epoch, {shard} shard')
return step, shard + 1, epoch
|