| import os
|
| from torch.utils.data import DataLoader, Dataset, random_split
|
| import numpy as np
|
| from datetime import datetime, timedelta
|
| from torchvision import transforms
|
| from pytorch_lightning import LightningDataModule, LightningModule
|
| from pytorch_lightning.cli import LightningCLI
|
| from torch.utils.data import DataLoader
|
| import pytorch_lightning as L
|
| import torch
|
| import torch.nn as nn
|
| from typing import Tuple, Dict, List
|
|
|
|
|
|
|
| class DataReader(Dataset):
|
| def __init__(
|
| self, dir_data : str,
|
| type_data : str,
|
| rad_attribute : str ,
|
| sat_attribute : str,
|
| hours_predicted : int,
|
| rad_predicted : str ,
|
| sat_predicted : str ,
|
| time_points_rad : int,
|
| time_points_sat : int,
|
| rad_size:int,
|
| sat_size:int,
|
| ablation = str,
|
| ):
|
| super().__init__()
|
| self.base_dir=dir_data
|
| self.type_data = type_data
|
| if self.type_data == "train":
|
| self.dir_data=os.path.join(dir_data, "train")
|
| elif self.type_data =="test":
|
| self.dir_data=os.path.join(dir_data, 'test')
|
| elif self.type_data =="val":
|
| self.dir_data=os.path.join(dir_data, 'val')
|
| else:
|
| raise ValueError("Type must be train, test or val")
|
| self.sat_size = sat_size
|
| self.rad_size = rad_size
|
| self.hours_predicted = hours_predicted
|
| self.rad_attribute = rad_attribute
|
| self.sat_attribute = sat_attribute
|
| self.rad_predicted = rad_predicted
|
| self.sat_predicted = sat_predicted
|
| self.time_points_rad = time_points_rad
|
| self.time_points_sat = time_points_sat
|
| self.transform_rad = None
|
| self.transform_sat = None
|
| self.ablation = ablation
|
|
|
| self.rad_mean = np.load(os.path.join(self.base_dir,'rad_mean.npz'))[self.rad_attribute]
|
| self.rad_std = np.load(os.path.join(self.base_dir,'rad_std.npz'))[self.rad_attribute]
|
| self.sat_mean = np.load(os.path.join(self.base_dir,'sat_mean.npz'))[self.sat_attribute]
|
| self.sat_std = np.load(os.path.join(self.base_dir,'sat_std.npz'))[self.sat_attribute]
|
|
|
| self.create_transform()
|
|
|
| if(self.ablation == "no"):
|
| self.list_img_dir = self.gen_list_img_no(self.dir_data)
|
| elif(self.ablation == "rad"):
|
| self.list_img_dir = self.gen_list_img_rad(self.dir_data)
|
| elif(self.ablation == "sat"):
|
| self.list_img_dir = self.gen_list_img_sat(self.dir_data)
|
| elif(self.ablation == "full"):
|
| self.list_img_dir = self.gen_list_img_full(self.dir_data)
|
| elif(self.ablation == "time"):
|
| self.list_img_dir = self.gen_list_img_time(self.dir_data)
|
| else:
|
| raise ValueError("Ablation must be no,rad,sat,full")
|
| print(f"Number of {self.type_data } samples:",len(self.list_img_dir))
|
| def __len__(self):
|
| return len(self.list_img_dir)
|
| def __getitem__(self, idx):
|
| if(self.transform_rad):
|
| inp_rad = self.transform_rad(np.load(self.list_img_dir[idx][0])[self.rad_attribute])
|
| out_rad = self.transform_rad(np.load(self.list_img_dir[idx][2])[self.rad_predicted])
|
| if(self.transform_sat):
|
| inp_sat = self.transform_sat(np.load(self.list_img_dir[idx][1])[self.sat_attribute])
|
| out_sat = self.transform_sat(np.load(self.list_img_dir[idx][3])[self.sat_predicted][0])
|
| return inp_rad,inp_sat.float(),out_rad, out_sat.float()
|
|
|
|
|
| def create_transform(self):
|
| self.transform_rad = transforms.Compose([
|
| transforms.ToTensor(),
|
| transforms.Normalize(self.rad_mean,self.rad_std)
|
| ])
|
| self.transform_sat = transforms.Compose([
|
| transforms.ToTensor(),
|
| transforms.Normalize(self.sat_mean[0],self.sat_std[0]),
|
| ])
|
|
|
|
|
| def gen_list_img_no(self,path):
|
| pred_rad_dir =os.path.join(path,"pred_rad")
|
| pred_sat_dir = os.path.join(path,"pred_sat")
|
| GT_rad_dir = os.path.join(path ,"rad")
|
| GT_sat_dir = os.path.join(path,"sat")
|
| list_dir = []
|
|
|
|
|
| for name in os.listdir(pred_rad_dir):
|
| temp = []
|
| if(not name.endswith("00.npz") and not name.endswith("03.npz")):
|
| continue
|
| temp.append(os.path.join(pred_rad_dir,name))
|
| pred_sat_path = os.path.join(pred_sat_dir,name[0:-6]+name[-4:])
|
| GT_rad_path = os.path.join(GT_rad_dir, name)
|
| GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
|
| if(os.path.isfile(pred_sat_path)):
|
| temp.append(pred_sat_path)
|
| if(os.path.isfile(GT_rad_path)):
|
| temp.append(GT_rad_path)
|
| if(os.path.isfile(GT_sat_path)):
|
| temp.append(GT_sat_path)
|
| if(len(temp) == 4):
|
| list_dir.append(temp)
|
| return list_dir
|
| def gen_list_img_rad(self,path):
|
| pred_rad_dir = os.path.join(path,"rad")
|
| pred_sat_dir = os.path.join(path,"pred_sat")
|
| GT_rad_dir = os.path.join(path ,"rad")
|
| GT_sat_dir = os.path.join(path,"sat")
|
| list_dir = []
|
| for name in os.listdir(pred_rad_dir):
|
| temp = []
|
| if( not name.endswith("00.npz") and not name.endswith("03.npz")):
|
| continue
|
| temp_date = self.get_date_time(name)
|
| temp.append(os.path.join(pred_rad_dir,name))
|
| pred_sat_path = os.path.join(pred_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
|
| GT_rad_path = os.path.join(GT_rad_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H%M') + '.npz')
|
| GT_sat_path = os.path.join(GT_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
|
| if(os.path.isfile(pred_sat_path)):
|
| temp.append(pred_sat_path)
|
| if(os.path.isfile(GT_rad_path)):
|
| temp.append(GT_rad_path)
|
| if(os.path.isfile(GT_sat_path)):
|
| temp.append(GT_sat_path)
|
| if(len(temp) == 4):
|
| list_dir.append(temp)
|
| return list_dir
|
| def gen_list_img_sat(self,path):
|
| pred_rad_dir = os.path.join(path,"pred_rad")
|
| pred_sat_dir = os.path.join(path,"sat")
|
| GT_rad_dir = os.path.join(path ,"rad")
|
| GT_sat_dir = os.path.join(path,"sat")
|
| list_dir = []
|
| for name in os.listdir(pred_rad_dir):
|
| temp = []
|
| if( not name.endswith("00.npz") and not name.endswith("03.npz")):
|
| continue
|
| temp_date = self.get_date_time(name)
|
| temp.append(os.path.join(pred_rad_dir,name))
|
| pred_sat_path = os.path.join(pred_sat_dir, (temp_date-timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
|
| GT_rad_path = os.path.join(GT_rad_dir, name)
|
| GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
|
| if(os.path.isfile(pred_sat_path)):
|
| temp.append(pred_sat_path)
|
| if(os.path.isfile(GT_rad_path)):
|
| temp.append(GT_rad_path)
|
| if(os.path.isfile(GT_sat_path)):
|
| temp.append(GT_sat_path)
|
| if(len(temp) == 4):
|
| list_dir.append(temp)
|
| return list_dir
|
| def gen_list_img_full(self,path):
|
| pred_rad_dir = os.path.join(path,"rad")
|
| pred_sat_dir = os.path.join(path,"sat")
|
| GT_rad_dir = os.path.join(path ,"rad")
|
| GT_sat_dir = os.path.join(path,"sat")
|
| list_dir = []
|
| for name in os.listdir(pred_rad_dir):
|
| temp = []
|
| if(not name.endswith("00.npz") and not name.endswith("03.npz")):
|
| continue
|
| temp_date = self.get_date_time(name)
|
| temp.append(os.path.join(pred_rad_dir,name))
|
| pred_sat_path = os.path.join(pred_sat_dir,temp_date.strftime('%Y%m%d%H')+'.npz')
|
| GT_rad_path = os.path.join(GT_rad_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H%M') + '.npz')
|
| GT_sat_path = os.path.join(GT_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
|
| if(os.path.isfile(pred_sat_path)):
|
| temp.append(pred_sat_path)
|
| if(os.path.isfile(GT_rad_path)):
|
| temp.append(GT_rad_path)
|
| if(os.path.isfile(GT_sat_path)):
|
| temp.append(GT_sat_path)
|
| if(len(temp) == 4):
|
| list_dir.append(temp)
|
| return list_dir
|
| def gen_list_img_time(self,path):
|
| pred_rad_dir =os.path.join(path,"pred_rad")
|
| pred_sat_dir = os.path.join(path,"pred_sat")
|
| GT_rad_dir = os.path.join(path ,"rad")
|
| GT_sat_dir = os.path.join(path,"sat")
|
| list_dir = []
|
| for name in os.listdir(pred_rad_dir):
|
| temp = [[],[],[],[]]
|
| temp_date = self.get_date_time(name)
|
| if(not name.endswith("00.npz") and not name.endswith("03.npz")):
|
| continue
|
| for i in range(4):
|
| temp_path = os.path.join(GT_rad_dir, (temp_date+timedelta(minutes=-210+i*10)).strftime('%Y%m%d%H%M') + '.npz')
|
| if(os.path.isfile(temp_path)): temp[0].append(temp_path)
|
| for i in range(1):
|
| temp_path = os.path.join(GT_sat_dir, (temp_date+timedelta(minutes=-180+i*10)).strftime('%Y%m%d%H') + '.npz')
|
| if(os.path.isfile(temp_path)): temp[1].append(temp_path)
|
| temp[0].append(os.path.join(pred_rad_dir,name))
|
| pred_sat_path = os.path.join(pred_sat_dir,name[0:-6]+name[-4:])
|
| GT_rad_path = os.path.join(GT_rad_dir, name)
|
| GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
|
| if(os.path.isfile(pred_sat_path)):
|
| temp[1].append(pred_sat_path)
|
| if(os.path.isfile(GT_rad_path)):
|
| temp[2].append(GT_rad_path)
|
| if(os.path.isfile(GT_sat_path)):
|
| temp[3].append(GT_sat_path)
|
| if(len(temp[0]) == 5 and len(temp[1]) == 2 and len(temp[2]) == 1 and len(temp[3]) == 1):
|
| list_dir.append(temp)
|
| return list_dir
|
| def get_date_time(self,name):
|
| year=int(name[0:4])
|
| month=int(name[4:6])
|
| day=int(name[6:8])
|
| hour=int(name[8:10])
|
| minute = int(name[10:12])
|
| return datetime(year,month,day,hour,minute)
|
|
|
|
|
|
|
|
|
| class WeatherForecastDataModule(LightningDataModule):
|
| def __init__(
|
| self,
|
| dir_data: str,
|
| batch_size:int ,
|
| hours_predicted :int,
|
| num_workers:int ,
|
| pin_memory: bool ,
|
| time_points_rad : int,
|
| time_points_sat : int,
|
| sat_inp_vars: str,
|
| sat_out_vars : str,
|
| sat_size: int,
|
| rad_inp_vars : str,
|
| rad_out_vars : str,
|
| rad_size: int,
|
| ablation: str,
|
| ):
|
|
|
| super().__init__()
|
|
|
| self.save_hyperparameters(logger=True)
|
| self.data_train = None
|
| self.data_test = None
|
| self.data_val = None
|
| self.rad_mean = np.load(os.path.join(self.hparams.dir_data,'rad_mean.npz'))[self.hparams.rad_inp_vars]
|
| self.rad_std = np.load(os.path.join(self.hparams.dir_data,'rad_std.npz'))[self.hparams.rad_inp_vars]
|
| self.sat_mean = np.load(os.path.join(self.hparams.dir_data,'sat_mean.npz'))[self.hparams.sat_inp_vars]
|
| self.sat_std = np.load(os.path.join(self.hparams.dir_data,'sat_std.npz'))[self.hparams.sat_inp_vars]
|
| def prepare_data(self):
|
| pass
|
|
|
| def setup(self, stage):
|
|
|
| self.data_train = DataReader(
|
| dir_data=self.hparams.dir_data,
|
| type_data= "train",
|
| rad_attribute = self.hparams.rad_inp_vars,
|
| sat_attribute = self.hparams.sat_inp_vars,
|
| hours_predicted = self.hparams.hours_predicted,
|
| rad_predicted = self.hparams.rad_out_vars,
|
| sat_predicted = self.hparams.sat_out_vars,
|
| time_points_rad = self.hparams.time_points_rad,
|
| time_points_sat = self.hparams.time_points_sat,
|
| sat_size = self.hparams.sat_size,
|
| rad_size = self.hparams.rad_size,
|
| ablation = self.hparams.ablation
|
| )
|
| self.data_test = DataReader(
|
| dir_data=self.hparams.dir_data,
|
| type_data ="test",
|
| rad_attribute = self.hparams.rad_inp_vars,
|
| sat_attribute = self.hparams.sat_inp_vars,
|
| hours_predicted = self.hparams.hours_predicted,
|
| rad_predicted = self.hparams.rad_out_vars,
|
| sat_predicted = self.hparams.sat_out_vars,
|
| time_points_rad = self.hparams.time_points_rad,
|
| time_points_sat = self.hparams.time_points_sat,
|
| sat_size = self.hparams.sat_size,
|
| rad_size = self.hparams.rad_size,
|
| ablation = self.hparams.ablation
|
| )
|
| self.data_val = DataReader(
|
| dir_data=self.hparams.dir_data,
|
| type_data = "val",
|
| rad_attribute = self.hparams.rad_inp_vars,
|
| sat_attribute = self.hparams.sat_inp_vars,
|
| hours_predicted = self.hparams.hours_predicted,
|
| rad_predicted = self.hparams.rad_out_vars,
|
| sat_predicted = self.hparams.sat_out_vars,
|
| time_points_rad = self.hparams.time_points_rad,
|
| time_points_sat = self.hparams.time_points_sat,
|
| sat_size = self.hparams.sat_size,
|
| rad_size = self.hparams.rad_size,
|
| ablation = self.hparams.ablation
|
| )
|
|
|
| def train_dataloader(self):
|
| return DataLoader(
|
| self.data_train,
|
| batch_size=self.hparams.batch_size,
|
| num_workers=self.hparams.num_workers,
|
| drop_last=False,
|
| pin_memory=self.hparams.pin_memory,
|
| shuffle=True,
|
| )
|
|
|
| def val_dataloader(self):
|
| return DataLoader(
|
| self.data_val,
|
| batch_size=self.hparams.batch_size,
|
| num_workers=self.hparams.num_workers,
|
| drop_last=False,
|
| pin_memory=self.hparams.pin_memory,
|
| shuffle=False,
|
| )
|
|
|
| def test_dataloader(self):
|
| return DataLoader(
|
| self.data_test,
|
| batch_size=self.hparams.batch_size,
|
| num_workers=self.hparams.num_workers,
|
| drop_last=False,
|
| pin_memory=self.hparams.pin_memory,
|
| shuffle=False,
|
| )
|
|
|