| from datetime import timedelta
|
| import gc
|
| import gzip
|
| import os
|
| import pickle
|
|
|
| import numpy as np
|
|
|
| from ldcast.features import batch, patches, split, transform
|
|
|
|
|
| file_dir = os.path.dirname("/data/data_WF/ldcast_precipitation/ldcast/")
|
|
|
| def setup_data(
|
| use_obs=True,
|
| use_nwp=False,
|
| obs_vars=("RZC",),
|
| nwp_vars=(
|
| "cape", "cin", "rate-cp", "rate-tp", "t2m",
|
| "tclw", "tcwv", "u", "v"
|
| ),
|
| nwp_lags=(0,12),
|
| target_var="RZC",
|
| batch_size=8,
|
| past_timesteps=4,
|
| future_timesteps=20,
|
| timestep_secs=300,
|
| nwp_timestep_secs=3600,
|
| sampler_file=None,
|
|
|
| chunks_file="./data/split_chunks.pkl.gz",
|
| sample_shape=(4,4)
|
| ):
|
| target = target_var + "-T"
|
| predictors_obs = [v + "-O" for v in obs_vars]
|
| predictors = []
|
| if use_obs:
|
| predictors += predictors_obs
|
| if use_nwp:
|
| predictors.append("nwp")
|
|
|
| variables = {
|
| target: {
|
| "sources": [target_var],
|
| "timesteps": np.arange(1,future_timesteps+1),
|
| }
|
| }
|
| for (var, raw_var) in zip(predictors_obs, obs_vars):
|
| variables[var] = {
|
| "sources": [raw_var],
|
| "timesteps": np.arange(-past_timesteps+1,1)
|
| }
|
| nwp_t1 = int(np.ceil(future_timesteps*timestep_secs/nwp_timestep_secs)) + 2
|
| nwp_range = np.arange(nwp_t1)
|
| variables["nwp"] = {
|
| "sources": nwp_vars,
|
| "timesteps": nwp_range,
|
| "timestep_secs": nwp_timestep_secs
|
| }
|
|
|
|
|
| raw_vars = set.union(
|
| *(set(variables[v]["sources"]) for v in predictors_obs+[target])
|
| )
|
| if use_nwp:
|
| for raw_var_base in variables["nwp"]["sources"]:
|
| raw_vars.update(f"{raw_var_base}-{lag}" for lag in nwp_lags)
|
| raw = {
|
| var: patches.load_all_patches(
|
| os.path.join(file_dir, f"./data/{var}/"), var
|
|
|
| )
|
| for var in raw_vars
|
| }
|
|
|
|
|
|
|
| with gzip.open(os.path.join(file_dir, chunks_file), 'rb') as f:
|
| chunks = pickle.load(f)
|
| (raw, _) = split.train_valid_test_split(raw, var, chunks=chunks)
|
|
|
| transform_rain = lambda: transform.default_rainrate_transform(
|
| raw["train"]["RZC"]["scale"]
|
| )
|
| transform_cape = lambda: transform.normalize_threshold(
|
| log=True,
|
| threshold=1.0, fill_value=1.0,
|
| mean=1.530, std=0.859
|
| )
|
| transform_rate_tp = lambda: transform.normalize_threshold(
|
| log=True,
|
| threshold=1e-5, fill_value=1e-5,
|
| mean=-3.831, std=0.650
|
| )
|
| transform_wind = lambda: transform.normalize(std=9.44)
|
|
|
| transforms = {
|
| "RZC-T": transform_rain(),
|
| "RZC-O": transform_rain(),
|
| "cape": transform_cape(),
|
| "cin": transform_cape(),
|
| "rate-tp": transform_rate_tp(),
|
| "rate-cp": transform_rate_tp(),
|
| "t2m": transform.normalize(mean=286.069, std=7.323),
|
| "tclw": transform.normalize_threshold(
|
| log=True,
|
| threshold=0.001, fill_value=0.001,
|
| mean=-1.486, std=0.638
|
| ),
|
| "tcwv": transform.normalize(std=17.307),
|
| "u": transform_wind(),
|
| "v": transform_wind()
|
| }
|
| transforms["nwp"] = transform.combine([transforms[v] for v in nwp_vars])
|
| for (var_name, var_data) in variables.items():
|
| var_data["transform"] = transforms[var_name]
|
|
|
| if sampler_file is None:
|
| sampler_file = {
|
|
|
|
|
|
|
| "train": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_nowcaster_train.pkl",
|
| "valid": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_nowcaster_valid.pkl",
|
| "test": "/data/data_WF/ldcast_precipitation/ldcast/cache/sampler_nowcaster_test.pkl",
|
| }
|
| bins = np.exp(np.linspace(np.log(0.2), np.log(50), 10))
|
| datamodule = split.DataModule(
|
| variables, raw, predictors, target, target,
|
| forecast_raw_vars=nwp_vars,
|
| interval=timedelta(seconds=timestep_secs),
|
| batch_size=batch_size, sampling_bins=bins,
|
| time_range_sampling=(-past_timesteps+1,future_timesteps+1),
|
| sampler_file=sampler_file,
|
| sample_shape=sample_shape,
|
| valid_seed=1234, test_seed=2345,
|
| )
|
|
|
| gc.collect()
|
| return datamodule
|
|
|