| import torch |
|
|
| def load_model_checkpoint(model, ckpt_path): |
| """Load state dict from checkpoint file. |
| |
| :param model: The model to load the state dict into. |
| :param ckpt_path: The path to the checkpoint file. |
| """ |
| if ckpt_path is None: |
| return model, None |
| |
| |
| |
| |
| |
| if ckpt_path.endswith(".pth"): |
| net_params = torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict'] |
| net_params = {k.replace('net.', ''): v for k, v in net_params.items()} |
| model.net.load_state_dict(net_params) |
| ckpt_path = None |
| elif ckpt_path.endswith(".ckpt"): |
| |
| pass |
| else: |
| |
| raise ValueError(f"ckpt_path {ckpt_path} is not a valid checkpoint file.") |
| |
| return model, ckpt_path |