# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. # credits: https://github.com/ashleve/lightning-hydra-template/blob/main/src/models/mnist_module.py from typing import Any import os import numpy as np import torch from pytorch_lightning import LightningModule from torchvision.transforms import transforms from lr_scheduler import LinearWarmupCosineAnnealingLR from arch import Network from metric import ( MSE,RMSE,MAE,ACC,WMSE,WRMSE ) class WeatherForecastModule(LightningModule): """Lightning module for global forecasting with the ClimaX model. Args: net: Deeplearning model. pretrained_path (str, optional): Path to pre-trained checkpoint. lr (float, optional): Learning rate. beta_1 (float, optional): Beta 1 for AdamW. beta_2 (float, optional): Beta 2 for AdamW. weight_decay (float, optional): Weight decay for AdamW. warmup_epochs (int, optional): Number of warmup epochs. max_epochs (int, optional): Number of total epochs. warmup_start_lr (float, optional): Starting learning rate for warmup. eta_min (float, optional): Minimum learning rate. """ def __init__( self, net: Network, pretrained_path: str = "", lr: float = 5e-4, beta_1: float = 0.9, beta_2: float = 0.99, weight_decay: float = 1e-5, warmup_epochs: int = 10000, max_epochs: int = 200000, warmup_start_lr: float = 1e-8, eta_min: float = 1e-8, ): super().__init__() self.save_hyperparameters(logger=True, ignore=["net"]) self.net = net if len(pretrained_path) > 0: self.load_pretrained_weights(pretrained_path) def load_pretrained_weights(self, pretrained_path): self.net.load_state_dict(torch.load(pretrained_path)) def set_path(self,path): self.path = path def set_size(self,rad_size,sat_size): self.rad_size = rad_size self.sat_size = sat_size def set_lat(self): lat = np.load(os.path.join(self.path,'sat_lat.npy')) self.sat_lat = lat[lat.shape[-1]//2-self.sat_size//2:lat.shape[-1]//2+self.sat_size//2] # self.sat_lat = np.load(os.path.join(self.path,'sat_lat.npy')) # self.sat_clim = torch.from_numpy(np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation']) def set_clim(self): ########## rad_clim = np.load(os.path.join(self.path,'rad_clim.npz'))['precipitation'] sat_clim = np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation'] self.rad_clim = torch.from_numpy(rad_clim) self.sat_clim = torch.from_numpy(sat_clim) def set_normalize(self): self.rad_mean = np.load(os.path.join(self.path,'rad_mean.npz'))['precipitation'] self.rad_std = np.load(os.path.join(self.path,'rad_std.npz'))['precipitation'] self.sat_mean = np.load(os.path.join(self.path,'sat_mean.npz'))['total_precipitation'] self.sat_std = np.load(os.path.join(self.path,'sat_std.npz'))['total_precipitation'] def set_denormalize(self): self.rad_denormalization = transforms.Normalize(-self.rad_mean/self.rad_std,1/self.rad_std) self.sat_denormalization = transforms.Normalize(-self.sat_mean/self.sat_std,1/self.sat_std) def training_step(self, batch: Any, batch_idx: int): inp_rad, inp_sat, out_rad, out_sat = batch pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat) loss = torch.nn.MSELoss() loss_rad = loss(pred_rad,out_rad) loss_sat = loss(pred_sat,out_sat) loss_tot = loss_rad + loss_sat self.log("train/rad", loss_rad, prog_bar=True, logger = True) self.log("train/sat", loss_sat, prog_bar=True, logger = True) self.log("train/mse", loss_tot, prog_bar=True, logger = True) return loss_tot def validation_step(self, batch: Any, batch_idx: int): inp_rad, inp_sat, out_rad, out_sat = batch pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat) loss = torch.nn.MSELoss() with torch.no_grad(): loss_rad = loss(pred_rad,out_rad) loss_sat = loss(pred_sat,out_sat) loss_tot = loss_rad + loss_sat self.log("val/rad", loss_rad, prog_bar=True, logger = True) self.log("val/sat", loss_sat, prog_bar=True, logger = True) self.log("val/mse", loss_tot, prog_bar=True, logger = True) return loss_tot def test_step(self, batch: Any, batch_idx: int): inp_rad, inp_sat, out_rad, out_sat = batch pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat) loss = torch.nn.MSELoss() self.rad_denormalization(out_rad) rad_metric = [MSE,RMSE,ACC,MAE] sat_metric = [MSE,WMSE,RMSE,WRMSE,ACC,MAE] with torch.no_grad(): loss_rad = loss(self.rad_denormalization(pred_rad),self.rad_denormalization(out_rad)) loss_sat = loss(self.sat_denormalization(pred_sat),self.sat_denormalization(out_sat)) loss_tot = loss_rad + loss_sat self.log(f"test/rad", loss_rad, prog_bar=True, logger = True) self.log("test/sat", loss_sat, prog_bar=True, logger = True) self.log("test/mse", loss_tot, prog_bar=True, logger = True) for met in rad_metric: loss_rad = met( self.rad_denormalization(pred_rad), self.rad_denormalization(out_rad), np.ones(self.rad_size), self.rad_clim ) self.log(f"test/rad_{met.__name__}", loss_rad, prog_bar=True, logger = True) for met in sat_metric: loss_sat = met( self.sat_denormalization(pred_sat), self.sat_denormalization(out_sat), self.sat_lat, self.sat_clim, ) self.log(f"test/sat_{met.__name__}", loss_sat, prog_bar=True, logger = True) return loss_tot def configure_optimizers(self): decay = [] no_decay = [] for name, m in self.named_parameters(): if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name: no_decay.append(m) else: decay.append(m) optimizer = torch.optim.AdamW( [ { "params": decay, "lr": self.hparams.lr, "betas": (self.hparams.beta_1, self.hparams.beta_2), "weight_decay": self.hparams.weight_decay, }, { "params": no_decay, "lr": self.hparams.lr, "betas": (self.hparams.beta_1, self.hparams.beta_2), "weight_decay": 0, }, ] ) lr_scheduler = LinearWarmupCosineAnnealingLR( optimizer, self.hparams.warmup_epochs, self.hparams.max_epochs, self.hparams.warmup_start_lr, self.hparams.eta_min, ) scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1} return {"optimizer": optimizer, "lr_scheduler": scheduler}