| ''' |
| Clean uncessary information in the weight (*.pth) |
| ''' |
| import torch |
|
|
|
|
| if __name__ == "__main__": |
| weight_path = "saved_models/esrgan_best_generator.pth" |
| store_path = "1x_APISR_RRDB_GAN_generator.pth" |
|
|
| |
| checkpoint_g = torch.load(weight_path) |
| keys = [] |
| for key in checkpoint_g: |
| keys.append(key) |
| print(key) |
| for key in keys: |
| if key != "model_state_dict": |
| del checkpoint_g[key] |
| |
|
|
| |
| old_keys = [key for key in checkpoint_g['model_state_dict']] |
| for old_key in old_keys: |
| if old_key[:10] == "_orig_mod.": |
| new_key = old_key[10:] |
| checkpoint_g['model_state_dict'][new_key] = checkpoint_g['model_state_dict'][old_key] |
| del checkpoint_g['model_state_dict'][old_key] |
|
|
| torch.save(checkpoint_g, store_path) |
|
|
|
|
|
|
|
|
|
|