ldcast_code / ldcast /features /sampling.py
weatherforecast1024's picture
Upload folder using huggingface_hub
d2f661a verified
from bisect import bisect_left
import multiprocessing
import dask
from numba import njit, prange, types
from numba.typed import Dict
import numpy as np
from .patches import unpack_patches
class EqualFrequencySampler:
def __init__(
self, bins, patch_data, patch_index,
sample_shape, time_range_valid, time_range_sampling=None,
timestep_secs=5*60,
random_seed=None, preselected_samples=None
):
binned_patches = bin_classify_patches_parallel(
bins,
*unpack_patches(patch_data),
zero_value=patch_data.get("zero_value", 0),
scale=patch_data.get("scale")
)
complete_ind = indices_with_complete_sample(
patch_index, sample_shape, time_range_valid, timestep_secs
)
if time_range_sampling is None:
time_range_sampling = time_range_valid
self.starting_ind = [
starting_indices_for_centers(
p, complete_ind, sample_shape, time_range_sampling, timestep_secs
)
for p in binned_patches
]
self.num_bins = len(self.starting_ind)
self.rng = np.random.RandomState(seed=random_seed)
self.preselected_samples = preselected_samples
self.current_ind = np.array([len(ind) for ind in self.starting_ind])
def get_bin_sample(self, bin_ind):
patches = self.starting_ind[bin_ind]
sample_ind = self.current_ind[bin_ind]
if sample_ind >= patches.shape[0]:
self.rng.shuffle(patches)
sample_ind = self.current_ind[bin_ind] = 0
else:
self.current_ind[bin_ind] += 1
return patches[sample_ind,:]
def __call__(self, num):
# sample each bin with equal probability
bins = self.rng.randint(self.num_bins, size=num)
coords = np.stack(
[self.get_bin_sample(b) for b in bins],
axis=0
)
return coords
def bin_classify_patches(
bins, patches, patch_coords, patch_times,
zero_patch_coords, zero_patch_times,
zero_value=0, metric_func=None,
scale=None,
):
if metric_func is None:
def metric_func(x):
xm = np.percentile(x, 99, axis=(1,2))
if np.issubdtype(x.dtype, np.integer):
xm = xm.round()
return xm.astype(x.dtype)
binned_patches = [[] for _ in range(len(bins)+1)]
def find_bin(value):
return bisect_left(bins, value)
zero_bin = find_bin(zero_value if scale is None else scale[zero_value])
for (t,(pi,pj)) in zip(zero_patch_times, zero_patch_coords):
binned_patches[zero_bin].append((t,pi,pj))
patch_metrics = metric_func(patches)
if scale is not None:
patch_metrics = scale[patch_metrics]
for (metric,t,(pi,pj)) in zip(patch_metrics, patch_times, patch_coords):
patch_bin = find_bin(metric)
binned_patches[patch_bin].append((t,pi,pj))
for i in range(len(binned_patches)):
if binned_patches[i]:
binned_patches[i] = np.array(binned_patches[i])
else:
binned_patches[i] = np.zeros((0,3), dtype=np.int64)
return binned_patches
def bin_classify_patches_parallel(
bins, patches, patch_coords, patch_times,
zero_patch_coords, zero_patch_times,
zero_value=0, metric_func=None,
scale=None,
):
num_patches = patches.shape[0]
num_zeros = zero_patch_coords.shape[0]
num_procs = multiprocessing.cpu_count()
tasks = []
for p in range(num_procs):
pk0 = int(round(num_patches*p/num_procs))
pk1 = int(round(num_patches*(p+1)/num_procs))
zk0 = int(round(num_zeros*p/num_procs))
zk1 = int(round(num_zeros*(p+1)/num_procs))
task = dask.delayed(bin_classify_patches)(
bins,
patches[pk0:pk1,...], patch_coords[pk0:pk1,...],
patch_times[pk0:pk1],
zero_patch_coords[zk0:zk1,...], zero_patch_times[zk0:zk1],
zero_value=zero_value, metric_func=metric_func,
scale=scale
)
tasks.append(task)
chunked_bins = dask.compute(tasks, scheduler="threads")[0]
n_bins = len(chunked_bins[0])
binned_patches = [
np.concatenate([cb[i] for cb in chunked_bins], axis=0)
for i in range(n_bins)
]
return binned_patches
def indices_with_complete_sample(
patch_index, sample_shape, time_range, timestep_secs
):
"""Check which locations will give a sample without missing data.
"""
ind = np.array(list(patch_index.patch_index.keys()))
t0 = ind[:,0]
i0 = ind[:,1]
j0 = ind[:,2]
n = ind.shape[0]
complete = np.ones(n, dtype=bool)
# we use this dict like a set - numba doesn't support typed sets
complete_ind = Dict.empty(
key_type=types.UniTuple(types.int64, 3),
value_type=types.uint8
)
@njit(parallel=True) # many nested loops, numba optimization needed
def check_complete(index, complete, complete_ind):
for k in prange(n):
for ts in range(*time_range):
t = t0[k] + ts*timestep_secs
for di in range(sample_shape[0]):
i = i0[k] + di
for dj in range(sample_shape[1]):
j = j0[k] + dj
if (t,i,j) not in index:
complete[k] = False
for k in range(n): # no prange: can't set dict items in parallel
if complete[k]:
complete_ind[(t0[k],i0[k],j0[k])] = np.uint8(0)
check_complete(patch_index.patch_index, complete, complete_ind)
return complete_ind
def starting_indices_for_centers(
centers, complete_ind, sample_shape, time_range, timestep_secs
):
"""Determine a complete list of sample indices that
contain one or more of the centerpoints.
"""
@njit
def find_indices(centers, starting_ind, complete_ind):
for k in range(centers.shape[0]):
t0 = centers[k,0]
i0 = centers[k,1]
j0 = centers[k,2]
for ts in range(*time_range):
t = t0 - ts*timestep_secs # note minus signs in (t,i,j)
for di in range(sample_shape[0]):
i = i0 - di
for dj in range(sample_shape[1]):
j = j0 - dj
if (t,i,j) in complete_ind:
starting_ind[(t,i,j)] = np.uint8(0)
num_chunks = multiprocessing.cpu_count()
@dask.delayed
def chunk(i):
starting_ind = Dict.empty(
key_type=types.UniTuple(types.int64, 3),
value_type=types.uint8
)
k0 = int(round(centers.shape[0] * (i / num_chunks)))
k1 = int(round(centers.shape[0] * ((i+1) / num_chunks)))
find_indices(centers[k0:k1,...], starting_ind, complete_ind)
return starting_ind
jobs = [chunk(i) for i in range(num_chunks)]
starting_ind = dask.compute(jobs, scheduler='threads')[0]
starting_ind = np.concatenate(
[np.array(list(st_ind.keys())) for st_ind in starting_ind if st_ind],
axis=0
)
return starting_ind