| import os |
|
|
| import torch |
|
|
|
|
| def save_weights(model, filename, path="./saved_models"): |
| os.makedirs(path, exist_ok=True) |
|
|
| fpath = os.path.join(path, filename) |
| torch.save(model.state_dict(), fpath) |
| return |
|
|
| def save_checkpoint(model, optimizer, epoch, filename, root="./checkpoints"): |
| if not os.path.isdir(root): |
| os.makedirs(root) |
|
|
| fpath = os.path.join(root, filename) |
| torch.save( |
| { |
| "model": model.state_dict(), |
| "optimizer": optimizer.state_dict(), |
| "epoch": epoch |
| } |
| , fpath) |
|
|
| def load_weights(model, filename, path="./saved_models"): |
| fpath = os.path.join(path, filename) |
| state_dict = torch.load(fpath) |
| model.load_state_dict(state_dict) |
| return model |
|
|
| def load_checkpoint(fpath, model, optimizer=None): |
| ckpt = torch.load(fpath, map_location='cpu') |
| if ckpt is None: |
| raise Exception(f"\nERROR Loading AdaBins_nyu.pt. Read this for a fix:\nhttps://github.com/deforum-art/deforum-for-automatic1111-webui/wiki/FAQ-&-Troubleshooting#3d-animation-mode-is-not-working-only-2d-works") |
| if optimizer is None: |
| optimizer = ckpt.get('optimizer', None) |
| else: |
| optimizer.load_state_dict(ckpt['optimizer']) |
| epoch = ckpt['epoch'] |
|
|
| if 'model' in ckpt: |
| ckpt = ckpt['model'] |
| load_dict = {} |
| for k, v in ckpt.items(): |
| if k.startswith('module.'): |
| k_ = k.replace('module.', '') |
| load_dict[k_] = v |
| else: |
| load_dict[k] = v |
|
|
| modified = {} |
| for k, v in load_dict.items(): |
| if k.startswith('adaptive_bins_layer.embedding_conv.'): |
| k_ = k.replace('adaptive_bins_layer.embedding_conv.', |
| 'adaptive_bins_layer.conv3x3.') |
| modified[k_] = v |
| |
|
|
| elif k.startswith('adaptive_bins_layer.patch_transformer.embedding_encoder'): |
|
|
| k_ = k.replace('adaptive_bins_layer.patch_transformer.embedding_encoder', |
| 'adaptive_bins_layer.patch_transformer.embedding_convPxP') |
| modified[k_] = v |
| |
| else: |
| modified[k] = v |
|
|
| model.load_state_dict(modified) |
| return model, optimizer, epoch |