ldcast_code / ldcast /features /patches.py
weatherforecast1024's picture
Upload folder using huggingface_hub
d2f661a verified
from datetime import datetime, timedelta
import os
import dask
import netCDF4
import numpy as np
from .utils import average_pool
def patch_locations(
time_range,
patch_box,
patch_shape=(32,32),
interval=timedelta(minutes=5),
epoch=(1970,1,1)
):
patches = {}
t = time_range[0]
while t < time_range[1]:
patches[t] = []
for pi in range(patch_box[0][0], patch_box[0][1]):
for pj in range(patch_box[1][0], patch_box[1][1]):
patches[t].append((pi,pj))
patches[t] = np.array(patches[t])
t += interval
return patches
def save_patches_radar(
patches, archive_path, out_dir,
variables=("RZC", "CPCH"),
suffix="2020",
**kwargs
):
from ..datasets import mchradar
source_vars = {}
mchradar_reader = mchradar.MCHRadarReader(
archive_path=archive_path,
variables=variables,
phys_values=False
)
ezc_nonzero_count_func = lambda x: np.count_nonzero((x >= 1) & (x<251))
nonzero_count_func = {
"RZC": lambda x: np.count_nonzero(x > 1),
"CPCH": lambda x: np.count_nonzero(x > 1)
}
zero_value = {v: 0 for v in variables}
zero_value["RZC"] = 1
zero_value["CPCH"] = 1
save_patches_all(
mchradar_reader, patches, variables,
nonzero_count_func, zero_value, out_dir, suffix,
source_vars=source_vars, min_nonzeros_to_include=5,
**kwargs
)
def save_patches_dwdradar(
patches, archive_path, out_dir,
variables=("RV",),
suffix="2022",
patch_shape=(32,32),
**kwargs
):
from ..datasets import dwdradar
source_vars = {}
dwdradar_reader = dwdradar.DWDRadarReader(
archive_path=archive_path,
variables=variables
)
patches_flt = {}
for t in sorted(patches):
if (t.hour==0) and (t.minute==0):
print(t)
try:
data = dwdradar_reader.variable_for_time(t, "RV")
except FileNotFoundError:
continue
patch_locs_time = []
for (pi,pj) in patches[t]:
i0 = pi * patch_shape[0]
i1 = i0 + patch_shape[0]
j0 = pj * patch_shape[1]
j1 = j0 + patch_shape[1]
patch = data[i0:i1,j0:j1]
if np.isfinite(patch).all():
patch_locs_time.append((pi,pj))
if patch_locs_time:
patches_flt[t] = np.array(patch_locs_time)
print(len(patches),len(patches_flt))
patches = patches_flt
nonzero_count_func = {"RV": np.count_nonzero}
zero_value = {v: 0 for v in variables}
save_patches_all(
dwdradar_reader, patches, variables,
nonzero_count_func, zero_value, out_dir, suffix,
source_vars=source_vars, min_nonzeros_to_include=5,
**kwargs
)
def save_patches_ifs(
patches, archive_path, out_dir,
variables=(
#"rate-tp", "rate-cp", "t2m", "cape", "cin",
#"tclw", "tcwv", #, "rate-tpe"
#"u", "v"
"cin",
),
suffix="2020",
lags=(0,12)
):
from ..datasets import ifsnwp
from .. import projection
proj = projection.GridProjection(projection.ccs4_swiss_grid_area)
ifsnwp_reader = ifsnwp.IFSNWPReader(
proj,
archive_path=archive_path,
variables=variables,
lags=lags,
)
# we only get data for every hour, so modify patches
ifs_patches = {
dt: pset for (dt, pset) in patches.items()
if (dt.minute == dt.second == dt.microsecond == 0)
}
variables_with_lag = []
for lag in ifsnwp_reader.lags:
variables_with_lag.extend(f"{v}-{lag}" for v in variables)
count_positive = lambda x: np.count_nonzero(x > 0)
all_nonzero = lambda x: np.prod(x.shape)
nonzero_count_func = {
"rate-tp": count_positive,
"rate-cp": count_positive,
"t2m": all_nonzero,
"cape": count_positive,
"cin": count_positive,
"tclw": count_positive,
"tcwv": count_positive,
"u": all_nonzero,
"v": all_nonzero,
"rate-tpe": count_positive,
}
nonzero_count_func = {
v: nonzero_count_func[v.rsplit("-", 1)[0]]
for v in variables_with_lag
}
postproc = {
f"cin-{lag}": lambda x: np.nan_to_num(x, nan=0.0, copy=False)
for lag in lags
}
zero_value = {v: 0 for v in variables_with_lag}
avg_pool = lambda x: average_pool(x, factor=8, missing=np.nan)
pool = {v: avg_pool for v in variables_with_lag}
save_patches_all(ifsnwp_reader, ifs_patches, variables_with_lag,
nonzero_count_func, zero_value, out_dir, suffix, pool=pool,
postproc=postproc)
def save_patches_cosmo(patches, archive_path, out_dir, suffix="2020"):
from ..datasets import cosmonwp
cosmonwp_reader = cosmonwp.COSMOCCS4Reader(
archive_path=archive_path, cache_size=6000)
# we only get data for every hour, so modify patches
cosmo_patches = {}
for (dt,pset) in patches.items():
dt0 = datetime(dt.year, dt.month, dt.day, dt.hour)
dt1 = dt0 + timedelta(hours=1)
if dt0 not in cosmo_patches:
cosmo_patches[dt0] = set()
if dt1 not in cosmo_patches:
cosmo_patches[dt1] = set()
cosmo_patches[dt0].update(pset)
cosmo_patches[dt1].update(pset)
variables = [
"CAPE_MU", "CIN_MU", "SLI",
"HZEROCL", "LCL_ML", "MCONV", "OMEGA",
"T_2M", "T_SO", "SOILTYP"
]
count_positive = lambda x: np.count_nonzero(x>0)
all_nonzero = lambda x: np.prod(x.shape)
nonzero_count_func = {
"CAPE_MU": count_positive,
"CIN_MU": count_positive,
"SLI": all_nonzero,
"HZEROCL": count_positive,
"LCL_ML": count_positive,
"MCONV": all_nonzero,
"OMEGA": all_nonzero,
"T_2M": all_nonzero,
"T_SO": all_nonzero,
"SOILTYP": lambda x: np.count_nonzero(x!=5)
}
zero_value = {v: 0 for v in variables}
zero_value["SOILTYP"] = 5
save_patches_all(cosmonwp_reader, cosmo_patches, variables,
nonzero_count_func, zero_value, out_dir, suffix, pool=pool)
def save_patches_all(
reader, patches, variables, nonzero_count_func, zero_value,
out_dir, suffix, epoch=datetime(1970,1,1), postproc={}, scale=None,
pool={}, source_vars={}, parallel=False, min_nonzeros_to_include=1
):
def save_var(var_name):
src_name = source_vars.get(var_name, var_name)
(patch_data, patch_coords, patch_times,
zero_patch_coords, zero_patch_times) = get_patches(
reader, src_name, patches,
nonzero_count_func=nonzero_count_func[var_name],
postproc=postproc.get(var_name),
pool=pool.get(var_name)
)
try:
time = epoch + timedelta(seconds=int(patch_times[0]))
var_scale = reader.get_scale(time, var_name)
except (AttributeError, KeyError):
var_scale = None if (scale is None) else scale[var_name]
pass
var_name = var_name.replace("_", "-")
out_fn = f"patches_{var_name}_{suffix}.nc"
out_path = os.path.join(out_dir, var_name)
os.makedirs(out_path, exist_ok=True)
out_fn = os.path.join(out_path, out_fn)
save_patches(
patch_data, patch_coords, patch_times,
zero_patch_coords, zero_patch_times, out_fn,
zero_value=zero_value[var_name], scale=var_scale
)
if parallel:
save_var = dask.delayed(save_var)
jobs = [save_var(v) for v in variables]
if parallel:
dask.compute(jobs, scheduler='threads')
def get_patches(
reader, variable, patches,
patch_shape=(32,32), nonzero_count_func=None,
epoch=datetime(1970,1,1), postproc=None,
pool=None, min_nonzeros_to_include=1
):
num_patches = sum(len(patches[t]) for t in patches)
patch_data = []
patch_coords = []
patch_times = []
zero_patch_coords = []
zero_patch_times = []
if hasattr(reader, "phys_values"):
phys_values = reader.phys_values
k = 0
try:
if hasattr(reader, "phys_values"):
reader.phys_values = False
for (t, p_coord) in patches.items():
try:
data = reader.variable_for_time(t, variable)
except (ValueError, FileNotFoundError, KeyError, OSError):
continue
if postproc is not None:
data = postproc(data)
time_sec = np.int64((t-epoch).total_seconds())
for (pi, pj) in p_coord:
if k % 100000 == 0:
print("{}: {}/{}".format(t, k, num_patches))
patch_box = data[
pi*patch_shape[0]:(pi+1)*patch_shape[0],
pj*patch_shape[1]:(pj+1)*patch_shape[1],
].copy()
is_nonzero = (nonzero_count_func is not None) and \
(nonzero_count_func(patch_box) < min_nonzeros_to_include)
if is_nonzero:
zero_patch_coords.append((pi,pj))
zero_patch_times.append(time_sec)
else:
if pool is not None:
patch_box = pool(patch_box)
patch_data.append(patch_box)
patch_coords.append((pi,pj))
patch_times.append(time_sec)
k += 1
finally:
if hasattr(reader, "phys_values"):
reader.phys_values = phys_values
if zero_patch_coords:
zero_patch_coords = np.stack(zero_patch_coords, axis=0).astype(np.uint16)
zero_patch_times = np.stack(zero_patch_times, axis=0)
else:
zero_patch_coords = np.zeros((0,2), dtype=np.uint16)
zero_patch_times = np.zeros((0,), dtype=np.int64)
patch_data = np.stack(patch_data, axis=0)
patch_coords = np.stack(patch_coords, axis=0).astype(np.uint16)
patch_times = np.stack(patch_times, axis=0)
return (patch_data, patch_coords, patch_times,
zero_patch_coords, zero_patch_times)
def save_patches(patch_data, patch_coords, patch_times,
zero_patch_coords, zero_patch_times, out_fn, zero_value=0, scale=None):
with netCDF4.Dataset(out_fn, 'w') as ds:
dim_patch = ds.createDimension("dim_patch", patch_data.shape[0])
dim_zero_patch = ds.createDimension("dim_zero_patch", zero_patch_coords.shape[0])
dim_coord = ds.createDimension("dim_coord", 2)
dim_height = ds.createDimension("dim_height", patch_data.shape[1])
dim_width = ds.createDimension("dim_width", patch_data.shape[2])
var_args = {"zlib": True, "complevel": 1}
chunksizes = (min(2**10, patch_data.shape[0]), patch_data.shape[1], patch_data.shape[2])
var_patch = ds.createVariable("patches", patch_data.dtype,
("dim_patch","dim_height","dim_width"), chunksizes=chunksizes, **var_args)
var_patch[:] = patch_data
var_patch_coord = ds.createVariable("patch_coords", patch_coords.dtype,
("dim_patch","dim_coord"), **var_args)
var_patch_coord[:] = patch_coords
var_patch_time = ds.createVariable("patch_times", patch_times.dtype,
("dim_patch",), **var_args)
var_patch_time[:] = patch_times
var_zero_patch_coord = ds.createVariable("zero_patch_coords", zero_patch_coords.dtype,
("dim_zero_patch","dim_coord"), **var_args)
var_zero_patch_coord[:] = zero_patch_coords
var_zero_patch_time = ds.createVariable("zero_patch_times", zero_patch_times.dtype,
("dim_zero_patch",), **var_args)
var_zero_patch_time[:] = zero_patch_times
ds.zero_value = zero_value
if scale is not None:
dim_scale = ds.createDimension("dim_scale", len(scale))
var_scale = ds.createVariable("scale", scale.dtype, ("dim_scale",), **var_args)
var_scale[:] = scale
def load_patches(fn, in_memory=True):
if in_memory:
with open(fn, 'rb') as f:
ds_raw = f.read()
fn = None
else:
ds_raw = None
with netCDF4.Dataset(fn, 'r', memory=ds_raw) as ds:
patch_data = {
"patches": np.array(ds["patches"]),
"patch_coords": np.array(ds["patch_coords"]),
"patch_times": np.array(ds["patch_times"]),
"zero_patch_coords": np.array(ds["zero_patch_coords"]),
"zero_patch_times": np.array(ds["zero_patch_times"]),
"zero_value": ds.zero_value
}
if "scale" in ds.variables:
patch_data["scale"] = np.array(ds["scale"])
return patch_data
def load_all_patches(patch_dir, var):
files = os.listdir(patch_dir)
jobs = []
for fn in files:
file_var = fn.split("_")[1]
if file_var == var:
fn = os.path.join(patch_dir, fn)
jobs.append(dask.delayed(load_patches)(fn))
file_data = dask.compute(jobs, scheduler="processes")[0]
patch_data = {}
keys = ["patches", "patch_coords", "patch_times",
"zero_patch_coords", "zero_patch_times"]
for k in keys:
patch_data[k] = np.concatenate(
[fd[k] for fd in file_data],
axis=0
)
patch_data["zero_value"] = file_data[0]["zero_value"]
if "scale" in file_data[0]:
patch_data["scale"] = file_data[0]["scale"]
return patch_data
def unpack_patches(patch_data):
return (
patch_data["patches"],
patch_data["patch_coords"],
patch_data["patch_times"],
patch_data["zero_patch_coords"],
patch_data["zero_patch_times"]
)