|
|
| from datetime import datetime, timedelta |
| import glob |
| import os |
|
|
| from fire import Fire |
| import h5py |
| from matplotlib import pyplot as plt |
| import numpy as np |
| import pathlib |
| import cartopy |
| import wradlib as wrl |
| import dask.array as da |
| import pandas as pd |
| import xarray as xr |
| import re |
| from datetime import datetime, timedelta |
| from ldcast import forecast |
| from ldcast.visualization import plots |
| import torch |
| MAP_PROJECTION = cartopy.crs.PlateCarree() |
|
|
|
|
| def get_date_time(name): |
| year=int(name[0:4]) |
| month=int(name[4:6]) |
| day=int(name[6:8]) |
| hour=int(name[8:10]) |
| minutes=int(name[10:12]) |
| return datetime(year,month,day,hour,minutes) |
|
|
| def demo( |
| ldm_weights_fn="/data/data_WF/ablation/ablation_time/genforecast-radaronly-256x256-20step.pt", |
| autoenc_weights_fn="/home/mmhk20/weather_forecast/ldcast/models/autoenc/autoenc-32-0.01.pt", |
| num_diffusion_iters=50, |
| out_dir="/data/data_WF/ablation/ablation_time/train/predict_radar", |
| data_dir="/data/data_WF/ablation/ablation_time/train/GT_radar", |
| t0=datetime(2020,12,31,21,40), |
| interval=timedelta(minutes=10), |
| past_timesteps=4, |
| crop_box=((0,640), (0,640)), |
| draw_border=False, |
| ensemble_members=6, |
| ): |
| filtered_files = [] |
| for filename in os.listdir(data_dir): |
| if filename.endswith("00.npy") and (not filename.startswith("2020")): |
| filtered_files.append(filename) |
| if filename.endswith("03.npy") and (not filename.startswith("2020")): |
| filtered_files.append(filename) |
| sorted_files = sorted(filtered_files) |
| fc1 = forecast.Forecast( |
| ldm_weights_fn=ldm_weights_fn, |
| autoenc_weights_fn=autoenc_weights_fn, |
| future_timesteps=20, |
| gpu=0 |
| ) |
| index = 27600 |
| for i in range(index, 27750): |
| print(i) |
| print(sorted_files[i]) |
| name_datetime = get_date_time(sorted_files[i]) |
| temp_radar = [] |
| for i in range(0,4): |
| temp_time = name_datetime + timedelta(minutes = i*10 - 30) |
| temp_path = os.path.join(data_dir,temp_time.strftime('%Y%m%d%H%M')+'.npy') |
| if(os.path.exists(temp_path)): |
| temp_radar.append(temp_path) |
| if(len(temp_radar) == 4): |
| with torch.no_grad(): |
| R_past = torch.stack([torch.tensor(np.load(temp_radar[i])) for i in range(0, len(temp_radar))],dim=0) |
| R_past.to('cuda:0') |
| R_pred = fc1(R_past,num_diffusion_iters=num_diffusion_iters) |
| pred_time = name_datetime + timedelta(minutes = 180) |
| path_save = os.path.join(out_dir,pred_time.strftime('%Y%m%d%H%M')+'.npy') |
| np.save(path_save,R_pred[-3]) |
| |
| if __name__ == "__main__": |
| Fire(demo) |
| |
|
|