File size: 7,434 Bytes
f3b050a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

# credits: https://github.com/ashleve/lightning-hydra-template/blob/main/src/models/mnist_module.py
from typing import Any
import os
import numpy as np
import torch
from pytorch_lightning import LightningModule
from torchvision.transforms import transforms
from lr_scheduler import LinearWarmupCosineAnnealingLR
from arch import Network
from metric import (
    MSE,RMSE,MAE,ACC,WMSE,WRMSE
)
class WeatherForecastModule(LightningModule):
    """Lightning module for global forecasting with the ClimaX model.

    Args:

        net: Deeplearning model.

        pretrained_path (str, optional): Path to pre-trained checkpoint.

        lr (float, optional): Learning rate.

        beta_1 (float, optional): Beta 1 for AdamW.

        beta_2 (float, optional): Beta 2 for AdamW.

        weight_decay (float, optional): Weight decay for AdamW.

        warmup_epochs (int, optional): Number of warmup epochs.

        max_epochs (int, optional): Number of total epochs.

        warmup_start_lr (float, optional): Starting learning rate for warmup.

        eta_min (float, optional): Minimum learning rate.

    """
    def __init__(

        self,

        net: Network,

        pretrained_path: str = "",

        lr: float = 5e-4,

        beta_1: float = 0.9,

        beta_2: float = 0.99,

        weight_decay: float = 1e-5,

        warmup_epochs: int = 10000,

        max_epochs: int = 200000,

        warmup_start_lr: float = 1e-8,

        eta_min: float = 1e-8,

    ):
        super().__init__()
        self.save_hyperparameters(logger=True, ignore=["net"])
        self.net = net
        if len(pretrained_path) > 0:
            self.load_pretrained_weights(pretrained_path)

    def load_pretrained_weights(self, pretrained_path):
        self.net.load_state_dict(torch.load(pretrained_path))
    def set_path(self,path):
        self.path = path
    def set_size(self,rad_size,sat_size):
        self.rad_size = rad_size
        self.sat_size = sat_size

    def set_lat(self):
        lat = np.load(os.path.join(self.path,'sat_lat.npy'))
        self.sat_lat = lat[lat.shape[-1]//2-self.sat_size//2:lat.shape[-1]//2+self.sat_size//2]
        # self.sat_lat = np.load(os.path.join(self.path,'sat_lat.npy'))
        # self.sat_clim = torch.from_numpy(np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation'])
    def set_clim(self):
        ##########
        rad_clim = np.load(os.path.join(self.path,'rad_clim.npz'))['precipitation']
        sat_clim = np.load(os.path.join(self.path,'sat_clim.npz'))['total_precipitation']
        self.rad_clim = torch.from_numpy(rad_clim)
        self.sat_clim = torch.from_numpy(sat_clim)

    def set_normalize(self):
        self.rad_mean = np.load(os.path.join(self.path,'rad_mean.npz'))['precipitation']
        self.rad_std  = np.load(os.path.join(self.path,'rad_std.npz'))['precipitation']
        self.sat_mean = np.load(os.path.join(self.path,'sat_mean.npz'))['total_precipitation']
        self.sat_std  = np.load(os.path.join(self.path,'sat_std.npz'))['total_precipitation']
    def set_denormalize(self):
        self.rad_denormalization = transforms.Normalize(-self.rad_mean/self.rad_std,1/self.rad_std)
        self.sat_denormalization = transforms.Normalize(-self.sat_mean/self.sat_std,1/self.sat_std)
    def training_step(self, batch: Any, batch_idx: int):
        inp_rad, inp_sat, out_rad, out_sat = batch
        pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
        loss = torch.nn.MSELoss()
        loss_rad = loss(pred_rad,out_rad)
        loss_sat = loss(pred_sat,out_sat)
        loss_tot = loss_rad + loss_sat
        self.log("train/rad", loss_rad, prog_bar=True, logger = True)
        self.log("train/sat", loss_sat, prog_bar=True, logger = True)
        self.log("train/mse", loss_tot, prog_bar=True, logger = True)
        return loss_tot

    def validation_step(self, batch: Any, batch_idx: int):
        inp_rad, inp_sat, out_rad, out_sat = batch
        pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
        loss = torch.nn.MSELoss()
        with torch.no_grad():
            loss_rad = loss(pred_rad,out_rad)
            loss_sat = loss(pred_sat,out_sat)
            loss_tot = loss_rad + loss_sat
        self.log("val/rad", loss_rad, prog_bar=True, logger = True)
        self.log("val/sat", loss_sat, prog_bar=True, logger = True)
        self.log("val/mse", loss_tot, prog_bar=True, logger = True)
        return loss_tot
    def test_step(self, batch: Any, batch_idx: int):
        inp_rad, inp_sat, out_rad, out_sat = batch
        pred_rad,pred_sat = self.net.forward(inp_rad,inp_sat)
        loss = torch.nn.MSELoss()
        self.rad_denormalization(out_rad)
        rad_metric = [MSE,RMSE,ACC,MAE]
        sat_metric = [MSE,WMSE,RMSE,WRMSE,ACC,MAE]

        with torch.no_grad():
            loss_rad = loss(self.rad_denormalization(pred_rad),self.rad_denormalization(out_rad))
            loss_sat = loss(self.sat_denormalization(pred_sat),self.sat_denormalization(out_sat))
            loss_tot = loss_rad + loss_sat
            self.log(f"test/rad", loss_rad, prog_bar=True, logger = True)
            self.log("test/sat", loss_sat, prog_bar=True, logger = True)
            self.log("test/mse", loss_tot, prog_bar=True, logger = True)
        for met in rad_metric:
            loss_rad = met(
                self.rad_denormalization(pred_rad), 
                self.rad_denormalization(out_rad),
                np.ones(self.rad_size),
                self.rad_clim
            )
            self.log(f"test/rad_{met.__name__}", loss_rad, prog_bar=True, logger = True)
        for met in sat_metric:
            loss_sat = met(
                self.sat_denormalization(pred_sat),
                self.sat_denormalization(out_sat),
                self.sat_lat,
                self.sat_clim,
            )
            self.log(f"test/sat_{met.__name__}", loss_sat, prog_bar=True, logger = True)
        return loss_tot
    def configure_optimizers(self):
        decay = []
        no_decay = []
        for name, m in self.named_parameters():
            if "var_embed" in name or "pos_embed" in name or "time_pos_embed" in name:
                no_decay.append(m)
            else:
                decay.append(m)
        optimizer = torch.optim.AdamW(
            [
                {
                    "params": decay,
                    "lr": self.hparams.lr,
                    "betas": (self.hparams.beta_1, self.hparams.beta_2),
                    "weight_decay": self.hparams.weight_decay,
                },
                {
                    "params": no_decay,
                    "lr": self.hparams.lr,
                    "betas": (self.hparams.beta_1, self.hparams.beta_2),
                    "weight_decay": 0,
                },
            ]
        )

        lr_scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            self.hparams.warmup_epochs,
            self.hparams.max_epochs,
            self.hparams.warmup_start_lr,
            self.hparams.eta_min,
        )
        scheduler = {"scheduler": lr_scheduler, "interval": "step", "frequency": 1}

        return {"optimizer": optimizer, "lr_scheduler": scheduler}