prediff_code / scripts /train_diffusion /train_sevirlr_prediff.py
weatherforecast1024's picture
Upload folder using huggingface_hub
7667a87 verified
import warnings
from collections import OrderedDict
from omegaconf import OmegaConf
import os
import argparse
import yaml
import torch
from lightning.pytorch import Trainer, seed_everything
from .prediff_lightning_module import PreDiffSEVIRPLModule
from utils.path import default_pretrained_earthformerunet_dir,pretrained_sevirlr_earthformer_unet_dir
from utils.pl_checkpoint import pl_load
from datamodule import WeatherForecastDataModuleOld
pytorch_state_dict_name = "sevirlr_earthformerunet.pt"
from dotenv import load_dotenv
_ = load_dotenv('./.env')
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--save', default='tmp_sevirlr_diffusion', type=str)
parser.add_argument('--nodes', default=1, type=int,
help="Number of nodes in DDP training.")
parser.add_argument('--gpus', nargs='+', type=int,
help="Number of GPUS per node in DDP training.")
parser.add_argument('--cfg', default=None, type=str)
parser.add_argument('--test', action='store_true')
parser.add_argument('--ckpt_name', default=None, type=str,
help='The model checkpoint trained on SEVIR-LR.')
parser.add_argument('--pretrained', action='store_true',
help='Load pretrained checkpoints for test.')
return parser
def main():
parser = get_parser()
args = parser.parse_args()
if args.pretrained:
args.cfg = os.path.abspath(os.path.join(os.path.dirname(__file__), "cfg.yaml"))
assert os.path.exists(os.path.join(default_pretrained_earthformerunet_dir,pretrained_sevirlr_earthformer_unet_dir)), "Pretrained weights for Earthformer Unet does not exist"
if args.cfg is not None:
oc_from_file = OmegaConf.load(open(args.cfg, "r"))
dataset_cfg = OmegaConf.to_object(oc_from_file.dataset)
total_batch_size = oc_from_file.optim.total_batch_size
micro_batch_size = oc_from_file.optim.micro_batch_size
max_epochs = oc_from_file.optim.max_epochs
seed = oc_from_file.optim.seed
float32_matmul_precision = oc_from_file.optim.float32_matmul_precision
else:
dataset_cfg = OmegaConf.to_object(PreDiffSEVIRPLModule.get_dataset_config())
micro_batch_size = 1
total_batch_size = int(micro_batch_size * args.nodes * len(args.gpus))
max_epochs = None
seed = 0
float32_matmul_precision = "high"
torch.set_float32_matmul_precision(float32_matmul_precision)
seed_everything(seed, workers=True)
dm = PreDiffSEVIRPLModule.get_sevir_datamodule(
dataset_cfg=dataset_cfg,
micro_batch_size=micro_batch_size,
num_workers=4
)
dm.prepare_data()
dm.setup()
accumulate_grad_batches = total_batch_size // (micro_batch_size * args.nodes * len(args.gpus))
total_num_steps = PreDiffSEVIRPLModule.get_total_num_steps(
epoch=max_epochs,
num_samples=dm.num_train_samples,
total_batch_size=total_batch_size,
)
pl_module = PreDiffSEVIRPLModule(
total_num_steps=total_num_steps,
save_dir=args.save,
oc_file=args.cfg)
trainer_kwargs = pl_module.set_trainer_kwargs(
devices=args.gpus,
num_nodes=args.nodes,
accumulate_grad_batches=accumulate_grad_batches,
)
trainer = Trainer(**trainer_kwargs)
if args.pretrained:
# load Earthformer-UNet
earthformerunet_ckpt_path = os.path.join(
default_pretrained_earthformerunet_dir,
pretrained_sevirlr_earthformer_unet_dir
)
state_dict = torch.load(
earthformerunet_ckpt_path,
map_location=torch.device("cpu")
)
pl_module.torch_nn_module.load_state_dict(state_dict=state_dict)
trainer.test(model=pl_module,
datamodule=dm)
elif args.test:
if args.ckpt_name is not None:
ckpt_path = os.path.join(pl_module.save_dir, "checkpoints", args.ckpt_name)
pl_ckpt = pl_load(path_or_url=ckpt_path,
map_location=torch.device("cpu"))
# pl_state_dict = pl_ckpt["state_dict"] # pl 1.x
pl_state_dict = pl_ckpt
model_kay = "torch_nn_module."
model_state_dict = OrderedDict()
for key, val in pl_state_dict.items():
if key.startswith(model_kay):
model_state_dict[key.replace(model_kay, "")] = val
pl_module.torch_nn_module.load_state_dict(model_state_dict)
trainer.test(model=pl_module,
datamodule=dm, )
else:
if args.ckpt_name is not None:
ckpt_path = os.path.join(pl_module.save_dir, "checkpoints", args.ckpt_name)
if not os.path.exists(ckpt_path):
warnings.warn(f"ckpt {ckpt_path} not exists! Start training from epoch 0.")
ckpt_path = None
else:
ckpt_path = None
trainer.fit(model=pl_module,
datamodule=dm,
ckpt_path=ckpt_path)
# save state_dict of the latent diffusion model
pl_ckpt = pl_load(path_or_url=trainer.checkpoint_callback.best_model_path,
map_location=torch.device("cpu"))
# pl_state_dict = pl_ckpt["state_dict"] # pl 1.x
pl_state_dict = pl_ckpt
model_kay = "torch_nn_module."
state_dict = OrderedDict()
unexpected_dict = OrderedDict()
for key, val in pl_state_dict.items():
if key.startswith(model_kay):
state_dict[key.replace(model_kay, "")] = val
else:
unexpected_dict[key] = val
torch.save(state_dict, os.path.join(pl_module.save_dir, "checkpoints", pytorch_state_dict_name))
# test
trainer.test(ckpt_path="best",
datamodule=dm)
if __name__ == "__main__":
main()