| import warnings |
| from collections import OrderedDict |
| from omegaconf import OmegaConf |
| import os |
| import argparse |
| import yaml |
|
|
| import torch |
| from lightning.pytorch import Trainer, seed_everything |
|
|
| from .prediff_lightning_module import PreDiffSEVIRPLModule |
| from utils.path import default_pretrained_earthformerunet_dir,pretrained_sevirlr_earthformer_unet_dir |
| from utils.pl_checkpoint import pl_load |
| from datamodule import WeatherForecastDataModuleOld |
|
|
| pytorch_state_dict_name = "sevirlr_earthformerunet.pt" |
|
|
| from dotenv import load_dotenv |
| _ = load_dotenv('./.env') |
|
|
| def get_parser(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--save', default='tmp_sevirlr_diffusion', type=str) |
| parser.add_argument('--nodes', default=1, type=int, |
| help="Number of nodes in DDP training.") |
| parser.add_argument('--gpus', nargs='+', type=int, |
| help="Number of GPUS per node in DDP training.") |
| parser.add_argument('--cfg', default=None, type=str) |
| parser.add_argument('--test', action='store_true') |
| parser.add_argument('--ckpt_name', default=None, type=str, |
| help='The model checkpoint trained on SEVIR-LR.') |
| parser.add_argument('--pretrained', action='store_true', |
| help='Load pretrained checkpoints for test.') |
| return parser |
|
|
|
|
| def main(): |
| parser = get_parser() |
| args = parser.parse_args() |
| if args.pretrained: |
| args.cfg = os.path.abspath(os.path.join(os.path.dirname(__file__), "cfg.yaml")) |
| assert os.path.exists(os.path.join(default_pretrained_earthformerunet_dir,pretrained_sevirlr_earthformer_unet_dir)), "Pretrained weights for Earthformer Unet does not exist" |
|
|
| if args.cfg is not None: |
| oc_from_file = OmegaConf.load(open(args.cfg, "r")) |
| dataset_cfg = OmegaConf.to_object(oc_from_file.dataset) |
| total_batch_size = oc_from_file.optim.total_batch_size |
| micro_batch_size = oc_from_file.optim.micro_batch_size |
| max_epochs = oc_from_file.optim.max_epochs |
| seed = oc_from_file.optim.seed |
| float32_matmul_precision = oc_from_file.optim.float32_matmul_precision |
| else: |
| dataset_cfg = OmegaConf.to_object(PreDiffSEVIRPLModule.get_dataset_config()) |
| micro_batch_size = 1 |
| total_batch_size = int(micro_batch_size * args.nodes * len(args.gpus)) |
| max_epochs = None |
| seed = 0 |
| float32_matmul_precision = "high" |
| torch.set_float32_matmul_precision(float32_matmul_precision) |
| seed_everything(seed, workers=True) |
|
|
| dm = PreDiffSEVIRPLModule.get_sevir_datamodule( |
| dataset_cfg=dataset_cfg, |
| micro_batch_size=micro_batch_size, |
| num_workers=4 |
| ) |
| dm.prepare_data() |
| dm.setup() |
| accumulate_grad_batches = total_batch_size // (micro_batch_size * args.nodes * len(args.gpus)) |
| total_num_steps = PreDiffSEVIRPLModule.get_total_num_steps( |
| epoch=max_epochs, |
| num_samples=dm.num_train_samples, |
| total_batch_size=total_batch_size, |
| ) |
|
|
| |
| pl_module = PreDiffSEVIRPLModule( |
| total_num_steps=total_num_steps, |
| save_dir=args.save, |
| oc_file=args.cfg) |
| trainer_kwargs = pl_module.set_trainer_kwargs( |
| devices=args.gpus, |
| num_nodes=args.nodes, |
| accumulate_grad_batches=accumulate_grad_batches, |
| ) |
| trainer = Trainer(**trainer_kwargs) |
| if args.pretrained: |
| |
| earthformerunet_ckpt_path = os.path.join( |
| default_pretrained_earthformerunet_dir, |
| pretrained_sevirlr_earthformer_unet_dir |
| ) |
| state_dict = torch.load( |
| earthformerunet_ckpt_path, |
| map_location=torch.device("cpu") |
| ) |
| pl_module.torch_nn_module.load_state_dict(state_dict=state_dict) |
| trainer.test(model=pl_module, |
| datamodule=dm) |
| elif args.test: |
| if args.ckpt_name is not None: |
| ckpt_path = os.path.join(pl_module.save_dir, "checkpoints", args.ckpt_name) |
| pl_ckpt = pl_load(path_or_url=ckpt_path, |
| map_location=torch.device("cpu")) |
| |
| pl_state_dict = pl_ckpt |
| model_kay = "torch_nn_module." |
| model_state_dict = OrderedDict() |
| for key, val in pl_state_dict.items(): |
| if key.startswith(model_kay): |
| model_state_dict[key.replace(model_kay, "")] = val |
| pl_module.torch_nn_module.load_state_dict(model_state_dict) |
| trainer.test(model=pl_module, |
| datamodule=dm, ) |
| else: |
| if args.ckpt_name is not None: |
| ckpt_path = os.path.join(pl_module.save_dir, "checkpoints", args.ckpt_name) |
| if not os.path.exists(ckpt_path): |
| warnings.warn(f"ckpt {ckpt_path} not exists! Start training from epoch 0.") |
| ckpt_path = None |
| else: |
| ckpt_path = None |
| trainer.fit(model=pl_module, |
| datamodule=dm, |
| ckpt_path=ckpt_path) |
| |
| pl_ckpt = pl_load(path_or_url=trainer.checkpoint_callback.best_model_path, |
| map_location=torch.device("cpu")) |
| |
| pl_state_dict = pl_ckpt |
| model_kay = "torch_nn_module." |
| state_dict = OrderedDict() |
| unexpected_dict = OrderedDict() |
| for key, val in pl_state_dict.items(): |
| if key.startswith(model_kay): |
| state_dict[key.replace(model_kay, "")] = val |
| else: |
| unexpected_dict[key] = val |
| torch.save(state_dict, os.path.join(pl_module.save_dir, "checkpoints", pytorch_state_dict_name)) |
| |
| trainer.test(ckpt_path="best", |
| datamodule=dm) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|