| import torch |
| import numpy as np |
|
|
|
|
| def expand_t_like_x(t, x_cur): |
| """Function to reshape time t to broadcastable dimension of x |
| Args: |
| t: [batch_dim,], time vector |
| x: [batch_dim,...], data point |
| """ |
| dims = [1] * (len(x_cur.size()) - 1) |
| t = t.view(t.size(0), *dims) |
| return t |
|
|
| def get_score_from_velocity(vt, xt, t, path_type="linear"): |
| """Wrapper function: transfrom velocity prediction model to score |
| Args: |
| velocity: [batch_dim, ...] shaped tensor; velocity model output |
| x: [batch_dim, ...] shaped tensor; x_t data point |
| t: [batch_dim,] time tensor |
| """ |
| t = expand_t_like_x(t, xt) |
| if path_type == "linear": |
| alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1 |
| sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device) |
| elif path_type == "cosine": |
| alpha_t = torch.cos(t * np.pi / 2) |
| sigma_t = torch.sin(t * np.pi / 2) |
| d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) |
| d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) |
| else: |
| raise NotImplementedError |
|
|
| mean = xt |
| reverse_alpha_ratio = alpha_t / d_alpha_t |
| var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t |
| score = (reverse_alpha_ratio * vt - mean) / var |
|
|
| return score |
|
|
|
|
| def compute_diffusion(t_cur): |
| return 2 * t_cur |
|
|
|
|
| def euler_maruyama_sampler( |
| model, |
| latents, |
| y, |
| num_steps=20, |
| heun=False, |
| cfg_scale=1.0, |
| guidance_low=0.0, |
| guidance_high=1.0, |
| path_type="linear", |
| cls_latents=None, |
| args=None, |
| return_mid_state=False, |
| t_mid=0.5, |
| ): |
| |
| if cfg_scale > 1.0: |
| y_null = torch.tensor([1000] * y.size(0), device=y.device) |
| |
| _dtype = latents.dtype |
|
|
|
|
| t_steps = torch.linspace(1., 0.04, num_steps, dtype=torch.float64) |
| t_steps = torch.cat([t_steps, torch.tensor([0.], dtype=torch.float64)]) |
| x_next = latents.to(torch.float64) |
| cls_x_next = cls_latents.to(torch.float64) |
| device = x_next.device |
| z_mid, cls_mid = None, None |
| t_mid = float(t_mid) |
|
|
|
|
| with torch.no_grad(): |
| for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])): |
| dt = t_next - t_cur |
| x_cur = x_next |
| cls_x_cur = cls_x_next |
| tc, tn = float(t_cur), float(t_next) |
| if return_mid_state and z_mid is None and tn <= t_mid <= tc: |
| if abs(tc - t_mid) < abs(tn - t_mid): |
| z_mid = x_cur.clone() |
| cls_mid = cls_x_cur.clone() |
|
|
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| model_input = torch.cat([x_cur] * 2, dim=0) |
| cls_model_input = torch.cat([cls_x_cur] * 2, dim=0) |
| y_cur = torch.cat([y, y_null], dim=0) |
| else: |
| model_input = x_cur |
| cls_model_input = cls_x_cur |
| y_cur = y |
|
|
| kwargs = dict(y=y_cur) |
| time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur |
| diffusion = compute_diffusion(t_cur) |
|
|
| eps_i = torch.randn_like(x_cur).to(device) |
| cls_eps_i = torch.randn_like(cls_x_cur).to(device) |
| deps = eps_i * torch.sqrt(torch.abs(dt)) |
| cls_deps = cls_eps_i * torch.sqrt(torch.abs(dt)) |
|
|
| |
| v_cur, _, cls_v_cur = model( |
| model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) |
| ) |
| v_cur = v_cur.to(torch.float64) |
| cls_v_cur = cls_v_cur.to(torch.float64) |
|
|
| s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) |
| d_cur = v_cur - 0.5 * diffusion * s_cur |
|
|
| cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) |
| cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur |
|
|
| if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: |
| d_cur_cond, d_cur_uncond = d_cur.chunk(2) |
| d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) |
|
|
| cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) |
| if args.cls_cfg_scale >0: |
| cls_d_cur = cls_d_cur_uncond + args.cls_cfg_scale * (cls_d_cur_cond - cls_d_cur_uncond) |
| else: |
| cls_d_cur = cls_d_cur_cond |
| x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps |
| cls_x_next = cls_x_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps |
| if return_mid_state and z_mid is None and tn <= t_mid <= tc: |
| z_mid = x_next.clone() |
| cls_mid = cls_x_next.clone() |
|
|
| |
| t_cur, t_next = t_steps[-2], t_steps[-1] |
| dt = t_next - t_cur |
| x_cur = x_next |
| cls_x_cur = cls_x_next |
|
|
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| model_input = torch.cat([x_cur] * 2, dim=0) |
| cls_model_input = torch.cat([cls_x_cur] * 2, dim=0) |
| y_cur = torch.cat([y, y_null], dim=0) |
| else: |
| model_input = x_cur |
| cls_model_input = cls_x_cur |
| y_cur = y |
| kwargs = dict(y=y_cur) |
| time_input = torch.ones(model_input.size(0)).to( |
| device=device, dtype=torch.float64 |
| ) * t_cur |
| |
| |
| v_cur, _, cls_v_cur = model( |
| model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) |
| ) |
| v_cur = v_cur.to(torch.float64) |
| cls_v_cur = cls_v_cur.to(torch.float64) |
|
|
|
|
| s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) |
| cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) |
|
|
| diffusion = compute_diffusion(t_cur) |
| d_cur = v_cur - 0.5 * diffusion * s_cur |
| cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur |
|
|
| if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: |
| d_cur_cond, d_cur_uncond = d_cur.chunk(2) |
| d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) |
|
|
| cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) |
| if args.cls_cfg_scale > 0: |
| cls_d_cur = cls_d_cur_uncond + args.cls_cfg_scale * (cls_d_cur_cond - cls_d_cur_uncond) |
| else: |
| cls_d_cur = cls_d_cur_cond |
|
|
| mean_x = x_cur + dt * d_cur |
| cls_mean_x = cls_x_cur + dt * cls_d_cur |
|
|
| if return_mid_state: |
| return mean_x, z_mid, cls_mean_x if cls_mid is None else cls_mid |
| return mean_x |
|
|
|
|
| def euler_sampler( |
| model, |
| latents, |
| y, |
| num_steps=20, |
| heun=False, |
| cfg_scale=1.0, |
| guidance_low=0.0, |
| guidance_high=1.0, |
| path_type="linear", |
| cls_latents=None, |
| args=None |
| ): |
| """ |
| REG 的 ODE 采样器:确定性(不注入扩散噪声)。 |
| |
| 这里按照 REG/SiT 的 velocity 参数化直接做 ODE: |
| d/dt x_t = v_t |
| 因此不需要把 velocity 再转成 score 再转 drift。 |
| """ |
| |
| if cfg_scale > 1.0: |
| y_null = torch.tensor([1000] * y.size(0), device=y.device) |
| _dtype = latents.dtype |
|
|
| cls_cfg_scale = getattr(args, "cls_cfg_scale", 0) if args is not None else 0 |
|
|
| |
| t_steps = torch.linspace(1.0, 0.0, int(num_steps) + 1, dtype=torch.float64, device=latents.device) |
|
|
| x_next = latents.to(torch.float64) |
| cls_x_next = cls_latents.to(torch.float64) |
| device = x_next.device |
|
|
| with torch.no_grad(): |
| for t_cur, t_next in zip(t_steps[:-1], t_steps[1:]): |
| dt = t_next - t_cur |
| x_cur = x_next |
| cls_x_cur = cls_x_next |
|
|
| |
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| model_input = torch.cat([x_cur] * 2, dim=0) |
| cls_model_input = torch.cat([cls_x_cur] * 2, dim=0) |
| y_cur = torch.cat([y, y_null], dim=0) |
| else: |
| model_input = x_cur |
| cls_model_input = cls_x_cur |
| y_cur = y |
|
|
| kwargs = dict(y=y_cur) |
| time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur |
|
|
| v_cur, _, cls_v_cur = model( |
| model_input.to(dtype=_dtype), |
| time_input.to(dtype=_dtype), |
| **kwargs, |
| cls_token=cls_model_input.to(dtype=_dtype), |
| ) |
|
|
| |
| d_cur = v_cur.to(torch.float64) |
| cls_d_cur = cls_v_cur.to(torch.float64) |
|
|
| |
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| d_cur_cond, d_cur_uncond = d_cur.chunk(2) |
| d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) |
|
|
| cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) |
| if cls_cfg_scale > 0: |
| cls_d_cur = cls_d_cur_uncond + cls_cfg_scale * (cls_d_cur_cond - cls_d_cur_uncond) |
| else: |
| cls_d_cur = cls_d_cur_cond |
|
|
| x_next = x_cur + dt * d_cur |
| cls_x_next = cls_x_cur + dt * cls_d_cur |
|
|
| return x_next |
|
|