unet_code / src /datamodule.py
weatherforecast1024's picture
Upload folder using huggingface_hub
f3b050a verified
import os
from torch.utils.data import DataLoader, Dataset, random_split
import numpy as np
from datetime import datetime, timedelta
from torchvision import transforms
from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.cli import LightningCLI
from torch.utils.data import DataLoader
import pytorch_lightning as L
import torch
import torch.nn as nn
from typing import Tuple, Dict, List
# import optim
class DataReader(Dataset):
def __init__(
self, dir_data : str,
type_data : str,
rad_attribute : str ,
sat_attribute : str,
hours_predicted : int,
rad_predicted : str ,
sat_predicted : str ,
time_points_rad : int,
time_points_sat : int,
rad_size:int,
sat_size:int,
ablation = str,
):
super().__init__()
self.base_dir=dir_data
self.type_data = type_data
if self.type_data == "train":
self.dir_data=os.path.join(dir_data, "train")
elif self.type_data =="test":
self.dir_data=os.path.join(dir_data, 'test')
elif self.type_data =="val":
self.dir_data=os.path.join(dir_data, 'val')
else:
raise ValueError("Type must be train, test or val")
self.sat_size = sat_size
self.rad_size = rad_size
self.hours_predicted = hours_predicted
self.rad_attribute = rad_attribute
self.sat_attribute = sat_attribute
self.rad_predicted = rad_predicted
self.sat_predicted = sat_predicted
self.time_points_rad = time_points_rad
self.time_points_sat = time_points_sat
self.transform_rad = None
self.transform_sat = None
self.ablation = ablation
# Create path for img
self.rad_mean = np.load(os.path.join(self.base_dir,'rad_mean.npz'))[self.rad_attribute]
self.rad_std = np.load(os.path.join(self.base_dir,'rad_std.npz'))[self.rad_attribute]
self.sat_mean = np.load(os.path.join(self.base_dir,'sat_mean.npz'))[self.sat_attribute]
self.sat_std = np.load(os.path.join(self.base_dir,'sat_std.npz'))[self.sat_attribute]
#Create transform
self.create_transform()
#Get list img
if(self.ablation == "no"):
self.list_img_dir = self.gen_list_img_no(self.dir_data)
elif(self.ablation == "rad"):
self.list_img_dir = self.gen_list_img_rad(self.dir_data)
elif(self.ablation == "sat"):
self.list_img_dir = self.gen_list_img_sat(self.dir_data)
elif(self.ablation == "full"):
self.list_img_dir = self.gen_list_img_full(self.dir_data)
elif(self.ablation == "time"):
self.list_img_dir = self.gen_list_img_time(self.dir_data)
else:
raise ValueError("Ablation must be no,rad,sat,full")
print(f"Number of {self.type_data } samples:",len(self.list_img_dir))
def __len__(self):
return len(self.list_img_dir)
def __getitem__(self, idx):
if(self.transform_rad):
inp_rad = self.transform_rad(np.load(self.list_img_dir[idx][0])[self.rad_attribute])
out_rad = self.transform_rad(np.load(self.list_img_dir[idx][2])[self.rad_predicted])
if(self.transform_sat):
inp_sat = self.transform_sat(np.load(self.list_img_dir[idx][1])[self.sat_attribute])
out_sat = self.transform_sat(np.load(self.list_img_dir[idx][3])[self.sat_predicted][0])
return inp_rad,inp_sat.float(),out_rad, out_sat.float()
def create_transform(self):
self.transform_rad = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(self.rad_mean,self.rad_std)
])
self.transform_sat = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(self.sat_mean[0],self.sat_std[0]),
])
# print("SAT_MEAN", self.sat_mean, self.sat_std)
def gen_list_img_no(self,path):
pred_rad_dir =os.path.join(path,"pred_rad")
pred_sat_dir = os.path.join(path,"pred_sat")
GT_rad_dir = os.path.join(path ,"rad")
GT_sat_dir = os.path.join(path,"sat")
list_dir = []
# print()
# print(len(os.listdir(pred_rad_dir)))
for name in os.listdir(pred_rad_dir):
temp = []
if(not name.endswith("00.npz") and not name.endswith("03.npz")):
continue
temp.append(os.path.join(pred_rad_dir,name))
pred_sat_path = os.path.join(pred_sat_dir,name[0:-6]+name[-4:])
GT_rad_path = os.path.join(GT_rad_dir, name)
GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
if(os.path.isfile(pred_sat_path)):
temp.append(pred_sat_path)
if(os.path.isfile(GT_rad_path)):
temp.append(GT_rad_path)
if(os.path.isfile(GT_sat_path)):
temp.append(GT_sat_path)
if(len(temp) == 4):
list_dir.append(temp)
return list_dir
def gen_list_img_rad(self,path):
pred_rad_dir = os.path.join(path,"rad")
pred_sat_dir = os.path.join(path,"pred_sat")
GT_rad_dir = os.path.join(path ,"rad")
GT_sat_dir = os.path.join(path,"sat")
list_dir = []
for name in os.listdir(pred_rad_dir):
temp = []
if( not name.endswith("00.npz") and not name.endswith("03.npz")):
continue
temp_date = self.get_date_time(name)
temp.append(os.path.join(pred_rad_dir,name))
pred_sat_path = os.path.join(pred_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
GT_rad_path = os.path.join(GT_rad_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H%M') + '.npz')
GT_sat_path = os.path.join(GT_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
if(os.path.isfile(pred_sat_path)):
temp.append(pred_sat_path)
if(os.path.isfile(GT_rad_path)):
temp.append(GT_rad_path)
if(os.path.isfile(GT_sat_path)):
temp.append(GT_sat_path)
if(len(temp) == 4):
list_dir.append(temp)
return list_dir
def gen_list_img_sat(self,path):
pred_rad_dir = os.path.join(path,"pred_rad")
pred_sat_dir = os.path.join(path,"sat")
GT_rad_dir = os.path.join(path ,"rad")
GT_sat_dir = os.path.join(path,"sat")
list_dir = []
for name in os.listdir(pred_rad_dir):
temp = []
if( not name.endswith("00.npz") and not name.endswith("03.npz")):
continue
temp_date = self.get_date_time(name)
temp.append(os.path.join(pred_rad_dir,name))
pred_sat_path = os.path.join(pred_sat_dir, (temp_date-timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
GT_rad_path = os.path.join(GT_rad_dir, name)
GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
if(os.path.isfile(pred_sat_path)):
temp.append(pred_sat_path)
if(os.path.isfile(GT_rad_path)):
temp.append(GT_rad_path)
if(os.path.isfile(GT_sat_path)):
temp.append(GT_sat_path)
if(len(temp) == 4):
list_dir.append(temp)
return list_dir
def gen_list_img_full(self,path):
pred_rad_dir = os.path.join(path,"rad")
pred_sat_dir = os.path.join(path,"sat")
GT_rad_dir = os.path.join(path ,"rad")
GT_sat_dir = os.path.join(path,"sat")
list_dir = []
for name in os.listdir(pred_rad_dir):
temp = []
if(not name.endswith("00.npz") and not name.endswith("03.npz")):
continue
temp_date = self.get_date_time(name)
temp.append(os.path.join(pred_rad_dir,name))
pred_sat_path = os.path.join(pred_sat_dir,temp_date.strftime('%Y%m%d%H')+'.npz')
GT_rad_path = os.path.join(GT_rad_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H%M') + '.npz')
GT_sat_path = os.path.join(GT_sat_dir, (temp_date+timedelta(hours=self.hours_predicted)).strftime('%Y%m%d%H') + '.npz')
if(os.path.isfile(pred_sat_path)):
temp.append(pred_sat_path)
if(os.path.isfile(GT_rad_path)):
temp.append(GT_rad_path)
if(os.path.isfile(GT_sat_path)):
temp.append(GT_sat_path)
if(len(temp) == 4):
list_dir.append(temp)
return list_dir
def gen_list_img_time(self,path):
pred_rad_dir =os.path.join(path,"pred_rad")
pred_sat_dir = os.path.join(path,"pred_sat")
GT_rad_dir = os.path.join(path ,"rad")
GT_sat_dir = os.path.join(path,"sat")
list_dir = []
for name in os.listdir(pred_rad_dir):
temp = [[],[],[],[]]
temp_date = self.get_date_time(name)
if(not name.endswith("00.npz") and not name.endswith("03.npz")):
continue
for i in range(4):
temp_path = os.path.join(GT_rad_dir, (temp_date+timedelta(minutes=-210+i*10)).strftime('%Y%m%d%H%M') + '.npz')
if(os.path.isfile(temp_path)): temp[0].append(temp_path)
for i in range(1):
temp_path = os.path.join(GT_sat_dir, (temp_date+timedelta(minutes=-180+i*10)).strftime('%Y%m%d%H') + '.npz')
if(os.path.isfile(temp_path)): temp[1].append(temp_path)
temp[0].append(os.path.join(pred_rad_dir,name))
pred_sat_path = os.path.join(pred_sat_dir,name[0:-6]+name[-4:])
GT_rad_path = os.path.join(GT_rad_dir, name)
GT_sat_path = os.path.join(GT_sat_dir, name[0:-6] + name[-4:])
if(os.path.isfile(pred_sat_path)):
temp[1].append(pred_sat_path)
if(os.path.isfile(GT_rad_path)):
temp[2].append(GT_rad_path)
if(os.path.isfile(GT_sat_path)):
temp[3].append(GT_sat_path)
if(len(temp[0]) == 5 and len(temp[1]) == 2 and len(temp[2]) == 1 and len(temp[3]) == 1):
list_dir.append(temp)
return list_dir
def get_date_time(self,name):
year=int(name[0:4])
month=int(name[4:6])
day=int(name[6:8])
hour=int(name[8:10])
minute = int(name[10:12])
return datetime(year,month,day,hour,minute)
class WeatherForecastDataModule(LightningDataModule):
def __init__(
self,
dir_data: str,
batch_size:int ,
hours_predicted :int,
num_workers:int ,
pin_memory: bool ,
time_points_rad : int,
time_points_sat : int,
sat_inp_vars: str,
sat_out_vars : str,
sat_size: int,
rad_inp_vars : str,
rad_out_vars : str,
rad_size: int,
ablation: str,
):
super().__init__()
# this line allows to access init params with 'self.hparams' attribute
self.save_hyperparameters(logger=True)
self.data_train = None
self.data_test = None
self.data_val = None
self.rad_mean = np.load(os.path.join(self.hparams.dir_data,'rad_mean.npz'))[self.hparams.rad_inp_vars]
self.rad_std = np.load(os.path.join(self.hparams.dir_data,'rad_std.npz'))[self.hparams.rad_inp_vars]
self.sat_mean = np.load(os.path.join(self.hparams.dir_data,'sat_mean.npz'))[self.hparams.sat_inp_vars]
self.sat_std = np.load(os.path.join(self.hparams.dir_data,'sat_std.npz'))[self.hparams.sat_inp_vars]
def prepare_data(self):
pass
def setup(self, stage):
# print(self.hparams.dir_data)
self.data_train = DataReader(
dir_data=self.hparams.dir_data,
type_data= "train",
rad_attribute = self.hparams.rad_inp_vars,
sat_attribute = self.hparams.sat_inp_vars,
hours_predicted = self.hparams.hours_predicted,
rad_predicted = self.hparams.rad_out_vars,
sat_predicted = self.hparams.sat_out_vars,
time_points_rad = self.hparams.time_points_rad,
time_points_sat = self.hparams.time_points_sat,
sat_size = self.hparams.sat_size,
rad_size = self.hparams.rad_size,
ablation = self.hparams.ablation
)
self.data_test = DataReader(
dir_data=self.hparams.dir_data,
type_data ="test",
rad_attribute = self.hparams.rad_inp_vars,
sat_attribute = self.hparams.sat_inp_vars,
hours_predicted = self.hparams.hours_predicted,
rad_predicted = self.hparams.rad_out_vars,
sat_predicted = self.hparams.sat_out_vars,
time_points_rad = self.hparams.time_points_rad,
time_points_sat = self.hparams.time_points_sat,
sat_size = self.hparams.sat_size,
rad_size = self.hparams.rad_size,
ablation = self.hparams.ablation
)
self.data_val = DataReader(
dir_data=self.hparams.dir_data,
type_data = "val",
rad_attribute = self.hparams.rad_inp_vars,
sat_attribute = self.hparams.sat_inp_vars,
hours_predicted = self.hparams.hours_predicted,
rad_predicted = self.hparams.rad_out_vars,
sat_predicted = self.hparams.sat_out_vars,
time_points_rad = self.hparams.time_points_rad,
time_points_sat = self.hparams.time_points_sat,
sat_size = self.hparams.sat_size,
rad_size = self.hparams.rad_size,
ablation = self.hparams.ablation
)
def train_dataloader(self):
return DataLoader(
self.data_train,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
drop_last=False,
pin_memory=self.hparams.pin_memory,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
self.data_val,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
drop_last=False,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)
def test_dataloader(self):
return DataLoader(
self.data_test,
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
drop_last=False,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)