import os from torch.utils.data import DataLoader, Dataset, random_split import numpy as np from datetime import datetime, timedelta from torchvision import transforms from lightning import LightningDataModule from pytorch_lightning.cli import LightningCLI from torch.utils.data import DataLoader import pytorch_lightning as L import torch import torch.nn as nn from typing import Tuple, Dict, List from einops import rearrange # import optim class DataReader(Dataset): def __init__( self, dir_data : str, type_data : str, radar_attribute :List[str] , sat_attribute : List[str], hours_predicted : int, rad_predicted : str, sat_predicted : str, time_points_radar : int, time_points_sat : int, short_timestep: bool, augmentation: bool, servir_format : bool = False ): """ Input: dir_data: directory to data folder (contain train, test, val) type: type of dataset (train, test, val) radar_attribute: list of radar field used sat_attribute: list of satellite field used hours_predicted: predicted image's hour rad_predicted: the attribute of predicted radar image sat_predicted: the attribute of predicted satellite image Output: input_radar: tensor: (number attributes, 400,400) input_satellite: tensor (number attributes, 25,25) output_radar: tensor (1, 400,400) output_satellite: tensor (1, 25,25) Note: Each radar image has 2 channels (precipitation and velocity) Each satellite image has 5 channels (10m_u_component_of_wind, 10m_v_component_of_wind, 2m_dewpoint_temperature, 2m_temperature, total_precipitation) """ super().__init__() self.base_dir=dir_data self.type_data = type_data if self.type_data == "train": self.dir_data=os.path.join(dir_data, "train") elif self.type_data =="test": self.dir_data=os.path.join(dir_data, 'test') elif self.type_data =="val": self.dir_data=os.path.join(dir_data, 'val') else: raise ValueError("Type must be train, test or val") self.augmentation=augmentation self.short_timestep=short_timestep self.hours_predicted = hours_predicted self.radar_attribute = radar_attribute self.sat_attribute = sat_attribute self.rad_predicted = rad_predicted self.sat_predicted = sat_predicted self.time_points_radar = time_points_radar self.time_points_sat = time_points_sat self.list_transform_radar = None self.list_transform_satellite = None self.servir_format = servir_format # Create path for img self.dir_img_radar = os.path.join(self.dir_data, "rad") self.dir_img_satellite = os.path.join(self.dir_data, "sat") self.rad_mean = np.load(os.path.join(self.base_dir, "rad_mean.npz")) self.rad_std = np.load(os.path.join(self.base_dir, "rad_std.npz")) self.sat_mean = np.load(os.path.join(self.base_dir, "sat_mean.npz")) self.sat_std = np.load(os.path.join(self.base_dir, "sat_std.npz")) #Create transform self.create_transform() #Get list img self.list_img_radar, self.list_img_satellite= self.gen_list_img(self.dir_img_radar,self.dir_img_satellite) # self.__len__() def __len__(self): # print('lendataset', len(self.list_img_radar)) return len(self.list_img_radar) def __getitem__(self, idx:int)->Tuple[torch.tensor,torch.tensor,torch.tensor,torch.tensor]: """ :param idx: index of data :return: input_radar (dictionary 2 fileds: 4x400x400 tensor), input_satellite (1 tensor 8x25x25), output_radar (1 tensor x25x25), output_satellite (1 tensor 1x25x25) """ #Get input and output if self.servir_format: input_radar={} input_satellite={} inp_radar=[np.load(self.list_img_radar[idx][i]) for i in range(0,self.time_points_radar)] #Transform data input_radar_temp={key: [] for key in self.radar_attribute} for x in inp_radar: for key in self.radar_attribute: input_radar_temp[key].append(self.list_transform_radar[key](x[key])) for key in self.radar_attribute: input_radar_temp[key]=torch.cat(input_radar_temp[key],dim=0).float() input_radar=input_radar_temp if len(self.radar_attribute)==1: output = input_radar[self.radar_attribute[0]] # N,C,T,H,W => NTHWC target_bchw = rearrange(output, "t h w -> t h w 1").contiguous() return target_bchw else: return [input_radar[key] for key in self.radar_attribute ] else: input_radar={} output_radar={} input_satellite={} output_satellite={} inp_radar=[np.load(self.list_img_radar[idx][i]) for i in range(0,self.time_points_radar)] inp_satellite=np.load(self.list_img_satellite[idx][0]) out_radar=np.load(self.list_img_radar[idx][-1]) out_satellite=np.load(self.list_img_satellite[idx][-1]) #Transform data input_radar_temp={key: [] for key in self.radar_attribute} for x in inp_radar: for key in self.radar_attribute: input_radar_temp[key].append(self.list_transform_radar[key](x[key])) for key in self.radar_attribute: input_radar_temp[key]=torch.cat(input_radar_temp[key],dim=0).float() temp_list=[] for x in self.radar_attribute: temp_list.append(input_radar_temp[x]) input_radar=torch.stack(temp_list,dim=0) input_satellite= torch.cat([self.list_transform_satellite[key](inp_satellite[key]).permute(1,2,0) for key in self.sat_attribute],dim=0).float() output_radar=self.list_transform_radar[self.rad_predicted](out_radar[self.rad_predicted]).float() output_satellite=self.list_transform_satellite[self.sat_predicted](out_satellite[self.sat_predicted]).permute(1,2,0).float() return input_radar,input_satellite,output_radar,output_satellite def create_transform(self): self.list_transform_radar = { key : transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(self.rad_mean[key], self.rad_std[key]), transforms.Resize((200,200)) ] ) for key in self.rad_mean.keys() } self.list_transform_satellite ={ key : transforms.Compose( [ transforms.ToTensor(), transforms.Normalize(self.sat_mean[key], self.sat_std[key]) ] ) for key in self.sat_mean.keys() } def gen_list_img(self,dir_img_radar: str, dir_img_satellite: str): if self.servir_format: radar_train_list_dir = [] satellite_train_list_dir = [] list_radar_img = os.listdir(self.dir_img_radar) # print(self.dir_img_radar,len(list_radar_img)) for i,name in enumerate(list_radar_img): name_datetime=self.get_date_time(name) #Get sattelite image temp_sattelite = [] timeadd_radar=timedelta(minutes=10*self.time_points_radar) out_datetime_radar=name_datetime+timeadd_radar #Get radar image temp_radar=self.get_input_radar(name_datetime,out_datetime_radar) # print(len(temp_radar)) if len(temp_radar) == self.time_points_radar: radar_train_list_dir.append(temp_radar) return radar_train_list_dir,[] else: radar_train_list_dir = [] satellite_train_list_dir = [] list_radar_img = os.listdir(dir_img_radar) list_satellite_img = os.listdir(dir_img_satellite) for i,name in enumerate(list_satellite_img): name_datetime=self.get_date_time(name) #Get sattelite image temp_sattelite = [] timeadd=timedelta(hours=self.hours_predicted) out_datetime=name_datetime+timeadd out_img_sattelite=os.path.join(self.dir_img_satellite,str(out_datetime.year)+f"{out_datetime.month:02}"+f"{out_datetime.day:02}"+f"{out_datetime.hour:02}"+".npz") if os.path.exists(out_img_sattelite): temp_sattelite+=[os.path.join(self.dir_img_satellite,name),out_img_sattelite] #Get radar image temp_radar=self.get_input_radar(name_datetime,out_datetime) if len(temp_radar) == self.time_points_radar+1 and len(temp_sattelite) == self.time_points_sat+1: radar_train_list_dir.append(temp_radar) satellite_train_list_dir.append(temp_sattelite) if self.augmentation: temp_radar_add_3=self.get_input_radar(name_datetime,out_datetime,time_add=3) if len(temp_radar_add_3) == self.time_points_radar+1 and len(temp_sattelite) == self.time_points_sat+1: radar_train_list_dir.append(temp_radar_add_3) satellite_train_list_dir.append(temp_sattelite) return radar_train_list_dir,satellite_train_list_dir def get_date_time(self,name): year=int(name[0:4]) month=int(name[4:6]) day=int(name[6:8]) hour=int(name[8:10]) return datetime(year,month,day,hour) def get_input_radar(self,name_datetime,out_datetime, time_add=0): if self.servir_format: temp_radar=[] for i in range(0,self.time_points_radar): if self.short_timestep: timeadd=timedelta(minutes=10*(int(i/2))+int(7*(i%2))) addditional_time=timedelta(minutes= (3 if i%2==0 else 7)) else: timeadd=timedelta(minutes=10*i) addditional_time=timedelta(minutes=time_add) input_radar_datetime=name_datetime-timeadd+addditional_time input_img_radar=os.path.join( self.dir_img_radar, str(input_radar_datetime.year)+f"{input_radar_datetime.month:02}"+f"{input_radar_datetime.day:02}"+f"{input_radar_datetime.hour:02}"+f"{input_radar_datetime.minute:02}"+".npz" ) if os.path.exists(input_img_radar): temp_radar+=[input_img_radar] else: break return temp_radar else: temp_radar=[] for i in range(0,self.time_points_radar): if self.short_timestep: timeadd=timedelta(minutes=10*(int(i/2))+int(7*(i%2))) addditional_time=timedelta(minutes= (3 if i%2==0 else 7)) else: timeadd=timedelta(minutes=10*i) addditional_time=timedelta(minutes=time_add) input_radar_datetime=name_datetime-timeadd+addditional_time input_img_radar=os.path.join( self.dir_img_radar, str(input_radar_datetime.year)+f"{input_radar_datetime.month:02}"+f"{input_radar_datetime.day:02}"+f"{input_radar_datetime.hour:02}"+f"{input_radar_datetime.minute:02}"+".npz" ) if os.path.exists(input_img_radar): temp_radar+=[input_img_radar] else: break addditional_time=timedelta(minutes=time_add) out_datetime=out_datetime+addditional_time out_img_radar=os.path.join( self.dir_img_radar, str(out_datetime.year)+f"{out_datetime.month:02}"+f"{out_datetime.day:02}"+f"{out_datetime.hour:02}"+"00"+".npz" ) if(os.path.exists(out_img_radar)): temp_radar+=[out_img_radar] return temp_radar class WeatherForecastDataModuleOld(LightningDataModule): def __init__( self, dir_data: str, batch_size:int , hours_predicted :int, num_workers:int , pin_memory: bool , time_points_radar : int, time_points_sat : int, sat_inp_vars: List[str], sat_out_vars : str, rad_inp_vars : List[str], rad_out_vars : str , short_timestep : bool, augmentation: bool, rebuild_val: bool, servir_format: bool, ): """ dir_img_radar: directory to radar image folder dir_img_satellite: directory to satellite image folder batch_size: batch size """ super(WeatherForecastDataModuleOld, self).__init__() self.save_hyperparameters(logger=False) # this line allows to access init params with 'self.hparams' attribute self.data_train = None self.data_test = None self.data_val = None self.dir_data = dir_data self.batch_size = batch_size self.hours_predicted = hours_predicted self.num_workers = num_workers self.pin_memory = pin_memory self.time_points_radar = time_points_radar self.time_points_sat = time_points_sat self.sat_inp_vars = sat_inp_vars self.sat_out_vars = sat_out_vars self.rad_inp_vars = rad_inp_vars self.rad_out_vars = rad_out_vars self.short_timestep = short_timestep self.augmentation = augmentation self.rebuild_val = rebuild_val self.servir_format = servir_format self.rad_mean = np.load(os.path.join(self.dir_data, "rad_mean.npz")) self.rad_std = np.load(os.path.join(self.dir_data, "rad_std.npz")) self.sat_mean = np.load(os.path.join(self.dir_data, "sat_mean.npz")) self.sat_std = np.load(os.path.join(self.dir_data, "sat_std.npz")) # def prepare_data(self): # pass def setup(self, stage): # print(self.dir_data) self.data_train = DataReader( dir_data=self.dir_data, type_data= "train", radar_attribute = self.rad_inp_vars, sat_attribute = self.sat_inp_vars, hours_predicted = self.hours_predicted, rad_predicted = self.rad_out_vars, sat_predicted = self.sat_out_vars, time_points_radar = self.time_points_radar, time_points_sat = self.time_points_sat, short_timestep = self.short_timestep, augmentation = self.augmentation, servir_format=self.servir_format ) if self.rebuild_val: self.old_data_test = DataReader( dir_data=self.dir_data, type_data ="test", radar_attribute = self.rad_inp_vars, sat_attribute = self.sat_inp_vars, hours_predicted = self.hours_predicted, rad_predicted = self.rad_out_vars, sat_predicted = self.sat_out_vars, time_points_radar = self.time_points_radar, time_points_sat = self.time_points_sat, short_timestep = self.short_timestep, augmentation = self.augmentation, servir_format=self.servir_format ) total_length= len(self.old_data_test) length_val = int(total_length/6) length_test = total_length-length_val self.data_test, self.data_val = random_split(self.old_data_test,[length_test,length_val]) else: self.data_test = DataReader( dir_data=self.dir_data, type_data ="test", radar_attribute = self.rad_inp_vars, sat_attribute = self.sat_inp_vars, hours_predicted = self.hours_predicted, rad_predicted = self.rad_out_vars, sat_predicted = self.sat_out_vars, time_points_radar = self.time_points_radar, time_points_sat = self.time_points_sat, short_timestep = self.short_timestep, augmentation = self.augmentation, servir_format=self.servir_format ) self.data_val = DataReader( dir_data=self.dir_data, type_data = "val", radar_attribute = self.rad_inp_vars, sat_attribute = self.sat_inp_vars, hours_predicted = self.hours_predicted, rad_predicted = self.rad_out_vars, sat_predicted = self.sat_out_vars, time_points_radar = self.time_points_radar, time_points_sat = self.time_points_sat, short_timestep = self.short_timestep, augmentation = self.augmentation, servir_format=self.servir_format ) def train_dataloader(self): return DataLoader( self.data_train, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=False, pin_memory=self.pin_memory, shuffle=True, ) def val_dataloader(self): return DataLoader( self.data_val, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=False, pin_memory=self.pin_memory, shuffle=False, ) def test_dataloader(self): return DataLoader( self.data_test, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=False, pin_memory=self.pin_memory, shuffle=False, )