|
|
|
|
|
|
|
|
| from typing import Any
|
| import os
|
| import numpy as np
|
| import torch
|
| from pytorch_lightning import LightningModule
|
| from torchvision.transforms import transforms
|
| from lr_scheduler import LinearWarmupCosineAnnealingLR
|
| from arch import Network
|
| from metric import (
|
| MSE,RMSE,MAE,ACC,WMSE,WRMSE
|
| )
|
| class WeatherForecastModule(LightningModule):
|
| """Lightning module for global forecasting with the ClimaX model.
|
| Args:
|
| net: Deeplearning model.
|
| pretrained_path (str, optional): Path to pre-trained checkpoint.
|
| lr (float, optional): Learning rate.
|
| beta_1 (float, optional): Beta 1 for AdamW.
|
| beta_2 (float, optional): Beta 2 for AdamW.
|
| weight_decay (float, optional): Weight decay for AdamW.
|
| warmup_epochs (int, optional): Number of warmup epochs.
|
| max_epochs (int, optional): Number of total epochs.
|
| warmup_start_lr (float, optional): Starting learning rate for warmup.
|
| eta_min (float, optional): Minimum learning rate.
|
| """
|
| def __init__(
|
| self,
|
| net: Network,
|
| pretrained_path: str = "",
|
| lr: float = 5e-4,
|
| beta_1: float = 0.9,
|
| beta_2: float = 0.99,
|
| weight_decay: float = 1e-5,
|
| warmup_epochs: int = 10000,
|
| max_epochs: int = 200000,
|
| warmup_start_lr: float = 1e-8,
|
| eta_min: float = 1e-8,
|
| ):
|
| super().__init__()
|
| self.save_hyperparameters(logger=True, ignore=["net"])
|
| self.net = net
|
| if len(pretrained_path) > 0:
|
| self.load_pretrained_weights(pretrained_path)
|
|
|
| def load_pretrained_weights(self, pretrained_path):
|
| self.net.load_state_dict(torch.load(pretrained_path))
|
| def set_path(self,path):
|
| self.path = path
|
| def set_size(self,rad_size,sat_size):
|
| self.rad_size = rad_size
|
| self.sat_size = sat_size
|
|
|
| def set_lat(self):
|
| lat = np.load(os.path.join(self.path,'sat_lat.npy'))
|
| self.sat_lat = lat[lat.shape[-1]//2-self.sat_size//2:lat.shape[-1]//2+self.sat_size//2]
|
|
|
|
|
| def set_clim(self):
|
|
|
| rad_clim = np.load(os.path.join(self.path,'rad_clim.npz'))['precipitation']
|
| sat_clim = np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation']
|
| self.rad_clim = torch.from_numpy(rad_clim)
|
| self.sat_clim = torch.from_numpy(sat_clim)
|
|
|
| def set_normalize(self):
|
| self.rad_mean = np.load(os.path.join(self.path,'rad_mean.npz'))['precipitation']
|
| self.rad_std = np.load(os.path.join(self.path,'rad_std.npz'))['precipitation']
|
| self.sat_mean = np.load(os.path.join(self.path,'sat_mean.npz'))['total_precipitation']
|
| self.sat_std = np.load(os.path.join(self.path,'sat_std.npz'))['total_precipitation']
|
| def set_denormalize(self):
|
| self.rad_denormalization = transforms.Normalize(-self.rad_mean/self.rad_std,1/self.rad_std)
|
| self.sat_denormalization = transforms.Normalize(-self.sat_mean/self.sat_std,1/self.sat_std)
|
| def training_step(self, batch: Any, batch_idx: int):
|
| inp_rad, inp_sat, out_rad, out_sat = batch
|
| pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
|
| loss = torch.nn.MSELoss()
|
| loss_rad = loss(pred_rad,out_rad)
|
| loss_sat = loss(pred_sat,out_sat)
|
| loss_tot = loss_rad + loss_sat
|
| self.log("train/rad", loss_rad, prog_bar=True, logger = True)
|
| self.log("train/sat", loss_sat, prog_bar=True, logger = True)
|
| self.log("train/mse", loss_tot, prog_bar=True, logger = True)
|
| return loss_tot
|
|
|
| def validation_step(self, batch: Any, batch_idx: int):
|
| inp_rad, inp_sat, out_rad, out_sat = batch
|
| pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
|
| loss = torch.nn.MSELoss()
|
| with torch.no_grad():
|
| loss_rad = loss(pred_rad,out_rad)
|
| loss_sat = loss(pred_sat,out_sat)
|
| loss_tot = loss_rad + loss_sat
|
| self.log("val/rad", loss_rad, prog_bar=True, logger = True)
|
| self.log("val/sat", loss_sat, prog_bar=True, logger = True)
|
| self.log("val/mse", loss_tot, prog_bar=True, logger = True)
|
| return loss_tot
|
| def test_step(self, batch: Any, batch_idx: int):
|
| inp_rad, inp_sat, out_rad, out_sat = batch
|
| pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
|
| loss = torch.nn.MSELoss()
|
| self.rad_denormalization(out_rad)
|
| rad_metric = [MSE,RMSE,ACC,MAE]
|
| sat_metric = [MSE,WMSE,RMSE,WRMSE,ACC,MAE]
|
|
|
| with torch.no_grad():
|
| loss_rad = loss(self.rad_denormalization(pred_rad),self.rad_denormalization(out_rad))
|
| loss_sat = loss(self.sat_denormalization(pred_sat),self.sat_denormalization(out_sat))
|
| loss_tot = loss_rad + loss_sat
|
| self.log(f"test/rad", loss_rad, prog_bar=True, logger = True)
|
| self.log("test/sat", loss_sat, prog_bar=True, logger = True)
|
| self.log("test/mse", loss_tot, prog_bar=True, logger = True)
|
| for met in rad_metric:
|
| loss_rad = met(
|
| self.rad_denormalization(pred_rad),
|
| self.rad_denormalization(out_rad),
|
| np.ones(self.rad_size),
|
| self.rad_clim
|
| )
|
| self.log(f"test/rad_{met.__name__}", loss_rad, prog_bar=True, logger = True)
|
| for met in sat_metric:
|
| loss_sat = met(
|
| self.sat_denormalization(pred_sat),
|
| self.sat_denormalization(out_sat),
|
| self.sat_lat,
|
| self.sat_clim,
|
| )
|
| self.log(f"test/sat_{met.__name__}", loss_sat, prog_bar=True, logger = True)
|
| return loss_tot
|
| def configure_optimizers(self):
|
| decay = []
|
| no_decay = []
|
| for name, m in self.named_parameters():
|
| if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name:
|
| no_decay.append(m)
|
| else:
|
| decay.append(m)
|
| optimizer = torch.optim.AdamW(
|
| [
|
| {
|
| "params": decay,
|
| "lr": self.hparams.lr,
|
| "betas": (self.hparams.beta_1, self.hparams.beta_2),
|
| "weight_decay": self.hparams.weight_decay,
|
| },
|
| {
|
| "params": no_decay,
|
| "lr": self.hparams.lr,
|
| "betas": (self.hparams.beta_1, self.hparams.beta_2),
|
| "weight_decay": 0,
|
| },
|
| ]
|
| )
|
|
|
| lr_scheduler = LinearWarmupCosineAnnealingLR(
|
| optimizer,
|
| self.hparams.warmup_epochs,
|
| self.hparams.max_epochs,
|
| self.hparams.warmup_start_lr,
|
| self.hparams.eta_min,
|
| )
|
| scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1}
|
|
|
| return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
|
|
|
|