| from internal import math |
| import numpy as np |
| import torch |
|
|
|
|
| def searchsorted(a, v): |
| """Find indices where v should be inserted into a to maintain order. |
| |
| Args: |
| a: tensor, the sorted reference points that we are scanning to see where v |
| should lie. |
| v: tensor, the query points that we are pretending to insert into a. Does |
| not need to be sorted. All but the last dimensions should match or expand |
| to those of a, the last dimension can differ. |
| |
| Returns: |
| (idx_lo, idx_hi), where a[idx_lo] <= v < a[idx_hi], unless v is out of the |
| range [a[0], a[-1]] in which case idx_lo and idx_hi are both the first or |
| last index of a. |
| """ |
| i = torch.arange(a.shape[-1], device=a.device) |
| v_ge_a = v[..., None, :] >= a[..., :, None] |
| idx_lo = torch.max(torch.where(v_ge_a, i[..., :, None], i[..., :1, None]), -2).values |
| idx_hi = torch.min(torch.where(~v_ge_a, i[..., :, None], i[..., -1:, None]), -2).values |
| return idx_lo, idx_hi |
|
|
|
|
| def query(tq, t, y, outside_value=0): |
| """Look up the values of the step function (t, y) at locations tq.""" |
| idx_lo, idx_hi = searchsorted(t, tq) |
| yq = torch.where(idx_lo == idx_hi, torch.full_like(idx_hi, outside_value), |
| torch.take_along_dim(y, idx_lo, dim=-1)) |
| return yq |
|
|
|
|
| def inner_outer(t0, t1, y1): |
| """Construct inner and outer measures on (t1, y1) for t0.""" |
| cy1 = torch.cat([torch.zeros_like(y1[..., :1]), |
| torch.cumsum(y1, dim=-1)], |
| dim=-1) |
| idx_lo, idx_hi = searchsorted(t1, t0) |
|
|
| cy1_lo = torch.take_along_dim(cy1, idx_lo, dim=-1) |
| cy1_hi = torch.take_along_dim(cy1, idx_hi, dim=-1) |
|
|
| y0_outer = cy1_hi[..., 1:] - cy1_lo[..., :-1] |
| y0_inner = torch.where(idx_hi[..., :-1] <= idx_lo[..., 1:], |
| cy1_lo[..., 1:] - cy1_hi[..., :-1], torch.zeros_like(idx_lo[..., 1:])) |
| return y0_inner, y0_outer |
|
|
|
|
| def lossfun_outer(t, w, t_env, w_env): |
| """The proposal weight should be an upper envelope on the nerf weight.""" |
| eps = torch.finfo(t.dtype).eps |
| |
|
|
| _, w_outer = inner_outer(t, t_env, w_env) |
| |
| |
| |
| return (w - w_outer).clamp_min(0) ** 2 / (w + eps) |
|
|
|
|
| def weight_to_pdf(t, w): |
| """Turn a vector of weights that sums to 1 into a PDF that integrates to 1.""" |
| eps = torch.finfo(t.dtype).eps |
| return w / (t[..., 1:] - t[..., :-1]).clamp_min(eps) |
|
|
|
|
| def pdf_to_weight(t, p): |
| """Turn a PDF that integrates to 1 into a vector of weights that sums to 1.""" |
| return p * (t[..., 1:] - t[..., :-1]) |
|
|
|
|
| def max_dilate(t, w, dilation, domain=(-torch.inf, torch.inf)): |
| """Dilate (via max-pooling) a non-negative step function.""" |
| t0 = t[..., :-1] - dilation |
| t1 = t[..., 1:] + dilation |
| t_dilate, _ = torch.sort(torch.cat([t, t0, t1], dim=-1), dim=-1) |
| t_dilate = torch.clip(t_dilate, *domain) |
| w_dilate = torch.max( |
| torch.where( |
| (t0[..., None, :] <= t_dilate[..., None]) |
| & (t1[..., None, :] > t_dilate[..., None]), |
| w[..., None, :], |
| torch.zeros_like(w[..., None, :]), |
| ), dim=-1).values[..., :-1] |
| return t_dilate, w_dilate |
|
|
|
|
| def max_dilate_weights(t, |
| w, |
| dilation, |
| domain=(-torch.inf, torch.inf), |
| renormalize=False): |
| """Dilate (via max-pooling) a set of weights.""" |
| eps = torch.finfo(w.dtype).eps |
| |
|
|
| p = weight_to_pdf(t, w) |
| t_dilate, p_dilate = max_dilate(t, p, dilation, domain=domain) |
| w_dilate = pdf_to_weight(t_dilate, p_dilate) |
| if renormalize: |
| w_dilate /= torch.sum(w_dilate, dim=-1, keepdim=True).clamp_min(eps) |
| return t_dilate, w_dilate |
|
|
|
|
| def integrate_weights(w): |
| """Compute the cumulative sum of w, assuming all weight vectors sum to 1. |
| |
| The output's size on the last dimension is one greater than that of the input, |
| because we're computing the integral corresponding to the endpoints of a step |
| function, not the integral of the interior/bin values. |
| |
| Args: |
| w: Tensor, which will be integrated along the last axis. This is assumed to |
| sum to 1 along the last axis, and this function will (silently) break if |
| that is not the case. |
| |
| Returns: |
| cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 |
| """ |
| cw = torch.cumsum(w[..., :-1], dim=-1).clamp_max(1) |
| shape = cw.shape[:-1] + (1,) |
| |
| cw0 = torch.cat([torch.zeros(shape, device=cw.device), cw, |
| torch.ones(shape, device=cw.device)], dim=-1) |
| return cw0 |
|
|
|
|
| def integrate_weights_np(w): |
| """Compute the cumulative sum of w, assuming all weight vectors sum to 1. |
| |
| The output's size on the last dimension is one greater than that of the input, |
| because we're computing the integral corresponding to the endpoints of a step |
| function, not the integral of the interior/bin values. |
| |
| Args: |
| w: Tensor, which will be integrated along the last axis. This is assumed to |
| sum to 1 along the last axis, and this function will (silently) break if |
| that is not the case. |
| |
| Returns: |
| cw0: Tensor, the integral of w, where cw0[..., 0] = 0 and cw0[..., -1] = 1 |
| """ |
| cw = np.minimum(1, np.cumsum(w[..., :-1], axis=-1)) |
| shape = cw.shape[:-1] + (1,) |
| |
| cw0 = np.concatenate([np.zeros(shape), cw, |
| np.ones(shape)], axis=-1) |
| return cw0 |
|
|
|
|
| def invert_cdf(u, t, w_logits): |
| """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" |
| |
| w = torch.softmax(w_logits, dim=-1) |
| cw = integrate_weights(w) |
| |
| t_new = math.sorted_interp(u, cw, t) |
| return t_new |
|
|
|
|
| def invert_cdf_np(u, t, w_logits): |
| """Invert the CDF defined by (t, w) at the points specified by u in [0, 1).""" |
| |
| w = np.exp(w_logits) / np.exp(w_logits).sum(axis=-1, keepdims=True) |
| cw = integrate_weights_np(w) |
| |
| interp_fn = np.interp |
| t_new = interp_fn(u, cw, t) |
| return t_new |
|
|
|
|
| def sample(rand, |
| t, |
| w_logits, |
| num_samples, |
| single_jitter=False, |
| deterministic_center=False): |
| """Piecewise-Constant PDF sampling from a step function. |
| |
| Args: |
| rand: random number generator (or None for `linspace` sampling). |
| t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) |
| w_logits: [..., num_bins], logits corresponding to bin weights |
| num_samples: int, the number of samples. |
| single_jitter: bool, if True, jitter every sample along each ray by the same |
| amount in the inverse CDF. Otherwise, jitter each sample independently. |
| deterministic_center: bool, if False, when `rand` is None return samples that |
| linspace the entire PDF. If True, skip the front and back of the linspace |
| so that the centers of each PDF interval are returned. |
| |
| Returns: |
| t_samples: [batch_size, num_samples]. |
| """ |
| eps = torch.finfo(t.dtype).eps |
| |
|
|
| device = t.device |
|
|
| |
| if not rand: |
| if deterministic_center: |
| pad = 1 / (2 * num_samples) |
| u = torch.linspace(pad, 1. - pad - eps, num_samples, device=device) |
| else: |
| u = torch.linspace(0, 1. - eps, num_samples, device=device) |
| u = torch.broadcast_to(u, t.shape[:-1] + (num_samples,)) |
| else: |
| |
| u_max = eps + (1 - eps) / num_samples |
| max_jitter = (1 - u_max) / (num_samples - 1) - eps |
| d = 1 if single_jitter else num_samples |
| u = torch.linspace(0, 1 - u_max, num_samples, device=device) + \ |
| torch.rand(t.shape[:-1] + (d,), device=device) * max_jitter |
|
|
| return invert_cdf(u, t, w_logits) |
|
|
|
|
| def sample_np(rand, |
| t, |
| w_logits, |
| num_samples, |
| single_jitter=False, |
| deterministic_center=False): |
| """ |
| numpy version of sample() |
| """ |
| eps = np.finfo(np.float32).eps |
|
|
| |
| if not rand: |
| if deterministic_center: |
| pad = 1 / (2 * num_samples) |
| u = np.linspace(pad, 1. - pad - eps, num_samples) |
| else: |
| u = np.linspace(0, 1. - eps, num_samples) |
| u = np.broadcast_to(u, t.shape[:-1] + (num_samples,)) |
| else: |
| |
| u_max = eps + (1 - eps) / num_samples |
| max_jitter = (1 - u_max) / (num_samples - 1) - eps |
| d = 1 if single_jitter else num_samples |
| u = np.linspace(0, 1 - u_max, num_samples) + \ |
| np.random.rand(*t.shape[:-1], d) * max_jitter |
|
|
| return invert_cdf_np(u, t, w_logits) |
|
|
|
|
| def sample_intervals(rand, |
| t, |
| w_logits, |
| num_samples, |
| single_jitter=False, |
| domain=(-torch.inf, torch.inf)): |
| """Sample *intervals* (rather than points) from a step function. |
| |
| Args: |
| rand: random number generator (or None for `linspace` sampling). |
| t: [..., num_bins + 1], bin endpoint coordinates (must be sorted) |
| w_logits: [..., num_bins], logits corresponding to bin weights |
| num_samples: int, the number of intervals to sample. |
| single_jitter: bool, if True, jitter every sample along each ray by the same |
| amount in the inverse CDF. Otherwise, jitter each sample independently. |
| domain: (minval, maxval), the range of valid values for `t`. |
| |
| Returns: |
| t_samples: [batch_size, num_samples]. |
| """ |
| if num_samples <= 1: |
| raise ValueError(f'num_samples must be > 1, is {num_samples}.') |
|
|
| |
| centers = sample( |
| rand, |
| t, |
| w_logits, |
| num_samples, |
| single_jitter, |
| deterministic_center=True) |
|
|
| |
| mid = (centers[..., 1:] + centers[..., :-1]) / 2 |
|
|
| |
| |
| |
| minval, maxval = domain |
| first = (2 * centers[..., :1] - mid[..., :1]).clamp_min(minval) |
| last = (2 * centers[..., -1:] - mid[..., -1:]).clamp_max(maxval) |
|
|
| t_samples = torch.cat([first, mid, last], dim=-1) |
| return t_samples |
|
|
|
|
| def lossfun_distortion(t, w): |
| """Compute iint w[i] w[j] |t[i] - t[j]| di dj.""" |
| |
| ut = (t[..., 1:] + t[..., :-1]) / 2 |
| dut = torch.abs(ut[..., :, None] - ut[..., None, :]) |
| loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1) |
|
|
| |
| loss_intra = torch.sum(w ** 2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3 |
|
|
| return loss_inter + loss_intra |
|
|
|
|
| def interval_distortion(t0_lo, t0_hi, t1_lo, t1_hi): |
| """Compute mean(abs(x-y); x in [t0_lo, t0_hi], y in [t1_lo, t1_hi]).""" |
| |
| d_disjoint = torch.abs((t1_lo + t1_hi) / 2 - (t0_lo + t0_hi) / 2) |
|
|
| |
| d_overlap = (2 * |
| (torch.minimum(t0_hi, t1_hi) ** 3 - torch.maximum(t0_lo, t1_lo) ** 3) + |
| 3 * (t1_hi * t0_hi * torch.abs(t1_hi - t0_hi) + |
| t1_lo * t0_lo * torch.abs(t1_lo - t0_lo) + t1_hi * t0_lo * |
| (t0_lo - t1_hi) + t1_lo * t0_hi * |
| (t1_lo - t0_hi))) / (6 * (t0_hi - t0_lo) * (t1_hi - t1_lo)) |
|
|
| |
| are_disjoint = (t0_lo > t1_hi) | (t1_lo > t0_hi) |
|
|
| return torch.where(are_disjoint, d_disjoint, d_overlap) |
|
|
|
|
| def weighted_percentile(t, w, ps): |
| """Compute the weighted percentiles of a step function. w's must sum to 1.""" |
| cw = integrate_weights(w) |
| |
| fn = lambda cw_i, t_i: math.sorted_interp(torch.tensor(ps, device=t.device) / 100, cw_i, t_i) |
| |
| cw_mat = cw.reshape([-1, cw.shape[-1]]) |
| t_mat = t.reshape([-1, t.shape[-1]]) |
| wprctile_mat = fn(cw_mat, t_mat) |
| wprctile = wprctile_mat.reshape(cw.shape[:-1] + (len(ps),)) |
| return wprctile |
|
|
|
|
| def resample(t, tp, vp, use_avg=False): |
| """Resample a step function defined by (tp, vp) into intervals t. |
| |
| Args: |
| t: tensor with shape (..., n+1), the endpoints to resample into. |
| tp: tensor with shape (..., m+1), the endpoints of the step function being |
| resampled. |
| vp: tensor with shape (..., m), the values of the step function being |
| resampled. |
| use_avg: bool, if False, return the sum of the step function for each |
| interval in `t`. If True, return the average, weighted by the width of |
| each interval in `t`. |
| eps: float, a small value to prevent division by zero when use_avg=True. |
| |
| Returns: |
| v: tensor with shape (..., n), the values of the resampled step function. |
| """ |
| eps = torch.finfo(t.dtype).eps |
| |
|
|
| if use_avg: |
| wp = torch.diff(tp, dim=-1) |
| v_numer = resample(t, tp, vp * wp, use_avg=False) |
| v_denom = resample(t, tp, wp, use_avg=False) |
| v = v_numer / v_denom.clamp_min(eps) |
| return v |
|
|
| acc = torch.cumsum(vp, dim=-1) |
| acc0 = torch.cat([torch.zeros(acc.shape[:-1] + (1,), device=acc.device), acc], dim=-1) |
| acc0_resampled = math.sorted_interp(t, tp, acc0) |
| v = torch.diff(acc0_resampled, dim=-1) |
| return v |
|
|
|
|
| def resample_np(t, tp, vp, use_avg=False): |
| """ |
| numpy version of resample |
| """ |
| eps = np.finfo(t.dtype).eps |
| if use_avg: |
| wp = np.diff(tp, axis=-1) |
| v_numer = resample_np(t, tp, vp * wp, use_avg=False) |
| v_denom = resample_np(t, tp, wp, use_avg=False) |
| v = v_numer / np.maximum(eps, v_denom) |
| return v |
|
|
| acc = np.cumsum(vp, axis=-1) |
| acc0 = np.concatenate([np.zeros(acc.shape[:-1] + (1,)), acc], axis=-1) |
| acc0_resampled = np.vectorize(np.interp, signature='(n),(m),(m)->(n)')(t, tp, acc0) |
| v = np.diff(acc0_resampled, axis=-1) |
| return v |
|
|
|
|
| def blur_stepfun(x, y, r): |
| xr, xr_idx = torch.sort(torch.cat([x - r, x + r], dim=-1)) |
| y1 = (torch.cat([y, torch.zeros_like(y[..., :1])], dim=-1) - |
| torch.cat([torch.zeros_like(y[..., :1]), y], dim=-1)) / (2 * r) |
| y2 = torch.cat([y1, -y1], dim=-1).take_along_dim(xr_idx[..., :-1], dim=-1) |
| yr = torch.cumsum((xr[..., 1:] - xr[..., :-1]) * |
| torch.cumsum(y2, dim=-1), dim=-1).clamp_min(0) |
| yr = torch.cat([torch.zeros_like(yr[..., :1]), yr], dim=-1) |
| return xr, yr |
|
|