import os from torch.utils.data import DataLoader, Dataset, random_split import numpy as np from datetime import datetime, timedelta from torchvision import transforms from pytorch_lightning import LightningDataModule, LightningModule from pytorch_lightning.cli import LightningCLI from torch.utils.data import DataLoader import pytorch_lightning as L import torch import torch.nn as nn from typing import Tuple, Dict, List # import optim class DataReader(Dataset): def __init__( self, dir_data : str, type_data : str, rad_attribute : str , sat_attribute : str, hours_predicted : int, rad_predicted : str , sat_predicted : str , time_points_rad : int, time_points_sat : int, rad_size:int, sat_size:int, ablation = str, ): super().__init__() self.base_dir=dir_data self.type_data = type_data if self.type_data == "train": self.dir_data=os.path.join(dir_data, "train") elif self.type_data =="test": self.dir_data=os.path.join(dir_data, 'test') elif self.type_data =="val": self.dir_data=os.path.join(dir_data, 'val') else: raise ValueError("Type must be train, test or val") self.sat_size = sat_size self.rad_size = rad_size self.hours_predicted = hours_predicted self.rad_attribute = rad_attribute self.sat_attribute = sat_attribute self.rad_predicted = rad_predicted self.sat_predicted = sat_predicted self.time_points_rad = time_points_rad self.time_points_sat = time_points_sat self.transform_rad = None self.transform_sat = None self.ablation = ablation # Create path for img self.rad_mean = np.load(os.path.join(self.base_dir,'rad_mean.npz'))[self.rad_attribute] self.rad_std = np.load(os.path.join(self.base_dir,'rad_std.npz'))[self.rad_attribute] self.sat_mean = np.load(os.path.join(self.base_dir,'sat_mean.npz'))[self.sat_attribute] self.sat_std = np.load(os.path.join(self.base_dir,'sat_std.npz'))[self.sat_attribute] #Create transform self.create_transform() #Get list img if(self.ablation == "no"): self.list_img_dir = self.gen_list_img_no(self.dir_data) elif(self.ablation == "rad"): self.list_img_dir = self.gen_list_img_rad(self.dir_data) elif(self.ablation == "sat"): self.list_img_dir = self.gen_list_img_sat(self.dir_data) elif(self.ablation == "full"): self.list_img_dir = self.gen_list_img_full(self.dir_data) elif(self.ablation == "time"): self.list_img_dir = self.gen_list_img_time(self.dir_data) else: raise ValueError("Ablation must be no,rad,sat,full") print(f"Number of {self.type_data } samples:",len(self.list_img_dir)) def __len__(self): return len(self.list_img_dir) def __getitem__(self, idx): if(self.transform_rad): inp_rad = self.transform_rad(np.load(self.list_img_dir[idx][0])[self.rad_attribute]) out_rad = self.transform_rad(np.load(self.list_img_dir[idx][2])[self.rad_predicted]) if(self.transform_sat): inp_sat = self.transform_sat(np.load(self.list_img_dir[idx][1])[self.sat_attribute]) out_sat = self.transform_sat(np.load(self.list_img_dir[idx][3])[self.sat_predicted][0]) return inp_rad,inp_sat.float(),out_rad, out_sat.float() def create_transform(self): self.transform_rad = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(self.rad_mean,self.rad_std) ]) self.transform_sat = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(self.sat_mean[0],self.sat_std[0]), ]) # print("SAT_MEAN", self.sat_mean, self.sat_std) def gen_list_img_no(self,path): pred_rad_dir =os.path.join(path,"pred_rad") pred_sat_dir = os.path.join(path,"pred_sat") GT_rad_dir = os.path.join(path ,"rad") GT_sat_dir = os.path.join(path,"sat") list_dir = [] # print() # print(len(os.listdir(pred_rad_dir))) for name in os.listdir(pred_rad_dir): temp = [] if(not name.endswith("00.npz") and not name.endswith("03.npz")): continue temp.append(os.path.join(pred_rad_dir,name)) pred_sat_path = os.path.join(pred_sat_dir,name[0:-6]+name[-4:]) GT_rad_path = os.path.join(GT_rad_dir, name) GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:]) if(os.path.isfile(pred_sat_path)): temp.append(pred_sat_path) if(os.path.isfile(GT_rad_path)): temp.append(GT_rad_path) if(os.path.isfile(GT_sat_path)): temp.append(GT_sat_path) if(len(temp) == 4): list_dir.append(temp) return list_dir def gen_list_img_rad(self,path): pred_rad_dir = os.path.join(path,"rad") pred_sat_dir = os.path.join(path,"pred_sat") GT_rad_dir = os.path.join(path ,"rad") GT_sat_dir = os.path.join(path,"sat") list_dir = [] for name in os.listdir(pred_rad_dir): temp = [] if( not name.endswith("00.npz") and not name.endswith("03.npz")): continue temp_date = self.get_date_time(name) temp.append(os.path.join(pred_rad_dir,name)) pred_sat_path = os.path.join(pred_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz') GT_rad_path = os.path.join(GT_rad_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H%M') + '.npz') GT_sat_path = os.path.join(GT_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz') if(os.path.isfile(pred_sat_path)): temp.append(pred_sat_path) if(os.path.isfile(GT_rad_path)): temp.append(GT_rad_path) if(os.path.isfile(GT_sat_path)): temp.append(GT_sat_path) if(len(temp) == 4): list_dir.append(temp) return list_dir def gen_list_img_sat(self,path): pred_rad_dir = os.path.join(path,"pred_rad") pred_sat_dir = os.path.join(path,"sat") GT_rad_dir = os.path.join(path ,"rad") GT_sat_dir = os.path.join(path,"sat") list_dir = [] for name in os.listdir(pred_rad_dir): temp = [] if( not name.endswith("00.npz") and not name.endswith("03.npz")): continue temp_date = self.get_date_time(name) temp.append(os.path.join(pred_rad_dir,name)) pred_sat_path = os.path.join(pred_sat_dir, (temp_date-timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz') GT_rad_path = os.path.join(GT_rad_dir, name) GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:]) if(os.path.isfile(pred_sat_path)): temp.append(pred_sat_path) if(os.path.isfile(GT_rad_path)): temp.append(GT_rad_path) if(os.path.isfile(GT_sat_path)): temp.append(GT_sat_path) if(len(temp) == 4): list_dir.append(temp) return list_dir def gen_list_img_full(self,path): pred_rad_dir = os.path.join(path,"rad") pred_sat_dir = os.path.join(path,"sat") GT_rad_dir = os.path.join(path ,"rad") GT_sat_dir = os.path.join(path,"sat") list_dir = [] for name in os.listdir(pred_rad_dir): temp = [] if(not name.endswith("00.npz") and not name.endswith("03.npz")): continue temp_date = self.get_date_time(name) temp.append(os.path.join(pred_rad_dir,name)) pred_sat_path = os.path.join(pred_sat_dir,temp_date.strftime('%Y%m%d%H')+'.npz') GT_rad_path = os.path.join(GT_rad_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H%M') + '.npz') GT_sat_path = os.path.join(GT_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz') if(os.path.isfile(pred_sat_path)): temp.append(pred_sat_path) if(os.path.isfile(GT_rad_path)): temp.append(GT_rad_path) if(os.path.isfile(GT_sat_path)): temp.append(GT_sat_path) if(len(temp) == 4): list_dir.append(temp) return list_dir def gen_list_img_time(self,path): pred_rad_dir =os.path.join(path,"pred_rad") pred_sat_dir = os.path.join(path,"pred_sat") GT_rad_dir = os.path.join(path ,"rad") GT_sat_dir = os.path.join(path,"sat") list_dir = [] for name in os.listdir(pred_rad_dir): temp = [[],[],[],[]] temp_date = self.get_date_time(name) if(not name.endswith("00.npz") and not name.endswith("03.npz")): continue for i in range(4): temp_path = os.path.join(GT_rad_dir, (temp_date+timedelta(minutes=-210+i*10)).strftime('%Y%m%d%H%M') + '.npz') if(os.path.isfile(temp_path)): temp[0].append(temp_path) for i in range(1): temp_path = os.path.join(GT_sat_dir, (temp_date+timedelta(minutes=-180+i*10)).strftime('%Y%m%d%H') + '.npz') if(os.path.isfile(temp_path)): temp[1].append(temp_path) temp[0].append(os.path.join(pred_rad_dir,name)) pred_sat_path = os.path.join(pred_sat_dir,name[0:-6]+name[-4:]) GT_rad_path = os.path.join(GT_rad_dir, name) GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:]) if(os.path.isfile(pred_sat_path)): temp[1].append(pred_sat_path) if(os.path.isfile(GT_rad_path)): temp[2].append(GT_rad_path) if(os.path.isfile(GT_sat_path)): temp[3].append(GT_sat_path) if(len(temp[0]) == 5 and len(temp[1]) == 2 and len(temp[2]) == 1 and len(temp[3]) == 1): list_dir.append(temp) return list_dir def get_date_time(self,name): year=int(name[0:4]) month=int(name[4:6]) day=int(name[6:8]) hour=int(name[8:10]) minute = int(name[10:12]) return datetime(year,month,day,hour,minute) class WeatherForecastDataModule(LightningDataModule): def __init__( self, dir_data: str, batch_size:int , hours_predicted :int, num_workers:int , pin_memory: bool , time_points_rad : int, time_points_sat : int, sat_inp_vars: str, sat_out_vars : str, sat_size: int, rad_inp_vars : str, rad_out_vars : str, rad_size: int, ablation: str, ): super().__init__() # this line allows to access init params with 'self.hparams' attribute self.save_hyperparameters(logger=True) self.data_train = None self.data_test = None self.data_val = None self.rad_mean = np.load(os.path.join(self.hparams.dir_data,'rad_mean.npz'))[self.hparams.rad_inp_vars] self.rad_std = np.load(os.path.join(self.hparams.dir_data,'rad_std.npz'))[self.hparams.rad_inp_vars] self.sat_mean = np.load(os.path.join(self.hparams.dir_data,'sat_mean.npz'))[self.hparams.sat_inp_vars] self.sat_std = np.load(os.path.join(self.hparams.dir_data,'sat_std.npz'))[self.hparams.sat_inp_vars] def prepare_data(self): pass def setup(self, stage): # print(self.hparams.dir_data) self.data_train = DataReader( dir_data=self.hparams.dir_data, type_data= "train", rad_attribute = self.hparams.rad_inp_vars, sat_attribute = self.hparams.sat_inp_vars, hours_predicted = self.hparams.hours_predicted, rad_predicted = self.hparams.rad_out_vars, sat_predicted = self.hparams.sat_out_vars, time_points_rad = self.hparams.time_points_rad, time_points_sat = self.hparams.time_points_sat, sat_size = self.hparams.sat_size, rad_size = self.hparams.rad_size, ablation = self.hparams.ablation ) self.data_test = DataReader( dir_data=self.hparams.dir_data, type_data ="test", rad_attribute = self.hparams.rad_inp_vars, sat_attribute = self.hparams.sat_inp_vars, hours_predicted = self.hparams.hours_predicted, rad_predicted = self.hparams.rad_out_vars, sat_predicted = self.hparams.sat_out_vars, time_points_rad = self.hparams.time_points_rad, time_points_sat = self.hparams.time_points_sat, sat_size = self.hparams.sat_size, rad_size = self.hparams.rad_size, ablation = self.hparams.ablation ) self.data_val = DataReader( dir_data=self.hparams.dir_data, type_data = "val", rad_attribute = self.hparams.rad_inp_vars, sat_attribute = self.hparams.sat_inp_vars, hours_predicted = self.hparams.hours_predicted, rad_predicted = self.hparams.rad_out_vars, sat_predicted = self.hparams.sat_out_vars, time_points_rad = self.hparams.time_points_rad, time_points_sat = self.hparams.time_points_sat, sat_size = self.hparams.sat_size, rad_size = self.hparams.rad_size, ablation = self.hparams.ablation ) def train_dataloader(self): return DataLoader( self.data_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, drop_last=False, pin_memory=self.hparams.pin_memory, shuffle=True, ) def val_dataloader(self): return DataLoader( self.data_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, drop_last=False, pin_memory=self.hparams.pin_memory, shuffle=False, ) def test_dataloader(self): return DataLoader( self.data_test, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, drop_last=False, pin_memory=self.hparams.pin_memory, shuffle=False, )