import torch import numpy as np def expand_t_like_x(t, x_cur): """Function to reshape time t to broadcastable dimension of x Args: t: [batch_dim,], time vector x: [batch_dim,...], data point """ dims = [1] * (len(x_cur.size()) - 1) t = t.view(t.size(0), *dims) return t def get_score_from_velocity(vt, xt, t, path_type="linear"): """Wrapper function: transfrom velocity prediction model to score Args: velocity: [batch_dim, ...] shaped tensor; velocity model output x: [batch_dim, ...] shaped tensor; x_t data point t: [batch_dim,] time tensor """ t = expand_t_like_x(t, xt) if path_type == "linear": alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1 sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device) elif path_type == "cosine": alpha_t = torch.cos(t * np.pi / 2) sigma_t = torch.sin(t * np.pi / 2) d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) else: raise NotImplementedError mean = xt reverse_alpha_ratio = alpha_t / d_alpha_t var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t score = (reverse_alpha_ratio * vt - mean) / var return score def compute_diffusion(t_cur): return 2 * t_cur def build_sampling_time_steps( num_steps=50, t_c=None, num_steps_before_tc=None, num_steps_after_tc=None, t_floor=0.04, ): """ 构造从 t=1 → t=0 的时间网格(与原先一致:最后一段到 0 前保留 t_floor,再接到 0)。 - 默认:均匀 linspace(1, t_floor, num_steps),再 append 0。 - 分段:t∈(t_c,1] 用 num_steps_before_tc 步(从 1 线性到 t_c); t∈[0,t_c] 用 num_steps_after_tc 步(从 t_c 线性到 t_floor),再 append 0。 """ t_floor = float(t_floor) if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None: ns = int(num_steps) if ns < 1: raise ValueError("num_steps must be >= 1") t_steps = torch.linspace(1.0, t_floor, ns, dtype=torch.float64) return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)]) t_c = float(t_c) nb = int(num_steps_before_tc) na = int(num_steps_after_tc) if nb < 1 or na < 1: raise ValueError("num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c") if not (0.0 < t_c < 1.0): raise ValueError("t_c must be in (0, 1)") if t_c <= t_floor: raise ValueError(f"t_c ({t_c}) must be > t_floor ({t_floor})") p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64) p2 = torch.linspace(t_c, t_floor, na + 1, dtype=torch.float64) t_steps = torch.cat([p1, p2[1:]]) return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)]) def _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc): """仅在 1→t_c→0 分段网格下启用:t∈[0,t_c] 段固定使用到达 t_c 时的 cls。""" return ( t_c is not None and num_steps_before_tc is not None and num_steps_after_tc is not None ) def _cls_effective_and_freeze( cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc ): """ 时间从 1 减到 0:当 t_cur <= t_c 时冻结 cls(取首次进入该段时的 cls_x_cur)。 返回 (用于前向的 cls, 更新后的 cls_frozen)。 """ if not freeze_after_tc or t_c_v is None: return cls_x_cur, cls_frozen if float(t_cur) <= float(t_c_v) + 1e-9: if cls_frozen is None: cls_frozen = cls_x_cur.clone() return cls_frozen, cls_frozen return cls_x_cur, cls_frozen def _build_euler_sampler_time_steps( num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device ): """ euler_sampler / REG ODE 用时间网格:默认 linspace(1,0);分段时为 1→t_c→0 直连,无 t_floor。 """ if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None: ns = int(num_steps) if ns < 1: raise ValueError("num_steps must be >= 1") return torch.linspace(1.0, 0.0, ns + 1, dtype=torch.float64, device=device) t_c = float(t_c) nb = int(num_steps_before_tc) na = int(num_steps_after_tc) if nb < 1 or na < 1: raise ValueError( "num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c" ) if not (0.0 < t_c < 1.0): raise ValueError("t_c must be in (0, 1)") p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64, device=device) p2 = torch.linspace(t_c, 0.0, na + 1, dtype=torch.float64, device=device) return torch.cat([p1, p2[1:]]) def euler_maruyama_sampler( model, latents, y, num_steps=20, heun=False, # not used, just for compatability cfg_scale=1.0, guidance_low=0.0, guidance_high=1.0, path_type="linear", cls_latents=None, args=None, return_mid_state=False, t_mid=0.5, t_c=None, num_steps_before_tc=None, num_steps_after_tc=None, deterministic=False, return_trajectory=False, ): """ Euler–Maruyama:漂移项与 score/velocity 变换与 euler_ode_sampler(euler_sampler)一致; deterministic=True 时关闭扩散噪声项。ODE 使用 euler_sampler 的 linspace(1→0) / t_c 分段网格(无 t_floor), 本函数仍用 build_sampling_time_steps(含 t_floor),与 EM/SDE 对齐。 """ # setup conditioning if cfg_scale > 1.0: y_null = torch.tensor([1000] * y.size(0), device=y.device) #[1000, 1000] _dtype = latents.dtype cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0 t_steps = build_sampling_time_steps( num_steps=num_steps, t_c=t_c, num_steps_before_tc=num_steps_before_tc, num_steps_after_tc=num_steps_after_tc, ) freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc) t_c_v = float(t_c) if freeze_after_tc else None x_next = latents.to(torch.float64) cls_x_next = cls_latents.to(torch.float64) device = x_next.device z_mid = cls_mid = None t_mid = float(t_mid) cls_frozen = None traj = [x_next.clone()] if return_trajectory else None with torch.no_grad(): for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])): dt = t_next - t_cur x_cur = x_next cls_x_cur = cls_x_next cls_model_input, cls_frozen = _cls_effective_and_freeze( cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc ) tc, tn = float(t_cur), float(t_next) if return_mid_state and z_mid is None and tn <= t_mid <= tc: if abs(tc - t_mid) < abs(tn - t_mid): z_mid = x_cur.clone() cls_mid = cls_model_input.clone() if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: model_input = torch.cat([x_cur] * 2, dim=0) cls_model_input = torch.cat([cls_model_input] * 2, dim=0) y_cur = torch.cat([y, y_null], dim=0) else: model_input = x_cur y_cur = y kwargs = dict(y=y_cur) time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur diffusion = compute_diffusion(t_cur) if deterministic: deps = torch.zeros_like(x_cur) cls_deps = torch.zeros_like(cls_model_input[: x_cur.size(0)]) else: eps_i = torch.randn_like(x_cur).to(device) cls_eps_i = torch.randn_like(cls_model_input[: x_cur.size(0)]).to(device) deps = eps_i * torch.sqrt(torch.abs(dt)) cls_deps = cls_eps_i * torch.sqrt(torch.abs(dt)) # compute drift v_cur, _, cls_v_cur = model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) ) v_cur = v_cur.to(torch.float64) cls_v_cur = cls_v_cur.to(torch.float64) s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) d_cur = v_cur - 0.5 * diffusion * s_cur cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: d_cur_cond, d_cur_uncond = d_cur.chunk(2) d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) if cls_cfg > 0: cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) else: cls_d_cur = cls_d_cur_cond x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9: cls_x_next = cls_frozen else: cls_x_next = cls_x_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps if return_trajectory: traj.append(x_next.clone()) if return_mid_state and z_mid is None and tn <= t_mid <= tc: z_mid = x_next.clone() cls_mid = cls_x_next.clone() # last step t_cur, t_next = t_steps[-2], t_steps[-1] dt = t_next - t_cur x_cur = x_next cls_x_cur = cls_x_next cls_model_input, cls_frozen = _cls_effective_and_freeze( cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc ) if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: model_input = torch.cat([x_cur] * 2, dim=0) cls_model_input = torch.cat([cls_model_input] * 2, dim=0) y_cur = torch.cat([y, y_null], dim=0) else: model_input = x_cur y_cur = y kwargs = dict(y=y_cur) time_input = torch.ones(model_input.size(0)).to( device=device, dtype=torch.float64 ) * t_cur # compute drift v_cur, _, cls_v_cur = model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) ) v_cur = v_cur.to(torch.float64) cls_v_cur = cls_v_cur.to(torch.float64) s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) diffusion = compute_diffusion(t_cur) d_cur = v_cur - 0.5 * diffusion * s_cur cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur # d_cur [b, 4, 32 ,32] if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: d_cur_cond, d_cur_uncond = d_cur.chunk(2) d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) if cls_cfg > 0: cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) else: cls_d_cur = cls_d_cur_cond mean_x = x_cur + dt * d_cur if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9: cls_mean_x = cls_frozen else: cls_mean_x = cls_x_cur + dt * cls_d_cur if return_trajectory: traj.append(mean_x.clone()) if return_trajectory and return_mid_state: return mean_x, z_mid, cls_mid, traj if return_trajectory: return mean_x, traj if return_mid_state: return mean_x, z_mid, cls_mid return mean_x def euler_maruyama_image_noise_sampler( model, latents, y, num_steps=20, heun=False, # not used, just for compatability cfg_scale=1.0, guidance_low=0.0, guidance_high=1.0, path_type="linear", cls_latents=None, args=None, return_mid_state=False, t_mid=0.5, t_c=None, num_steps_before_tc=None, num_steps_after_tc=None, return_trajectory=False, ): """ EM 采样变体:仅图像 latent 引入随机扩散噪声,cls/token 通道不引入随机项(deterministic)。 """ if cfg_scale > 1.0: y_null = torch.tensor([1000] * y.size(0), device=y.device) _dtype = latents.dtype cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0 t_steps = build_sampling_time_steps( num_steps=num_steps, t_c=t_c, num_steps_before_tc=num_steps_before_tc, num_steps_after_tc=num_steps_after_tc, ) freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc) t_c_v = float(t_c) if freeze_after_tc else None x_next = latents.to(torch.float64) cls_x_next = cls_latents.to(torch.float64) device = x_next.device z_mid = cls_mid = None t_mid = float(t_mid) cls_frozen = None traj = [x_next.clone()] if return_trajectory else None with torch.no_grad(): for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]): dt = t_next - t_cur x_cur = x_next cls_x_cur = cls_x_next cls_model_input, cls_frozen = _cls_effective_and_freeze( cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc ) tc, tn = float(t_cur), float(t_next) if return_mid_state and z_mid is None and tn <= t_mid <= tc: if abs(tc - t_mid) < abs(tn - t_mid): z_mid = x_cur.clone() cls_mid = cls_model_input.clone() if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: model_input = torch.cat([x_cur] * 2, dim=0) cls_model_input = torch.cat([cls_model_input] * 2, dim=0) y_cur = torch.cat([y, y_null], dim=0) else: model_input = x_cur y_cur = y kwargs = dict(y=y_cur) time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur diffusion = compute_diffusion(t_cur) eps_i = torch.randn_like(x_cur).to(device) deps = eps_i * torch.sqrt(torch.abs(dt)) v_cur, _, cls_v_cur = model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) ) v_cur = v_cur.to(torch.float64) cls_v_cur = cls_v_cur.to(torch.float64) if add_img_noise: s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) d_cur = v_cur - 0.5 * diffusion * s_cur cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur else: # t<=t_c 去随机阶段:与当前 ODE 逻辑一致,直接 d=v。 d_cur = v_cur cls_d_cur = cls_v_cur if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: d_cur_cond, d_cur_uncond = d_cur.chunk(2) d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) if cls_cfg > 0: cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) else: cls_d_cur = cls_d_cur_cond # 图像 latent 有随机扩散噪声;cls/token 仅走漂移(不加随机项) x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9: cls_x_next = cls_frozen else: cls_x_next = cls_x_cur + cls_d_cur * dt if return_trajectory: traj.append(x_next.clone()) if return_mid_state and z_mid is None and tn <= t_mid <= tc: z_mid = x_next.clone() cls_mid = cls_x_next.clone() t_cur, t_next = t_steps[-2], t_steps[-1] dt = t_next - t_cur x_cur = x_next cls_x_cur = cls_x_next cls_model_input, cls_frozen = _cls_effective_and_freeze( cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc ) if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: model_input = torch.cat([x_cur] * 2, dim=0) cls_model_input = torch.cat([cls_model_input] * 2, dim=0) y_cur = torch.cat([y, y_null], dim=0) else: model_input = x_cur y_cur = y kwargs = dict(y=y_cur) time_input = torch.ones(model_input.size(0)).to( device=device, dtype=torch.float64 ) * t_cur v_cur, _, cls_v_cur = model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) ) v_cur = v_cur.to(torch.float64) cls_v_cur = cls_v_cur.to(torch.float64) # 最后一步本身无随机项,也与 ODE 对齐使用 velocity 漂移。 d_cur = v_cur cls_d_cur = cls_v_cur if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: d_cur_cond, d_cur_uncond = d_cur.chunk(2) d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) if cls_cfg > 0: cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) else: cls_d_cur = cls_d_cur_cond mean_x = x_cur + dt * d_cur if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9: cls_mean_x = cls_frozen else: cls_mean_x = cls_x_cur + dt * cls_d_cur if return_trajectory and return_mid_state: return mean_x, z_mid, cls_mid, traj if return_trajectory: return mean_x, traj if return_mid_state: return mean_x, z_mid, cls_mid return mean_x def euler_maruyama_image_noise_before_tc_sampler( model, latents, y, num_steps=20, heun=False, # not used, just for compatability cfg_scale=1.0, guidance_low=0.0, guidance_high=1.0, path_type="linear", cls_latents=None, args=None, return_mid_state=False, t_mid=0.5, t_c=None, num_steps_before_tc=None, num_steps_after_tc=None, return_cls_final=False, return_trajectory=False, ): """ EM 采样变体: - 图像 latent 在 t > t_c 区间引入随机扩散噪声; - 图像 latent 在 t <= t_c 区间不引入随机项(仅漂移); - cls/token 通道全程不引入随机项。 """ if cfg_scale > 1.0: y_null = torch.tensor([1000] * y.size(0), device=y.device) _dtype = latents.dtype cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0 t_steps = build_sampling_time_steps( num_steps=num_steps, t_c=t_c, num_steps_before_tc=num_steps_before_tc, num_steps_after_tc=num_steps_after_tc, ) freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc) t_c_freeze = float(t_c) if freeze_after_tc else None x_next = latents.to(torch.float64) cls_x_next = cls_latents.to(torch.float64) device = x_next.device z_mid = cls_mid = None t_mid = float(t_mid) t_c_v = None if t_c is None else float(t_c) cls_frozen = None traj = [x_next.clone()] if return_trajectory else None with torch.no_grad(): for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]): dt = t_next - t_cur x_cur = x_next cls_x_cur = cls_x_next cls_model_input, cls_frozen = _cls_effective_and_freeze( cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc ) tc, tn = float(t_cur), float(t_next) if return_mid_state and z_mid is None and tn <= t_mid <= tc: if abs(tc - t_mid) < abs(tn - t_mid): z_mid = x_cur.clone() cls_mid = cls_model_input.clone() if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: model_input = torch.cat([x_cur] * 2, dim=0) cls_model_input = torch.cat([cls_model_input] * 2, dim=0) y_cur = torch.cat([y, y_null], dim=0) else: model_input = x_cur y_cur = y kwargs = dict(y=y_cur) time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur diffusion = compute_diffusion(t_cur) # 跨过/进入 t_c 后关闭图像随机性;t>t_c 区间保留图像噪声 add_img_noise = True if t_c_v is not None and float(t_next) <= t_c_v: add_img_noise = False eps_i = torch.randn_like(x_cur).to(device) if add_img_noise else torch.zeros_like(x_cur) deps = eps_i * torch.sqrt(torch.abs(dt)) v_cur, _, cls_v_cur = model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) ) v_cur = v_cur.to(torch.float64) cls_v_cur = cls_v_cur.to(torch.float64) if add_img_noise: s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type) d_cur = v_cur - 0.5 * diffusion * s_cur cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type) cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur else: # t<=t_c 去随机段:使用显式欧拉 + velocity 漂移(不使用修正漂移项) d_cur = v_cur cls_d_cur = cls_v_cur if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: d_cur_cond, d_cur_uncond = d_cur.chunk(2) d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) if cls_cfg > 0: cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) else: cls_d_cur = cls_d_cur_cond x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9: cls_x_next = cls_frozen else: cls_x_next = cls_x_cur + cls_d_cur * dt if return_trajectory: traj.append(x_next.clone()) if return_mid_state and z_mid is None and tn <= t_mid <= tc: z_mid = x_next.clone() cls_mid = cls_x_next.clone() t_cur, t_next = t_steps[-2], t_steps[-1] dt = t_next - t_cur x_cur = x_next cls_x_cur = cls_x_next cls_model_input, cls_frozen = _cls_effective_and_freeze( cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc ) if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: model_input = torch.cat([x_cur] * 2, dim=0) cls_model_input = torch.cat([cls_model_input] * 2, dim=0) y_cur = torch.cat([y, y_null], dim=0) else: model_input = x_cur y_cur = y kwargs = dict(y=y_cur) time_input = torch.ones(model_input.size(0)).to( device=device, dtype=torch.float64 ) * t_cur v_cur, _, cls_v_cur = model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype) ) v_cur = v_cur.to(torch.float64) cls_v_cur = cls_v_cur.to(torch.float64) # 最后一步无随机项,保持与 ODE 一致使用 d=v。 d_cur = v_cur cls_d_cur = cls_v_cur if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low: d_cur_cond, d_cur_uncond = d_cur.chunk(2) d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) if cls_cfg > 0: cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) else: cls_d_cur = cls_d_cur_cond mean_x = x_cur + dt * d_cur if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9: cls_mean_x = cls_frozen else: cls_mean_x = cls_x_cur + dt * cls_d_cur if return_trajectory and return_mid_state and return_cls_final: return mean_x, z_mid, cls_mid, cls_mean_x, traj if return_trajectory and return_mid_state: return mean_x, z_mid, cls_mid, traj if return_trajectory and return_cls_final: return mean_x, cls_mean_x, traj if return_trajectory: return mean_x, traj if return_mid_state and return_cls_final: return mean_x, z_mid, cls_mid, cls_mean_x if return_mid_state: return mean_x, z_mid, cls_mid if return_cls_final: return mean_x, cls_mean_x return mean_x def euler_ode_sampler( model, latents, y, num_steps=20, cfg_scale=1.0, guidance_low=0.0, guidance_high=1.0, path_type="linear", cls_latents=None, args=None, return_mid_state=False, t_mid=0.5, t_c=None, num_steps_before_tc=None, num_steps_after_tc=None, return_trajectory=False, ): """ REG 的 ODE 入口:与 SDE 采样器解耦,直接委托 euler_sampler(linspace 1→0 或 t_c 分段,无 t_floor)。 """ return euler_sampler( model, latents, y, num_steps=num_steps, heun=False, cfg_scale=cfg_scale, guidance_low=guidance_low, guidance_high=guidance_high, path_type=path_type, cls_latents=cls_latents, args=args, return_mid_state=return_mid_state, t_mid=t_mid, t_c=t_c, num_steps_before_tc=num_steps_before_tc, num_steps_after_tc=num_steps_after_tc, return_trajectory=return_trajectory, ) def euler_sampler( model, latents, y, num_steps=20, heun=False, cfg_scale=1.0, guidance_low=0.0, guidance_high=1.0, path_type="linear", cls_latents=None, args=None, return_mid_state=False, t_mid=0.5, t_c=None, num_steps_before_tc=None, num_steps_after_tc=None, return_trajectory=False, ): """ 轻量确定性漂移采样(与 glflow 同名同参的前缀兼容:model, latents, y, num_steps, heun, cfg, guidance, path_type, cls_latents, args)。 - 默认:linspace(1, 0, num_steps+1),无 t_floor(与原先独立 ODE 一致)。 - 可选:同时传入 t_c、num_steps_before_tc、num_steps_after_tc 时,网格为 1→t_c→0;并与 EM 一致在 t≤t_c 段冻结 cls。 - 可选:return_mid_state / return_trajectory 供 train.py 与 sample_from_checkpoint 使用。 REG 的 SiT 需要 cls_token;cls_latents 不可为 None。heun 占位未使用。 """ if cls_latents is None: raise ValueError( "euler_sampler: 本仓库 REG SiT 需要 cls_token,请传入 cls_latents(例如高斯噪声或训练中的 cls 初值)。" ) if cfg_scale > 1.0: y_null = torch.full((y.size(0),), 1000, device=y.device, dtype=y.dtype) else: y_null = None _dtype = latents.dtype cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0 device = latents.device t_steps = _build_euler_sampler_time_steps( num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device ) freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc) t_c_v = float(t_c) if freeze_after_tc else None x_next = latents.to(torch.float64) cls_x_next = cls_latents.to(torch.float64) z_mid = cls_mid = None t_mid = float(t_mid) cls_frozen = None traj = [x_next.clone()] if return_trajectory else None with torch.no_grad(): for t_cur, t_next in zip(t_steps[:-1], t_steps[1:]): dt = t_next - t_cur x_cur = x_next cls_x_cur = cls_x_next cls_model_input, cls_frozen = _cls_effective_and_freeze( cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc ) tc, tn = float(t_cur), float(t_next) if return_mid_state and z_mid is None and tn <= t_mid <= tc: if abs(tc - t_mid) < abs(tn - t_mid): z_mid = x_cur.clone() cls_mid = cls_model_input.clone() if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: model_input = torch.cat([x_cur] * 2, dim=0) cls_model_input = torch.cat([cls_model_input] * 2, dim=0) y_cur = torch.cat([y, y_null], dim=0) else: model_input = x_cur y_cur = y time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur v_cur, _, cls_v_cur = model( model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), y_cur, cls_token=cls_model_input.to(dtype=_dtype), ) v_cur = v_cur.to(torch.float64) cls_v_cur = cls_v_cur.to(torch.float64) # ODE: follow velocity parameterization directly (d/dt x_t = v_t). # This aligns with velocity training target and avoids extra v->score->drift conversion. d_cur = v_cur cls_d_cur = cls_v_cur if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low: d_cur_cond, d_cur_uncond = d_cur.chunk(2) d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond) cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2) if cls_cfg > 0: cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond) else: cls_d_cur = cls_d_cur_cond x_next = x_cur + dt * d_cur if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9: cls_x_next = cls_frozen else: cls_x_next = cls_x_cur + dt * cls_d_cur if return_trajectory: traj.append(x_next.clone()) if return_mid_state and z_mid is None and tn <= t_mid <= tc: z_mid = x_next.clone() cls_mid = cls_x_next.clone() if return_trajectory and return_mid_state: return x_next, z_mid, cls_mid, traj if return_trajectory: return x_next, traj if return_mid_state: return x_next, z_mid, cls_mid return x_next