unet_code / src /module.py
weatherforecast1024's picture
Upload folder using huggingface_hub
f3b050a verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# credits: https://github.com/ashleve/lightning-hydra-template/blob/main/src/models/mnist_module.py
from typing import Any
import os
import numpy as np
import torch
from pytorch_lightning import LightningModule
from torchvision.transforms import transforms
from lr_scheduler import LinearWarmupCosineAnnealingLR
from arch import Network
from metric import (
MSE,RMSE,MAE,ACC,WMSE,WRMSE
)
class WeatherForecastModule(LightningModule):
"""Lightning module for global forecasting with the ClimaX model.
Args:
net: Deeplearning model.
pretrained_path (str, optional): Path to pre-trained checkpoint.
lr (float, optional): Learning rate.
beta_1 (float, optional): Beta 1 for AdamW.
beta_2 (float, optional): Beta 2 for AdamW.
weight_decay (float, optional): Weight decay for AdamW.
warmup_epochs (int, optional): Number of warmup epochs.
max_epochs (int, optional): Number of total epochs.
warmup_start_lr (float, optional): Starting learning rate for warmup.
eta_min (float, optional): Minimum learning rate.
"""
def __init__(
self,
net: Network,
pretrained_path: str = "",
lr: float = 5e-4,
beta_1: float = 0.9,
beta_2: float = 0.99,
weight_decay: float = 1e-5,
warmup_epochs: int = 10000,
max_epochs: int = 200000,
warmup_start_lr: float = 1e-8,
eta_min: float = 1e-8,
):
super().__init__()
self.save_hyperparameters(logger=True, ignore=["net"])
self.net = net
if len(pretrained_path) > 0:
self.load_pretrained_weights(pretrained_path)
def load_pretrained_weights(self, pretrained_path):
self.net.load_state_dict(torch.load(pretrained_path))
def set_path(self,path):
self.path = path
def set_size(self,rad_size,sat_size):
self.rad_size = rad_size
self.sat_size = sat_size
def set_lat(self):
lat = np.load(os.path.join(self.path,'sat_lat.npy'))
self.sat_lat = lat[lat.shape[-1]//2-self.sat_size//2:lat.shape[-1]//2+self.sat_size//2]
# self.sat_lat = np.load(os.path.join(self.path,'sat_lat.npy'))
# self.sat_clim = torch.from_numpy(np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation'])
def set_clim(self):
##########
rad_clim = np.load(os.path.join(self.path,'rad_clim.npz'))['precipitation']
sat_clim = np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation']
self.rad_clim = torch.from_numpy(rad_clim)
self.sat_clim = torch.from_numpy(sat_clim)
def set_normalize(self):
self.rad_mean = np.load(os.path.join(self.path,'rad_mean.npz'))['precipitation']
self.rad_std = np.load(os.path.join(self.path,'rad_std.npz'))['precipitation']
self.sat_mean = np.load(os.path.join(self.path,'sat_mean.npz'))['total_precipitation']
self.sat_std = np.load(os.path.join(self.path,'sat_std.npz'))['total_precipitation']
def set_denormalize(self):
self.rad_denormalization = transforms.Normalize(-self.rad_mean/self.rad_std,1/self.rad_std)
self.sat_denormalization = transforms.Normalize(-self.sat_mean/self.sat_std,1/self.sat_std)
def training_step(self, batch: Any, batch_idx: int):
inp_rad, inp_sat, out_rad, out_sat = batch
pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
loss = torch.nn.MSELoss()
loss_rad = loss(pred_rad,out_rad)
loss_sat = loss(pred_sat,out_sat)
loss_tot = loss_rad + loss_sat
self.log("train/rad", loss_rad, prog_bar=True, logger = True)
self.log("train/sat", loss_sat, prog_bar=True, logger = True)
self.log("train/mse", loss_tot, prog_bar=True, logger = True)
return loss_tot
def validation_step(self, batch: Any, batch_idx: int):
inp_rad, inp_sat, out_rad, out_sat = batch
pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
loss = torch.nn.MSELoss()
with torch.no_grad():
loss_rad = loss(pred_rad,out_rad)
loss_sat = loss(pred_sat,out_sat)
loss_tot = loss_rad + loss_sat
self.log("val/rad", loss_rad, prog_bar=True, logger = True)
self.log("val/sat", loss_sat, prog_bar=True, logger = True)
self.log("val/mse", loss_tot, prog_bar=True, logger = True)
return loss_tot
def test_step(self, batch: Any, batch_idx: int):
inp_rad, inp_sat, out_rad, out_sat = batch
pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
loss = torch.nn.MSELoss()
self.rad_denormalization(out_rad)
rad_metric = [MSE,RMSE,ACC,MAE]
sat_metric = [MSE,WMSE,RMSE,WRMSE,ACC,MAE]
with torch.no_grad():
loss_rad = loss(self.rad_denormalization(pred_rad),self.rad_denormalization(out_rad))
loss_sat = loss(self.sat_denormalization(pred_sat),self.sat_denormalization(out_sat))
loss_tot = loss_rad + loss_sat
self.log(f"test/rad", loss_rad, prog_bar=True, logger = True)
self.log("test/sat", loss_sat, prog_bar=True, logger = True)
self.log("test/mse", loss_tot, prog_bar=True, logger = True)
for met in rad_metric:
loss_rad = met(
self.rad_denormalization(pred_rad),
self.rad_denormalization(out_rad),
np.ones(self.rad_size),
self.rad_clim
)
self.log(f"test/rad_{met.__name__}", loss_rad, prog_bar=True, logger = True)
for met in sat_metric:
loss_sat = met(
self.sat_denormalization(pred_sat),
self.sat_denormalization(out_sat),
self.sat_lat,
self.sat_clim,
)
self.log(f"test/sat_{met.__name__}", loss_sat, prog_bar=True, logger = True)
return loss_tot
def configure_optimizers(self):
decay = []
no_decay = []
for name, m in self.named_parameters():
if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name:
no_decay.append(m)
else:
decay.append(m)
optimizer = torch.optim.AdamW(
[
{
"params": decay,
"lr": self.hparams.lr,
"betas": (self.hparams.beta_1, self.hparams.beta_2),
"weight_decay": self.hparams.weight_decay,
},
{
"params": no_decay,
"lr": self.hparams.lr,
"betas": (self.hparams.beta_1, self.hparams.beta_2),
"weight_decay": 0,
},
]
)
lr_scheduler = LinearWarmupCosineAnnealingLR(
optimizer,
self.hparams.warmup_epochs,
self.hparams.max_epochs,
self.hparams.warmup_start_lr,
self.hparams.eta_min,
)
scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1}
return {"optimizer": optimizer, "lr_scheduler": scheduler}