| import glob |
| import os |
| import pickle |
|
|
| import torch |
|
|
|
|
| def _remove_files(files): |
| for f in files: |
| return os.remove(f) |
|
|
|
|
| def assert_dir_exits(path): |
| if not os.path.exists(path): |
| os.makedirs(path) |
|
|
|
|
| def save_model(model, epoch, out_path): |
| assert_dir_exits(out_path) |
| model_file = out_path + str(epoch) + '.pth' |
| chk_files = glob.glob(out_path + '*.pth') |
| _remove_files(chk_files) |
| torch.save(model.state_dict(), model_file) |
| print('model saved for epoch: {}'.format(epoch)) |
| return model_file |
|
|
|
|
| def save_objects(obj, epoch, out_path): |
| assert_dir_exits(out_path) |
| dat_files = glob.glob(out_path + '*.dat') |
| _remove_files(dat_files) |
| |
| with open(out_path + str(epoch) + '.dat', 'wb') as output: |
| pickle.dump(obj, output) |
|
|
| print('objects saved for epoch: {}'.format(epoch)) |
|
|
|
|
| def restore_model(model, out_path): |
| chk_file = glob.glob(out_path + '*.pth') |
|
|
| if chk_file: |
| chk_file = str(chk_file[0]) |
| print('found modeL {}, restoring'.format(chk_file)) |
| model.load_state_dict(torch.load(chk_file)) |
| else: |
| print('Model not found, using untrained model') |
| return model |
|
|
|
|
| def restore_objects(out_path, default): |
| data_file = glob.glob(out_path + '*.dat') |
| if data_file: |
| data_file = str(data_file[0]) |
| print('found data {}, restoring'.format(data_file)) |
| with open(data_file, 'rb') as input_: |
| obj = pickle.load(input_) |
|
|
| return obj |
| else: |
| return default |
|
|