| """ |
| PerplexitySubsampler class: define and execute subsampling on a dataset, |
| weighted by perplexity values |
| """ |
|
|
| from collections import namedtuple |
| import numpy as np |
| import scipy as sp |
| from numpy.random import default_rng |
| from scipy.stats import norm, uniform |
|
|
| from typing import List, Tuple, Iterable |
|
|
| Histo = namedtuple("HISTO", "counts edges centers") |
|
|
| rng = default_rng() |
|
|
|
|
| def histo_quantile(hcounts: np.ndarray, hedges: np.ndarray, |
| perc_values: Iterable[float]) -> List[float]: |
| """ |
| Compute quantile values by using a histogram |
| """ |
| cs = np.cumsum(hcounts)/np.sum(hcounts) |
| out = [] |
| for p in perc_values: |
| idx = np.searchsorted(cs, p) |
| frac = (p - cs[idx-1]) / (cs[idx] - cs[idx-1]) |
| r = hedges[idx] + (hedges[idx+1] - hedges[idx])*frac |
| out.append(r) |
| return out |
|
|
|
|
| def _histo_inv_quantile(hedges: np.ndarray, hcounts: np.ndarray, |
| perp_value: float) -> float: |
| """ |
| Using an histogram of values, estimate the quantile occupied |
| by a given value |
| It is therefore the inverse function of quantile() |
| """ |
| v = np.searchsorted(hedges, perp_value, side="right") |
| frac = (perp_value - hedges[v-1]) / (hedges[v] - hedges[v-1]) |
| return hcounts[:v-1].sum() + hcounts[v-1]*frac |
|
|
|
|
| def subsample_frac(data: np.ndarray, frac: float) -> np.ndarray: |
| """ |
| Subsample an array to a given fraction |
| """ |
| return data[uniform.rvs(size=len(data)) < frac] |
|
|
|
|
|
|
| |
|
|
|
|
| class PerplexitySubsampler: |
|
|
| def __init__(self, perp_data: np.ndarray = None, |
| perp_histogram: Tuple[np.ndarray, np.ndarray] = None, |
| hbins: int = 1000): |
| """ |
| :param perp_data: a dataset of perplexity values |
| :param perp_histo: a histogram computed over a dataset of perplexity |
| values, passed as a tuple (counts, edges) |
| :param hbins: number of bins to use for the histogram approximation |
| (only used if `perp_data` is passed) |
| |
| Either `perp_data` or `perp_histogram` must be passed |
| """ |
| if perp_data is not None: |
|
|
| |
| self.qr = np.quantile(perp_data, [0.25, 0.75]) |
| |
| range_max = self.qr[1]*10 |
| counts, edges = np.histogram(perp_data, bins=hbins, |
| range=[0, range_max]) |
| counts[-1] += len(perp_data[perp_data > range_max]) |
| self.histo = Histo(counts, edges, (edges[:-1] + edges[1:])/2) |
|
|
| elif perp_histogram is not None: |
|
|
| edges = perp_histogram[1] |
| self.histo = Histo(perp_histogram[0], edges, |
| (edges[:-1] + edges[1:])/2) |
| self.qr = histo_quantile(self.histo.counts, self.histo.edges, |
| [0.25, 0.75]) |
|
|
| else: |
| raise Exception("Neither sample nor histogram provided") |
|
|
|
|
| def _estimate(self, m: float, s: float, |
| ratio: float) -> Tuple[float, float]: |
| """ |
| Estimate the quantiles to be retained in the 1st & 4th original |
| quartiles |
| """ |
| |
| gauss_weights = norm.pdf(self.histo.centers, loc=m, scale=s) |
| hcounts = self.histo.counts |
| adjusted_norm = (hcounts*gauss_weights).sum()/hcounts.sum()/ratio |
| |
| hcounts_sub = self.histo.counts*gauss_weights/adjusted_norm |
| sub_size = hcounts_sub.sum() |
| |
| ra = _histo_inv_quantile(self.histo.edges, hcounts_sub, self.qr[0])/sub_size |
| rb = _histo_inv_quantile(self.histo.edges, hcounts_sub, self.qr[1])/sub_size |
| |
| return ra, 1-rb |
|
|
|
|
| def _error(self, point: np.ndarray, ratio: float, |
| pa: float, pb: float) -> float: |
| """ |
| Estimate the error in probability mass results |
| """ |
| actual_pa, actual_pb = self._estimate(point[0], point[1], ratio) |
| return abs(pa-actual_pa) + abs(pb-actual_pb) |
|
|
|
|
| def set(self, ratio: float, pa: float, pb: float): |
| """ |
| Compute the parameters needed to achieve a desired sampling ratio & |
| probability distribution |
| :param ratio: the desired sampling ratio |
| :param pa: the probability mass to be left in the first original |
| perplexity quartile |
| :param pb: the probability mass to be left in the fourth original |
| perplexity quartile |
| """ |
| |
| |
| sdev = (self.qr[0] - self.qr[1]) / (norm.ppf(pa) - norm.ppf(1-pb)) |
| mean = self.qr[0] - norm.ppf(pa)*sdev |
| |
| initial = np.array([mean, sdev]) |
| result = sp.optimize.minimize(self._error, initial, |
| args=(ratio, pa, pb), |
| method='nelder-mead', |
| options={'xatol': 1e-8, 'disp': False}) |
| self.mean, self.sdev = result.x |
| |
| |
| gauss_weights = norm.pdf(self.histo.centers, loc=self.mean, |
| scale=self.sdev) |
| |
| counts = self.histo.counts |
| self.norm = (counts*gauss_weights).sum()/counts.sum()/ratio |
|
|
|
|
| def subsample(self, data: np.ndarray) -> np.ndarray: |
| """ |
| Subsample a dataset according to the defined conditions |
| Note: set() must have been called previously |
| """ |
| |
| p = norm.pdf(data, loc=self.mean, scale=self.sdev)/self.norm |
| |
| |
| return data[uniform.rvs(size=len(p)) < p] |
|
|
|
|
| def retain(self, perp: float) -> bool: |
| """ |
| Decide if a sample is to be retained based on its perplexity value |
| Note: set() must have been called previously |
| """ |
| p = norm.pdf(perp, loc=self.mean, scale=self.sdev)/self.norm |
| return rng.uniform() < p |
|
|
|
|
| def subsample_piecewise(self, data: np.ndarray, |
| pa: float, pb: float) -> np.ndarray: |
| """ |
| Creat a subsample by directly subsampling each region |
| """ |
| qr = self.qr |
| data1 = subsample_frac(data[data < qr[0]], pa/0.25*self.ratio) |
| data2 = subsample_frac(data[(data >= qr[0]) & (data <= qr[1])], |
| (1-pa-pb)/0.5*self.ratio) |
| data3 = subsample_frac(data[self.data > qr[1]], pb/0.25*self.ratio) |
| return np.hstack([data1, data2, data3]) |
|
|
|
|
| def verify(self, data: np.ndarray, data_sub: np.ndarray) -> Tuple: |
| """ |
| Check the statistics of a sample |
| """ |
| ratio = len(data_sub)/len(data) |
| ra = len(data_sub[data_sub < self.qr[0]]) / len(data_sub) |
| rb = len(data_sub[data_sub > self.qr[1]]) / len(data_sub) |
| return ratio, ra, rb |
|
|
|
|
|
|
| def check_results(s: PerplexitySubsampler, |
| data_full: np.ndarray, data_sub: np.ndarray): |
| """ |
| Compute and print out the results for a subsample |
| """ |
| r, ra, rb = s.verify(data_full, data_sub) |
| print("Sampling ratio:", r) |
| print("Probability mass below Pa:", ra) |
| print("Probability mass above Pb:", rb) |
|
|