| import torch |
| import numpy as np |
|
|
|
|
| def expand_t_like_x(t, x_cur): |
| """Function to reshape time t to broadcastable dimension of x |
| Args: |
| t: [batch_dim,], time vector |
| x: [batch_dim,...], data point |
| """ |
| dims = [1] * (len(x_cur.size()) - 1) |
| t = t.view(t.size(0), *dims) |
| return t |
|
|
| def get_score_from_velocity(vt, xt, t, path_type="linear"): |
| """Wrapper function: transfrom velocity prediction model to score |
| Args: |
| velocity: [batch_dim, ...] shaped tensor; velocity model output |
| x: [batch_dim, ...] shaped tensor; x_t data point |
| t: [batch_dim,] time tensor |
| """ |
| t = expand_t_like_x(t, xt) |
| if path_type == "linear": |
| alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1 |
| sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device) |
| elif path_type == "cosine": |
| alpha_t = torch.cos(t * np.pi / 2) |
| sigma_t = torch.sin(t * np.pi / 2) |
| d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) |
| d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) |
| else: |
| raise NotImplementedError |
|
|
| mean = xt |
| reverse_alpha_ratio = alpha_t / d_alpha_t |
| var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t |
| score = (reverse_alpha_ratio * vt - mean) / var |
|
|
| return score |
|
|
|
|
| def compute_diffusion(t_cur): |
| return 2 * t_cur |
|
|
|
|
| def build_sampling_time_steps( |
| num_steps=50, |
| t_c=None, |
| num_steps_before_tc=None, |
| num_steps_after_tc=None, |
| t_floor=0.04, |
| ): |
| """ |
| 构造从 t=1 → t=0 的时间网格(与原先一致:最后一段到 0 前保留 t_floor,再接到 0)。 |
| |
| - 默认:均匀 linspace(1, t_floor, num_steps),再 append 0。 |
| - 分段:t∈(t_c,1] 用 num_steps_before_tc 步(从 1 线性到 t_c); |
| t∈[0,t_c] 用 num_steps_after_tc 步(从 t_c 线性到 t_floor),再 append 0。 |
| """ |
| t_floor = float(t_floor) |
| if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None: |
| ns = int(num_steps) |
| if ns < 1: |
| raise ValueError("num_steps must be >= 1") |
| t_steps = torch.linspace(1.0, t_floor, ns, dtype=torch.float64) |
| return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)]) |
|
|
| t_c = float(t_c) |
| nb = int(num_steps_before_tc) |
| na = int(num_steps_after_tc) |
| if nb < 1 or na < 1: |
| raise ValueError("num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c") |
| if not (0.0 < t_c < 1.0): |
| raise ValueError("t_c must be in (0, 1)") |
| if t_c <= t_floor: |
| raise ValueError(f"t_c ({t_c}) must be > t_floor ({t_floor})") |
|
|
| p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64) |
| p2 = torch.linspace(t_c, t_floor, na + 1, dtype=torch.float64) |
| t_steps = torch.cat([p1, p2[1:]]) |
| return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)]) |
|
|
|
|
| def _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc): |
| """仅在 1→t_c→0 分段网格下启用:t∈[0,t_c] 段固定使用到达 t_c 时的 cls。""" |
| return ( |
| t_c is not None |
| and num_steps_before_tc is not None |
| and num_steps_after_tc is not None |
| ) |
|
|
|
|
| def _cls_effective_and_freeze( |
| cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc |
| ): |
| """ |
| 时间从 1 减到 0:当 t_cur <= t_c 时冻结 cls(取首次进入该段时的 cls_x_cur)。 |
| 返回 (用于前向的 cls, 更新后的 cls_frozen)。 |
| """ |
| if not freeze_after_tc or t_c_v is None: |
| return cls_x_cur, cls_frozen |
| if float(t_cur) <= float(t_c_v) + 1e-9: |
| if cls_frozen is None: |
| cls_frozen = cls_x_cur.clone() |
| return cls_frozen, cls_frozen |
| return cls_x_cur, cls_frozen |
|
|
|
|
| def _build_euler_sampler_time_steps( |
| num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device |
| ): |
| """ |
| euler_sampler / REG ODE 用时间网格:默认 linspace(1,0);分段时为 1→t_c→0 直连,无 t_floor。 |
| """ |
| if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None: |
| ns = int(num_steps) |
| if ns < 1: |
| raise ValueError("num_steps must be >= 1") |
| return torch.linspace(1.0, 0.0, ns + 1, dtype=torch.float64, device=device) |
| t_c = float(t_c) |
| nb = int(num_steps_before_tc) |
| na = int(num_steps_after_tc) |
| if nb < 1 or na < 1: |
| raise ValueError( |
| "num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c" |
| ) |
| if not (0.0 < t_c < 1.0): |
| raise ValueError("t_c must be in (0, 1)") |
| p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64, device=device) |
| p2 = torch.linspace(t_c, 0.0, na + 1, dtype=torch.float64, device=device) |
| return torch.cat([p1, p2[1:]]) |
|
|
|
|
| def euler_maruyama_sampler( |
| model, |
| latents, |
| y, |
| num_steps=20, |
| heun=False, |
| cfg_scale=1.0, |
| guidance_low=0.0, |
| guidance_high=1.0, |
| path_type="linear", |
| cls_latents=None, |
| args=None, |
| return_mid_state=False, |
| t_mid=0.5, |
| t_c=None, |
| num_steps_before_tc=None, |
| num_steps_after_tc=None, |
| deterministic=False, |
| return_trajectory=False, |
| ): |
| """ |
| Euler–Maruyama:漂移项与 score/velocity 变换与 euler_ode_sampler(euler_sampler)一致; |
| deterministic=True 时关闭扩散噪声项。ODE 使用 euler_sampler 的 linspace(1→0) / t_c 分段网格(无 t_floor), |
| 本函数仍用 build_sampling_time_steps(含 t_floor),与 EM/SDE 对齐。 |
| """ |
| |
| if cfg_scale > 1.0: |
| y_null = torch.tensor([1000] * y.size(0), device=y.device) |
| |
| _dtype = latents.dtype |
| cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0 |
|
|
| t_steps = build_sampling_time_steps( |
| num_steps=num_steps, |
| t_c=t_c, |
| num_steps_before_tc=num_steps_before_tc, |
| num_steps_after_tc=num_steps_after_tc, |
| ) |
| freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc) |
| t_c_v = float(t_c) if freeze_after_tc else None |
| x_next = latents.to(torch.float64) |
| cls_x_next = cls_latents.to(torch.float64) |
| device = x_next.device |
| z_mid = cls_mid = None |
| t_mid = float(t_mid) |
| cls_frozen = None |
| traj = [x_next.clone()] if return_trajectory else None |
|
|
| with torch.no_grad(): |
| for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])): |
| dt = t_next - t_cur |
| x_cur = x_next |
| cls_x_cur = cls_x_next |
|
|
| cls_model_input, cls_frozen = _cls_effective_and_freeze( |
| cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc |
| ) |
|
|
| tc, tn = float(t_cur), float(t_next) |
| if return_mid_state and z_mid is None and tn <= t_mid <= tc: |
| if abs(tc - t_mid) < abs(tn - t_mid): |
| z_mid = x_cur.clone() |
| cls_mid = cls_model_input.clone() |
|
|
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| model_input = torch.cat([x_cur] * 2, dim=0) |
| cls_model_input = torch.cat([cls_model_input] * 2, dim=0) |
| y_cur = torch.cat([y, y_null], dim=0) |
| else: |
| model_input = x_cur |
| y_cur = y |
|
|
| kwargs = dict(y=y_cur) |
| time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur |
| diffusion = compute_diffusion(t_cur) |
|
|
| if deterministic: |
| deps = torch.zeros_like(x_cur) |
| cls_deps = torch.zeros_like(cls_model_input[: x_cur.size(0)]) |
| else: |
| eps_i = torch.randn_like(x_cur).to(device) |
| cls_eps_i = torch.randn_like(cls_model_input[: x_cur.size(0)]).to(device) |
| deps = eps_i * torch.sqrt(torch.abs(dt)) |
| cls_deps = cls_eps_i * torch.sqrt(torch.abs(dt)) |
|
|
| |
| v_cur, _, cls_v_cur = model( |
| model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) |
| ) |
| v_cur = v_cur.to(torch.float64) |
| cls_v_cur = cls_v_cur.to(torch.float64) |
|
|
| s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) |
| d_cur = v_cur - 0.5 * diffusion * s_cur |
|
|
| cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) |
| cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur |
|
|
| if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: |
| d_cur_cond, d_cur_uncond = d_cur.chunk(2) |
| d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) |
|
|
| cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) |
| if cls_cfg > 0: |
| cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) |
| else: |
| cls_d_cur = cls_d_cur_cond |
| x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps |
| if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9: |
| cls_x_next = cls_frozen |
| else: |
| cls_x_next = cls_x_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps |
|
|
| if return_trajectory: |
| traj.append(x_next.clone()) |
|
|
| if return_mid_state and z_mid is None and tn <= t_mid <= tc: |
| z_mid = x_next.clone() |
| cls_mid = cls_x_next.clone() |
|
|
| |
| t_cur, t_next = t_steps[-2], t_steps[-1] |
| dt = t_next - t_cur |
| x_cur = x_next |
| cls_x_cur = cls_x_next |
|
|
| cls_model_input, cls_frozen = _cls_effective_and_freeze( |
| cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc |
| ) |
|
|
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| model_input = torch.cat([x_cur] * 2, dim=0) |
| cls_model_input = torch.cat([cls_model_input] * 2, dim=0) |
| y_cur = torch.cat([y, y_null], dim=0) |
| else: |
| model_input = x_cur |
| y_cur = y |
| kwargs = dict(y=y_cur) |
| time_input = torch.ones(model_input.size(0)).to( |
| device=device, dtype=torch.float64 |
| ) * t_cur |
| |
| |
| v_cur, _, cls_v_cur = model( |
| model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) |
| ) |
| v_cur = v_cur.to(torch.float64) |
| cls_v_cur = cls_v_cur.to(torch.float64) |
|
|
|
|
| s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) |
| cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) |
|
|
| diffusion = compute_diffusion(t_cur) |
| d_cur = v_cur - 0.5 * diffusion * s_cur |
| cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur |
|
|
| if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: |
| d_cur_cond, d_cur_uncond = d_cur.chunk(2) |
| d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) |
|
|
| cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) |
| if cls_cfg > 0: |
| cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) |
| else: |
| cls_d_cur = cls_d_cur_cond |
|
|
| mean_x = x_cur + dt * d_cur |
| if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9: |
| cls_mean_x = cls_frozen |
| else: |
| cls_mean_x = cls_x_cur + dt * cls_d_cur |
|
|
| if return_trajectory: |
| traj.append(mean_x.clone()) |
|
|
| if return_trajectory and return_mid_state: |
| return mean_x, z_mid, cls_mid, traj |
| if return_trajectory: |
| return mean_x, traj |
| if return_mid_state: |
| return mean_x, z_mid, cls_mid |
| return mean_x |
|
|
|
|
| def euler_maruyama_image_noise_sampler( |
| model, |
| latents, |
| y, |
| num_steps=20, |
| heun=False, |
| cfg_scale=1.0, |
| guidance_low=0.0, |
| guidance_high=1.0, |
| path_type="linear", |
| cls_latents=None, |
| args=None, |
| return_mid_state=False, |
| t_mid=0.5, |
| t_c=None, |
| num_steps_before_tc=None, |
| num_steps_after_tc=None, |
| return_trajectory=False, |
| ): |
| """ |
| EM 采样变体:仅图像 latent 引入随机扩散噪声,cls/token 通道不引入随机项(deterministic)。 |
| """ |
| if cfg_scale > 1.0: |
| y_null = torch.tensor([1000] * y.size(0), device=y.device) |
| _dtype = latents.dtype |
| cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0 |
|
|
| t_steps = build_sampling_time_steps( |
| num_steps=num_steps, |
| t_c=t_c, |
| num_steps_before_tc=num_steps_before_tc, |
| num_steps_after_tc=num_steps_after_tc, |
| ) |
| freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc) |
| t_c_v = float(t_c) if freeze_after_tc else None |
| x_next = latents.to(torch.float64) |
| cls_x_next = cls_latents.to(torch.float64) |
| device = x_next.device |
| z_mid = cls_mid = None |
| t_mid = float(t_mid) |
| cls_frozen = None |
| traj = [x_next.clone()] if return_trajectory else None |
|
|
| with torch.no_grad(): |
| for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]): |
| dt = t_next - t_cur |
| x_cur = x_next |
| cls_x_cur = cls_x_next |
|
|
| cls_model_input, cls_frozen = _cls_effective_and_freeze( |
| cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc |
| ) |
|
|
| tc, tn = float(t_cur), float(t_next) |
| if return_mid_state and z_mid is None and tn <= t_mid <= tc: |
| if abs(tc - t_mid) < abs(tn - t_mid): |
| z_mid = x_cur.clone() |
| cls_mid = cls_model_input.clone() |
|
|
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| model_input = torch.cat([x_cur] * 2, dim=0) |
| cls_model_input = torch.cat([cls_model_input] * 2, dim=0) |
| y_cur = torch.cat([y, y_null], dim=0) |
| else: |
| model_input = x_cur |
| y_cur = y |
|
|
| kwargs = dict(y=y_cur) |
| time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur |
| diffusion = compute_diffusion(t_cur) |
|
|
| eps_i = torch.randn_like(x_cur).to(device) |
| deps = eps_i * torch.sqrt(torch.abs(dt)) |
|
|
| v_cur, _, cls_v_cur = model( |
| model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) |
| ) |
| v_cur = v_cur.to(torch.float64) |
| cls_v_cur = cls_v_cur.to(torch.float64) |
|
|
| if add_img_noise: |
| s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) |
| d_cur = v_cur - 0.5 * diffusion * s_cur |
|
|
| cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) |
| cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur |
| else: |
| |
| d_cur = v_cur |
| cls_d_cur = cls_v_cur |
|
|
| if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: |
| d_cur_cond, d_cur_uncond = d_cur.chunk(2) |
| d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) |
|
|
| cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) |
| if cls_cfg > 0: |
| cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) |
| else: |
| cls_d_cur = cls_d_cur_cond |
|
|
| |
| x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps |
| if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9: |
| cls_x_next = cls_frozen |
| else: |
| cls_x_next = cls_x_cur + cls_d_cur * dt |
| if return_trajectory: |
| traj.append(x_next.clone()) |
|
|
| if return_mid_state and z_mid is None and tn <= t_mid <= tc: |
| z_mid = x_next.clone() |
| cls_mid = cls_x_next.clone() |
|
|
| t_cur, t_next = t_steps[-2], t_steps[-1] |
| dt = t_next - t_cur |
| x_cur = x_next |
| cls_x_cur = cls_x_next |
|
|
| cls_model_input, cls_frozen = _cls_effective_and_freeze( |
| cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc |
| ) |
|
|
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| model_input = torch.cat([x_cur] * 2, dim=0) |
| cls_model_input = torch.cat([cls_model_input] * 2, dim=0) |
| y_cur = torch.cat([y, y_null], dim=0) |
| else: |
| model_input = x_cur |
| y_cur = y |
| kwargs = dict(y=y_cur) |
| time_input = torch.ones(model_input.size(0)).to( |
| device=device, dtype=torch.float64 |
| ) * t_cur |
|
|
| v_cur, _, cls_v_cur = model( |
| model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) |
| ) |
| v_cur = v_cur.to(torch.float64) |
| cls_v_cur = cls_v_cur.to(torch.float64) |
|
|
| |
| d_cur = v_cur |
| cls_d_cur = cls_v_cur |
|
|
| if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: |
| d_cur_cond, d_cur_uncond = d_cur.chunk(2) |
| d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) |
|
|
| cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) |
| if cls_cfg > 0: |
| cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) |
| else: |
| cls_d_cur = cls_d_cur_cond |
|
|
| mean_x = x_cur + dt * d_cur |
| if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9: |
| cls_mean_x = cls_frozen |
| else: |
| cls_mean_x = cls_x_cur + dt * cls_d_cur |
|
|
| if return_trajectory and return_mid_state: |
| return mean_x, z_mid, cls_mid, traj |
| if return_trajectory: |
| return mean_x, traj |
| if return_mid_state: |
| return mean_x, z_mid, cls_mid |
| return mean_x |
|
|
|
|
| def euler_maruyama_image_noise_before_tc_sampler( |
| model, |
| latents, |
| y, |
| num_steps=20, |
| heun=False, |
| cfg_scale=1.0, |
| guidance_low=0.0, |
| guidance_high=1.0, |
| path_type="linear", |
| cls_latents=None, |
| args=None, |
| return_mid_state=False, |
| t_mid=0.5, |
| t_c=None, |
| num_steps_before_tc=None, |
| num_steps_after_tc=None, |
| return_cls_final=False, |
| return_trajectory=False, |
| ): |
| """ |
| EM 采样变体: |
| - 图像 latent 在 t > t_c 区间引入随机扩散噪声; |
| - 图像 latent 在 t <= t_c 区间不引入随机项(仅漂移); |
| - cls/token 通道全程不引入随机项。 |
| """ |
| if cfg_scale > 1.0: |
| y_null = torch.tensor([1000] * y.size(0), device=y.device) |
| _dtype = latents.dtype |
| cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0 |
|
|
| t_steps = build_sampling_time_steps( |
| num_steps=num_steps, |
| t_c=t_c, |
| num_steps_before_tc=num_steps_before_tc, |
| num_steps_after_tc=num_steps_after_tc, |
| ) |
| freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc) |
| t_c_freeze = float(t_c) if freeze_after_tc else None |
| x_next = latents.to(torch.float64) |
| cls_x_next = cls_latents.to(torch.float64) |
| device = x_next.device |
| z_mid = cls_mid = None |
| t_mid = float(t_mid) |
| t_c_v = None if t_c is None else float(t_c) |
| cls_frozen = None |
| traj = [x_next.clone()] if return_trajectory else None |
|
|
| with torch.no_grad(): |
| for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]): |
| dt = t_next - t_cur |
| x_cur = x_next |
| cls_x_cur = cls_x_next |
|
|
| cls_model_input, cls_frozen = _cls_effective_and_freeze( |
| cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc |
| ) |
|
|
| tc, tn = float(t_cur), float(t_next) |
| if return_mid_state and z_mid is None and tn <= t_mid <= tc: |
| if abs(tc - t_mid) < abs(tn - t_mid): |
| z_mid = x_cur.clone() |
| cls_mid = cls_model_input.clone() |
|
|
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| model_input = torch.cat([x_cur] * 2, dim=0) |
| cls_model_input = torch.cat([cls_model_input] * 2, dim=0) |
| y_cur = torch.cat([y, y_null], dim=0) |
| else: |
| model_input = x_cur |
| y_cur = y |
|
|
| kwargs = dict(y=y_cur) |
| time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur |
| diffusion = compute_diffusion(t_cur) |
|
|
| |
| add_img_noise = True |
| if t_c_v is not None and float(t_next) <= t_c_v: |
| add_img_noise = False |
|
|
| eps_i = torch.randn_like(x_cur).to(device) if add_img_noise else torch.zeros_like(x_cur) |
| deps = eps_i * torch.sqrt(torch.abs(dt)) |
|
|
| v_cur, _, cls_v_cur = model( |
| model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) |
| ) |
| v_cur = v_cur.to(torch.float64) |
| cls_v_cur = cls_v_cur.to(torch.float64) |
|
|
| if add_img_noise: |
| s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) |
| d_cur = v_cur - 0.5 * diffusion * s_cur |
|
|
| cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) |
| cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur |
| else: |
| |
| d_cur = v_cur |
| cls_d_cur = cls_v_cur |
|
|
| if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: |
| d_cur_cond, d_cur_uncond = d_cur.chunk(2) |
| d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) |
|
|
| cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) |
| if cls_cfg > 0: |
| cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) |
| else: |
| cls_d_cur = cls_d_cur_cond |
|
|
| x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps |
| if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9: |
| cls_x_next = cls_frozen |
| else: |
| cls_x_next = cls_x_cur + cls_d_cur * dt |
| if return_trajectory: |
| traj.append(x_next.clone()) |
|
|
| if return_mid_state and z_mid is None and tn <= t_mid <= tc: |
| z_mid = x_next.clone() |
| cls_mid = cls_x_next.clone() |
|
|
| t_cur, t_next = t_steps[-2], t_steps[-1] |
| dt = t_next - t_cur |
| x_cur = x_next |
| cls_x_cur = cls_x_next |
|
|
| cls_model_input, cls_frozen = _cls_effective_and_freeze( |
| cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc |
| ) |
|
|
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| model_input = torch.cat([x_cur] * 2, dim=0) |
| cls_model_input = torch.cat([cls_model_input] * 2, dim=0) |
| y_cur = torch.cat([y, y_null], dim=0) |
| else: |
| model_input = x_cur |
| y_cur = y |
| kwargs = dict(y=y_cur) |
| time_input = torch.ones(model_input.size(0)).to( |
| device=device, dtype=torch.float64 |
| ) * t_cur |
|
|
| v_cur, _, cls_v_cur = model( |
| model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) |
| ) |
| v_cur = v_cur.to(torch.float64) |
| cls_v_cur = cls_v_cur.to(torch.float64) |
|
|
| |
| d_cur = v_cur |
| cls_d_cur = cls_v_cur |
|
|
| if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: |
| d_cur_cond, d_cur_uncond = d_cur.chunk(2) |
| d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) |
|
|
| cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) |
| if cls_cfg > 0: |
| cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) |
| else: |
| cls_d_cur = cls_d_cur_cond |
|
|
| mean_x = x_cur + dt * d_cur |
| if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9: |
| cls_mean_x = cls_frozen |
| else: |
| cls_mean_x = cls_x_cur + dt * cls_d_cur |
|
|
| if return_trajectory and return_mid_state and return_cls_final: |
| return mean_x, z_mid, cls_mid, cls_mean_x, traj |
| if return_trajectory and return_mid_state: |
| return mean_x, z_mid, cls_mid, traj |
| if return_trajectory and return_cls_final: |
| return mean_x, cls_mean_x, traj |
| if return_trajectory: |
| return mean_x, traj |
| if return_mid_state and return_cls_final: |
| return mean_x, z_mid, cls_mid, cls_mean_x |
| if return_mid_state: |
| return mean_x, z_mid, cls_mid |
| if return_cls_final: |
| return mean_x, cls_mean_x |
| return mean_x |
|
|
|
|
| def euler_ode_sampler( |
| model, |
| latents, |
| y, |
| num_steps=20, |
| cfg_scale=1.0, |
| guidance_low=0.0, |
| guidance_high=1.0, |
| path_type="linear", |
| cls_latents=None, |
| args=None, |
| return_mid_state=False, |
| t_mid=0.5, |
| t_c=None, |
| num_steps_before_tc=None, |
| num_steps_after_tc=None, |
| return_trajectory=False, |
| ): |
| """ |
| REG 的 ODE 入口:与 SDE 采样器解耦,直接委托 euler_sampler(linspace 1→0 或 t_c 分段,无 t_floor)。 |
| """ |
| return euler_sampler( |
| model, |
| latents, |
| y, |
| num_steps=num_steps, |
| heun=False, |
| cfg_scale=cfg_scale, |
| guidance_low=guidance_low, |
| guidance_high=guidance_high, |
| path_type=path_type, |
| cls_latents=cls_latents, |
| args=args, |
| return_mid_state=return_mid_state, |
| t_mid=t_mid, |
| t_c=t_c, |
| num_steps_before_tc=num_steps_before_tc, |
| num_steps_after_tc=num_steps_after_tc, |
| return_trajectory=return_trajectory, |
| ) |
|
|
|
|
| def euler_sampler( |
| model, |
| latents, |
| y, |
| num_steps=20, |
| heun=False, |
| cfg_scale=1.0, |
| guidance_low=0.0, |
| guidance_high=1.0, |
| path_type="linear", |
| cls_latents=None, |
| args=None, |
| return_mid_state=False, |
| t_mid=0.5, |
| t_c=None, |
| num_steps_before_tc=None, |
| num_steps_after_tc=None, |
| return_trajectory=False, |
| ): |
| """ |
| 轻量确定性漂移采样(与 glflow 同名同参的前缀兼容:model, latents, y, num_steps, heun, cfg, guidance, path_type, cls_latents, args)。 |
| |
| - 默认:linspace(1, 0, num_steps+1),无 t_floor(与原先独立 ODE 一致)。 |
| - 可选:同时传入 t_c、num_steps_before_tc、num_steps_after_tc 时,网格为 1→t_c→0;并与 EM 一致在 t≤t_c 段冻结 cls。 |
| - 可选:return_mid_state / return_trajectory 供 train.py 与 sample_from_checkpoint 使用。 |
| |
| REG 的 SiT 需要 cls_token;cls_latents 不可为 None。heun 占位未使用。 |
| """ |
| if cls_latents is None: |
| raise ValueError( |
| "euler_sampler: 本仓库 REG SiT 需要 cls_token,请传入 cls_latents(例如高斯噪声或训练中的 cls 初值)。" |
| ) |
| if cfg_scale > 1.0: |
| y_null = torch.full((y.size(0),), 1000, device=y.device, dtype=y.dtype) |
| else: |
| y_null = None |
| _dtype = latents.dtype |
| cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0 |
| device = latents.device |
|
|
| t_steps = _build_euler_sampler_time_steps( |
| num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device |
| ) |
| freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc) |
| t_c_v = float(t_c) if freeze_after_tc else None |
|
|
| x_next = latents.to(torch.float64) |
| cls_x_next = cls_latents.to(torch.float64) |
| z_mid = cls_mid = None |
| t_mid = float(t_mid) |
| cls_frozen = None |
| traj = [x_next.clone()] if return_trajectory else None |
|
|
| with torch.no_grad(): |
| for t_cur, t_next in zip(t_steps[:-1], t_steps[1:]): |
| dt = t_next - t_cur |
| x_cur = x_next |
| cls_x_cur = cls_x_next |
|
|
| cls_model_input, cls_frozen = _cls_effective_and_freeze( |
| cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc |
| ) |
|
|
| tc, tn = float(t_cur), float(t_next) |
| if return_mid_state and z_mid is None and tn <= t_mid <= tc: |
| if abs(tc - t_mid) < abs(tn - t_mid): |
| z_mid = x_cur.clone() |
| cls_mid = cls_model_input.clone() |
|
|
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| model_input = torch.cat([x_cur] * 2, dim=0) |
| cls_model_input = torch.cat([cls_model_input] * 2, dim=0) |
| y_cur = torch.cat([y, y_null], dim=0) |
| else: |
| model_input = x_cur |
| y_cur = y |
|
|
| time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur |
|
|
| v_cur, _, cls_v_cur = model( |
| model_input.to(dtype=_dtype), |
| time_input.to(dtype=_dtype), |
| y_cur, |
| cls_token=cls_model_input.to(dtype=_dtype), |
| ) |
| v_cur = v_cur.to(torch.float64) |
| cls_v_cur = cls_v_cur.to(torch.float64) |
|
|
| |
| |
| d_cur = v_cur |
| cls_d_cur = cls_v_cur |
|
|
| if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: |
| d_cur_cond, d_cur_uncond = d_cur.chunk(2) |
| d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) |
|
|
| cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) |
| if cls_cfg > 0: |
| cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) |
| else: |
| cls_d_cur = cls_d_cur_cond |
|
|
| x_next = x_cur + dt * d_cur |
| if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9: |
| cls_x_next = cls_frozen |
| else: |
| cls_x_next = cls_x_cur + dt * cls_d_cur |
|
|
| if return_trajectory: |
| traj.append(x_next.clone()) |
|
|
| if return_mid_state and z_mid is None and tn <= t_mid <= tc: |
| z_mid = x_next.clone() |
| cls_mid = cls_x_next.clone() |
|
|
| if return_trajectory and return_mid_state: |
| return x_next, z_mid, cls_mid, traj |
| if return_trajectory: |
| return x_next, traj |
| if return_mid_state: |
| return x_next, z_mid, cls_mid |
| return x_next |
|
|