| import numpy as np |
| import torch |
|
|
|
|
| @torch.jit.script |
| def erf(x): |
| return torch.sign(x) * torch.sqrt(1 - torch.exp(-4 / torch.pi * x ** 2)) |
|
|
|
|
| def matmul(a, b): |
| return (a[..., None] * b[..., None, :, :]).sum(dim=-2) |
| |
|
|
| |
| |
|
|
|
|
| def safe_trig_helper(x, fn, t=100 * torch.pi): |
| """Helper function used by safe_cos/safe_sin: mods x before sin()/cos().""" |
| return fn(torch.where(torch.abs(x) < t, x, x % t)) |
|
|
|
|
| def safe_cos(x): |
| return safe_trig_helper(x, torch.cos) |
|
|
|
|
| def safe_sin(x): |
| return safe_trig_helper(x, torch.sin) |
|
|
|
|
| def safe_exp(x): |
| return torch.exp(x.clamp_max(88.)) |
|
|
|
|
| def safe_exp_jvp(primals, tangents): |
| """Override safe_exp()'s gradient so that it's large when inputs are large.""" |
| x, = primals |
| x_dot, = tangents |
| exp_x = safe_exp(x) |
| exp_x_dot = exp_x * x_dot |
| return exp_x, exp_x_dot |
|
|
|
|
| def log_lerp(t, v0, v1): |
| """Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).""" |
| if v0 <= 0 or v1 <= 0: |
| raise ValueError(f'Interpolants {v0} and {v1} must be positive.') |
| lv0 = np.log(v0) |
| lv1 = np.log(v1) |
| return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0) |
|
|
|
|
| def learning_rate_decay(step, |
| lr_init, |
| lr_final, |
| max_steps, |
| lr_delay_steps=0, |
| lr_delay_mult=1): |
| """Continuous learning rate decay function. |
| |
| The returned rate is lr_init when step=0 and lr_final when step=max_steps, and |
| is log-linearly interpolated elsewhere (equivalent to exponential decay). |
| If lr_delay_steps>0 then the learning rate will be scaled by some smooth |
| function of lr_delay_mult, such that the initial learning rate is |
| lr_init*lr_delay_mult at the beginning of optimization but will be eased back |
| to the normal learning rate when steps>lr_delay_steps. |
| |
| Args: |
| step: int, the current optimization step. |
| lr_init: float, the initial learning rate. |
| lr_final: float, the final learning rate. |
| max_steps: int, the number of steps during optimization. |
| lr_delay_steps: int, the number of steps to delay the full learning rate. |
| lr_delay_mult: float, the multiplier on the rate when delaying it. |
| |
| Returns: |
| lr: the learning for current step 'step'. |
| """ |
| if lr_delay_steps > 0: |
| |
| delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( |
| 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)) |
| else: |
| delay_rate = 1. |
| return delay_rate * log_lerp(step / max_steps, lr_init, lr_final) |
|
|
|
|
| def sorted_interp(x, xp, fp): |
| """A TPU-friendly version of interp(), where xp and fp must be sorted.""" |
|
|
| |
| |
| mask = x[..., None, :] >= xp[..., :, None] |
|
|
| def find_interval(x): |
| |
| |
| x0 = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2).values |
| x1 = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2).values |
| return x0, x1 |
|
|
| fp0, fp1 = find_interval(fp) |
| xp0, xp1 = find_interval(xp) |
|
|
| offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1) |
| ret = fp0 + offset * (fp1 - fp0) |
| return ret |
|
|
|
|
| def sorted_interp_quad(x, xp, fpdf, fcdf): |
| """interp in quadratic""" |
|
|
| |
| |
| mask = x[..., None, :] >= xp[..., :, None] |
|
|
| def find_interval(x, return_idx=False): |
| |
| |
| x0, x0_idx = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2) |
| x1, x1_idx = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2) |
| if return_idx: |
| return x0, x1, x0_idx, x1_idx |
| return x0, x1 |
|
|
| fcdf0, fcdf1, fcdf0_idx, fcdf1_idx = find_interval(fcdf, return_idx=True) |
| fpdf0 = fpdf.take_along_dim(fcdf0_idx, dim=-1) |
| fpdf1 = fpdf.take_along_dim(fcdf1_idx, dim=-1) |
| xp0, xp1 = find_interval(xp) |
|
|
| offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1) |
| ret = fcdf0 + (x - xp0) * (fpdf0 + fpdf1 * offset + fpdf0 * (1 - offset)) / 2 |
| return ret |
|
|