File size: 7,434 Bytes
f3b050a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# credits: https://github.com/ashleve/lightning-hydra-template/blob/main/src/models/mnist_module.py
from typing import Any
import os
import numpy as np
import torch
from pytorch_lightning import LightningModule
from torchvision.transforms import transforms
from lr_scheduler import LinearWarmupCosineAnnealingLR
from arch import Network
from metric import (
MSE,RMSE,MAE,ACC,WMSE,WRMSE
)
class WeatherForecastModule(LightningModule):
"""Lightning module for global forecasting with the ClimaX model.
Args:
net: Deeplearning model.
pretrained_path (str, optional): Path to pre-trained checkpoint.
lr (float, optional): Learning rate.
beta_1 (float, optional): Beta 1 for AdamW.
beta_2 (float, optional): Beta 2 for AdamW.
weight_decay (float, optional): Weight decay for AdamW.
warmup_epochs (int, optional): Number of warmup epochs.
max_epochs (int, optional): Number of total epochs.
warmup_start_lr (float, optional): Starting learning rate for warmup.
eta_min (float, optional): Minimum learning rate.
"""
def __init__(
self,
net: Network,
pretrained_path: str = "",
lr: float = 5e-4,
beta_1: float = 0.9,
beta_2: float = 0.99,
weight_decay: float = 1e-5,
warmup_epochs: int = 10000,
max_epochs: int = 200000,
warmup_start_lr: float = 1e-8,
eta_min: float = 1e-8,
):
super().__init__()
self.save_hyperparameters(logger=True, ignore=["net"])
self.net = net
if len(pretrained_path) > 0:
self.load_pretrained_weights(pretrained_path)
def load_pretrained_weights(self, pretrained_path):
self.net.load_state_dict(torch.load(pretrained_path))
def set_path(self,path):
self.path = path
def set_size(self,rad_size,sat_size):
self.rad_size = rad_size
self.sat_size = sat_size
def set_lat(self):
lat = np.load(os.path.join(self.path,'sat_lat.npy'))
self.sat_lat = lat[lat.shape[-1]//2-self.sat_size//2:lat.shape[-1]//2+self.sat_size//2]
# self.sat_lat = np.load(os.path.join(self.path,'sat_lat.npy'))
# self.sat_clim = torch.from_numpy(np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation'])
def set_clim(self):
##########
rad_clim = np.load(os.path.join(self.path,'rad_clim.npz'))['precipitation']
sat_clim = np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation']
self.rad_clim = torch.from_numpy(rad_clim)
self.sat_clim = torch.from_numpy(sat_clim)
def set_normalize(self):
self.rad_mean = np.load(os.path.join(self.path,'rad_mean.npz'))['precipitation']
self.rad_std = np.load(os.path.join(self.path,'rad_std.npz'))['precipitation']
self.sat_mean = np.load(os.path.join(self.path,'sat_mean.npz'))['total_precipitation']
self.sat_std = np.load(os.path.join(self.path,'sat_std.npz'))['total_precipitation']
def set_denormalize(self):
self.rad_denormalization = transforms.Normalize(-self.rad_mean/self.rad_std,1/self.rad_std)
self.sat_denormalization = transforms.Normalize(-self.sat_mean/self.sat_std,1/self.sat_std)
def training_step(self, batch: Any, batch_idx: int):
inp_rad, inp_sat, out_rad, out_sat = batch
pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
loss = torch.nn.MSELoss()
loss_rad = loss(pred_rad,out_rad)
loss_sat = loss(pred_sat,out_sat)
loss_tot = loss_rad + loss_sat
self.log("train/rad", loss_rad, prog_bar=True, logger = True)
self.log("train/sat", loss_sat, prog_bar=True, logger = True)
self.log("train/mse", loss_tot, prog_bar=True, logger = True)
return loss_tot
def validation_step(self, batch: Any, batch_idx: int):
inp_rad, inp_sat, out_rad, out_sat = batch
pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
loss = torch.nn.MSELoss()
with torch.no_grad():
loss_rad = loss(pred_rad,out_rad)
loss_sat = loss(pred_sat,out_sat)
loss_tot = loss_rad + loss_sat
self.log("val/rad", loss_rad, prog_bar=True, logger = True)
self.log("val/sat", loss_sat, prog_bar=True, logger = True)
self.log("val/mse", loss_tot, prog_bar=True, logger = True)
return loss_tot
def test_step(self, batch: Any, batch_idx: int):
inp_rad, inp_sat, out_rad, out_sat = batch
pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
loss = torch.nn.MSELoss()
self.rad_denormalization(out_rad)
rad_metric = [MSE,RMSE,ACC,MAE]
sat_metric = [MSE,WMSE,RMSE,WRMSE,ACC,MAE]
with torch.no_grad():
loss_rad = loss(self.rad_denormalization(pred_rad),self.rad_denormalization(out_rad))
loss_sat = loss(self.sat_denormalization(pred_sat),self.sat_denormalization(out_sat))
loss_tot = loss_rad + loss_sat
self.log(f"test/rad", loss_rad, prog_bar=True, logger = True)
self.log("test/sat", loss_sat, prog_bar=True, logger = True)
self.log("test/mse", loss_tot, prog_bar=True, logger = True)
for met in rad_metric:
loss_rad = met(
self.rad_denormalization(pred_rad),
self.rad_denormalization(out_rad),
np.ones(self.rad_size),
self.rad_clim
)
self.log(f"test/rad_{met.__name__}", loss_rad, prog_bar=True, logger = True)
for met in sat_metric:
loss_sat = met(
self.sat_denormalization(pred_sat),
self.sat_denormalization(out_sat),
self.sat_lat,
self.sat_clim,
)
self.log(f"test/sat_{met.__name__}", loss_sat, prog_bar=True, logger = True)
return loss_tot
def configure_optimizers(self):
decay = []
no_decay = []
for name, m in self.named_parameters():
if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name:
no_decay.append(m)
else:
decay.append(m)
optimizer = torch.optim.AdamW(
[
{
"params": decay,
"lr": self.hparams.lr,
"betas": (self.hparams.beta_1, self.hparams.beta_2),
"weight_decay": self.hparams.weight_decay,
},
{
"params": no_decay,
"lr": self.hparams.lr,
"betas": (self.hparams.beta_1, self.hparams.beta_2),
"weight_decay": 0,
},
]
)
lr_scheduler = LinearWarmupCosineAnnealingLR(
optimizer,
self.hparams.warmup_epochs,
self.hparams.max_epochs,
self.hparams.warmup_start_lr,
self.hparams.eta_min,
)
scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1}
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|