jsflow / REG /samplers.py
xiangzai's picture
Add files using upload-large-folder tool
b65e56d verified
import torch
import numpy as np
def expand_t_like_x(t, x_cur):
"""Function to reshape time t to broadcastable dimension of x
Args:
t: [batch_dim,], time vector
x: [batch_dim,...], data point
"""
dims = [1] * (len(x_cur.size()) - 1)
t = t.view(t.size(0), *dims)
return t
def get_score_from_velocity(vt, xt, t, path_type="linear"):
"""Wrapper function: transfrom velocity prediction model to score
Args:
velocity: [batch_dim, ...] shaped tensor; velocity model output
x: [batch_dim, ...] shaped tensor; x_t data point
t: [batch_dim,] time tensor
"""
t = expand_t_like_x(t, xt)
if path_type == "linear":
alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1
sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device)
elif path_type == "cosine":
alpha_t = torch.cos(t * np.pi / 2)
sigma_t = torch.sin(t * np.pi / 2)
d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
else:
raise NotImplementedError
mean = xt
reverse_alpha_ratio = alpha_t / d_alpha_t
var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
score = (reverse_alpha_ratio * vt - mean) / var
return score
def compute_diffusion(t_cur):
return 2 * t_cur
def build_sampling_time_steps(
num_steps=50,
t_c=None,
num_steps_before_tc=None,
num_steps_after_tc=None,
t_floor=0.04,
):
"""
构造从 t=1 → t=0 的时间网格(与原先一致:最后一段到 0 前保留 t_floor,再接到 0)。
- 默认:均匀 linspace(1, t_floor, num_steps),再 append 0。
- 分段:t∈(t_c,1] 用 num_steps_before_tc 步(从 1 线性到 t_c);
t∈[0,t_c] 用 num_steps_after_tc 步(从 t_c 线性到 t_floor),再 append 0。
"""
t_floor = float(t_floor)
if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None:
ns = int(num_steps)
if ns < 1:
raise ValueError("num_steps must be >= 1")
t_steps = torch.linspace(1.0, t_floor, ns, dtype=torch.float64)
return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
t_c = float(t_c)
nb = int(num_steps_before_tc)
na = int(num_steps_after_tc)
if nb < 1 or na < 1:
raise ValueError("num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c")
if not (0.0 < t_c < 1.0):
raise ValueError("t_c must be in (0, 1)")
if t_c <= t_floor:
raise ValueError(f"t_c ({t_c}) must be > t_floor ({t_floor})")
p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64)
p2 = torch.linspace(t_c, t_floor, na + 1, dtype=torch.float64)
t_steps = torch.cat([p1, p2[1:]])
return torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
def _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc):
"""仅在 1→t_c→0 分段网格下启用:t∈[0,t_c] 段固定使用到达 t_c 时的 cls。"""
return (
t_c is not None
and num_steps_before_tc is not None
and num_steps_after_tc is not None
)
def _cls_effective_and_freeze(
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
):
"""
时间从 1 减到 0:当 t_cur <= t_c 时冻结 cls(取首次进入该段时的 cls_x_cur)。
返回 (用于前向的 cls, 更新后的 cls_frozen)。
"""
if not freeze_after_tc or t_c_v is None:
return cls_x_cur, cls_frozen
if float(t_cur) <= float(t_c_v) + 1e-9:
if cls_frozen is None:
cls_frozen = cls_x_cur.clone()
return cls_frozen, cls_frozen
return cls_x_cur, cls_frozen
def _build_euler_sampler_time_steps(
num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device
):
"""
euler_sampler / REG ODE 用时间网格:默认 linspace(1,0);分段时为 1→t_c→0 直连,无 t_floor。
"""
if t_c is None or num_steps_before_tc is None or num_steps_after_tc is None:
ns = int(num_steps)
if ns < 1:
raise ValueError("num_steps must be >= 1")
return torch.linspace(1.0, 0.0, ns + 1, dtype=torch.float64, device=device)
t_c = float(t_c)
nb = int(num_steps_before_tc)
na = int(num_steps_after_tc)
if nb < 1 or na < 1:
raise ValueError(
"num_steps_before_tc and num_steps_after_tc must be >= 1 when using t_c"
)
if not (0.0 < t_c < 1.0):
raise ValueError("t_c must be in (0, 1)")
p1 = torch.linspace(1.0, t_c, nb + 1, dtype=torch.float64, device=device)
p2 = torch.linspace(t_c, 0.0, na + 1, dtype=torch.float64, device=device)
return torch.cat([p1, p2[1:]])
def euler_maruyama_sampler(
model,
latents,
y,
num_steps=20,
heun=False, # not used, just for compatability
cfg_scale=1.0,
guidance_low=0.0,
guidance_high=1.0,
path_type="linear",
cls_latents=None,
args=None,
return_mid_state=False,
t_mid=0.5,
t_c=None,
num_steps_before_tc=None,
num_steps_after_tc=None,
deterministic=False,
return_trajectory=False,
):
"""
Euler–Maruyama:漂移项与 score/velocity 变换与 euler_ode_sampler(euler_sampler)一致;
deterministic=True 时关闭扩散噪声项。ODE 使用 euler_sampler 的 linspace(1→0) / t_c 分段网格(无 t_floor),
本函数仍用 build_sampling_time_steps(含 t_floor),与 EM/SDE 对齐。
"""
# setup conditioning
if cfg_scale > 1.0:
y_null = torch.tensor([1000] * y.size(0), device=y.device)
#[1000, 1000]
_dtype = latents.dtype
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
t_steps = build_sampling_time_steps(
num_steps=num_steps,
t_c=t_c,
num_steps_before_tc=num_steps_before_tc,
num_steps_after_tc=num_steps_after_tc,
)
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
t_c_v = float(t_c) if freeze_after_tc else None
x_next = latents.to(torch.float64)
cls_x_next = cls_latents.to(torch.float64)
device = x_next.device
z_mid = cls_mid = None
t_mid = float(t_mid)
cls_frozen = None
traj = [x_next.clone()] if return_trajectory else None
with torch.no_grad():
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):
dt = t_next - t_cur
x_cur = x_next
cls_x_cur = cls_x_next
cls_model_input, cls_frozen = _cls_effective_and_freeze(
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
)
tc, tn = float(t_cur), float(t_next)
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
if abs(tc - t_mid) < abs(tn - t_mid):
z_mid = x_cur.clone()
cls_mid = cls_model_input.clone()
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
model_input = torch.cat([x_cur] * 2, dim=0)
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
y_cur = torch.cat([y, y_null], dim=0)
else:
model_input = x_cur
y_cur = y
kwargs = dict(y=y_cur)
time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
diffusion = compute_diffusion(t_cur)
if deterministic:
deps = torch.zeros_like(x_cur)
cls_deps = torch.zeros_like(cls_model_input[: x_cur.size(0)])
else:
eps_i = torch.randn_like(x_cur).to(device)
cls_eps_i = torch.randn_like(cls_model_input[: x_cur.size(0)]).to(device)
deps = eps_i * torch.sqrt(torch.abs(dt))
cls_deps = cls_eps_i * torch.sqrt(torch.abs(dt))
# compute drift
v_cur, _, cls_v_cur = model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
)
v_cur = v_cur.to(torch.float64)
cls_v_cur = cls_v_cur.to(torch.float64)
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
d_cur = v_cur - 0.5 * diffusion * s_cur
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
if cls_cfg > 0:
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
else:
cls_d_cur = cls_d_cur_cond
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
cls_x_next = cls_frozen
else:
cls_x_next = cls_x_cur + cls_d_cur * dt + torch.sqrt(diffusion) * cls_deps
if return_trajectory:
traj.append(x_next.clone())
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
z_mid = x_next.clone()
cls_mid = cls_x_next.clone()
# last step
t_cur, t_next = t_steps[-2], t_steps[-1]
dt = t_next - t_cur
x_cur = x_next
cls_x_cur = cls_x_next
cls_model_input, cls_frozen = _cls_effective_and_freeze(
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
)
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
model_input = torch.cat([x_cur] * 2, dim=0)
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
y_cur = torch.cat([y, y_null], dim=0)
else:
model_input = x_cur
y_cur = y
kwargs = dict(y=y_cur)
time_input = torch.ones(model_input.size(0)).to(
device=device, dtype=torch.float64
) * t_cur
# compute drift
v_cur, _, cls_v_cur = model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
)
v_cur = v_cur.to(torch.float64)
cls_v_cur = cls_v_cur.to(torch.float64)
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
diffusion = compute_diffusion(t_cur)
d_cur = v_cur - 0.5 * diffusion * s_cur
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur # d_cur [b, 4, 32 ,32]
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
if cls_cfg > 0:
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
else:
cls_d_cur = cls_d_cur_cond
mean_x = x_cur + dt * d_cur
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
cls_mean_x = cls_frozen
else:
cls_mean_x = cls_x_cur + dt * cls_d_cur
if return_trajectory:
traj.append(mean_x.clone())
if return_trajectory and return_mid_state:
return mean_x, z_mid, cls_mid, traj
if return_trajectory:
return mean_x, traj
if return_mid_state:
return mean_x, z_mid, cls_mid
return mean_x
def euler_maruyama_image_noise_sampler(
model,
latents,
y,
num_steps=20,
heun=False, # not used, just for compatability
cfg_scale=1.0,
guidance_low=0.0,
guidance_high=1.0,
path_type="linear",
cls_latents=None,
args=None,
return_mid_state=False,
t_mid=0.5,
t_c=None,
num_steps_before_tc=None,
num_steps_after_tc=None,
return_trajectory=False,
):
"""
EM 采样变体:仅图像 latent 引入随机扩散噪声,cls/token 通道不引入随机项(deterministic)。
"""
if cfg_scale > 1.0:
y_null = torch.tensor([1000] * y.size(0), device=y.device)
_dtype = latents.dtype
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
t_steps = build_sampling_time_steps(
num_steps=num_steps,
t_c=t_c,
num_steps_before_tc=num_steps_before_tc,
num_steps_after_tc=num_steps_after_tc,
)
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
t_c_v = float(t_c) if freeze_after_tc else None
x_next = latents.to(torch.float64)
cls_x_next = cls_latents.to(torch.float64)
device = x_next.device
z_mid = cls_mid = None
t_mid = float(t_mid)
cls_frozen = None
traj = [x_next.clone()] if return_trajectory else None
with torch.no_grad():
for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]):
dt = t_next - t_cur
x_cur = x_next
cls_x_cur = cls_x_next
cls_model_input, cls_frozen = _cls_effective_and_freeze(
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
)
tc, tn = float(t_cur), float(t_next)
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
if abs(tc - t_mid) < abs(tn - t_mid):
z_mid = x_cur.clone()
cls_mid = cls_model_input.clone()
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
model_input = torch.cat([x_cur] * 2, dim=0)
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
y_cur = torch.cat([y, y_null], dim=0)
else:
model_input = x_cur
y_cur = y
kwargs = dict(y=y_cur)
time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
diffusion = compute_diffusion(t_cur)
eps_i = torch.randn_like(x_cur).to(device)
deps = eps_i * torch.sqrt(torch.abs(dt))
v_cur, _, cls_v_cur = model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
)
v_cur = v_cur.to(torch.float64)
cls_v_cur = cls_v_cur.to(torch.float64)
if add_img_noise:
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
d_cur = v_cur - 0.5 * diffusion * s_cur
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
else:
# t<=t_c 去随机阶段:与当前 ODE 逻辑一致,直接 d=v。
d_cur = v_cur
cls_d_cur = cls_v_cur
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
if cls_cfg > 0:
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
else:
cls_d_cur = cls_d_cur_cond
# 图像 latent 有随机扩散噪声;cls/token 仅走漂移(不加随机项)
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
cls_x_next = cls_frozen
else:
cls_x_next = cls_x_cur + cls_d_cur * dt
if return_trajectory:
traj.append(x_next.clone())
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
z_mid = x_next.clone()
cls_mid = cls_x_next.clone()
t_cur, t_next = t_steps[-2], t_steps[-1]
dt = t_next - t_cur
x_cur = x_next
cls_x_cur = cls_x_next
cls_model_input, cls_frozen = _cls_effective_and_freeze(
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
)
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
model_input = torch.cat([x_cur] * 2, dim=0)
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
y_cur = torch.cat([y, y_null], dim=0)
else:
model_input = x_cur
y_cur = y
kwargs = dict(y=y_cur)
time_input = torch.ones(model_input.size(0)).to(
device=device, dtype=torch.float64
) * t_cur
v_cur, _, cls_v_cur = model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
)
v_cur = v_cur.to(torch.float64)
cls_v_cur = cls_v_cur.to(torch.float64)
# 最后一步本身无随机项,也与 ODE 对齐使用 velocity 漂移。
d_cur = v_cur
cls_d_cur = cls_v_cur
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
if cls_cfg > 0:
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
else:
cls_d_cur = cls_d_cur_cond
mean_x = x_cur + dt * d_cur
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
cls_mean_x = cls_frozen
else:
cls_mean_x = cls_x_cur + dt * cls_d_cur
if return_trajectory and return_mid_state:
return mean_x, z_mid, cls_mid, traj
if return_trajectory:
return mean_x, traj
if return_mid_state:
return mean_x, z_mid, cls_mid
return mean_x
def euler_maruyama_image_noise_before_tc_sampler(
model,
latents,
y,
num_steps=20,
heun=False, # not used, just for compatability
cfg_scale=1.0,
guidance_low=0.0,
guidance_high=1.0,
path_type="linear",
cls_latents=None,
args=None,
return_mid_state=False,
t_mid=0.5,
t_c=None,
num_steps_before_tc=None,
num_steps_after_tc=None,
return_cls_final=False,
return_trajectory=False,
):
"""
EM 采样变体:
- 图像 latent 在 t > t_c 区间引入随机扩散噪声;
- 图像 latent 在 t <= t_c 区间不引入随机项(仅漂移);
- cls/token 通道全程不引入随机项。
"""
if cfg_scale > 1.0:
y_null = torch.tensor([1000] * y.size(0), device=y.device)
_dtype = latents.dtype
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
t_steps = build_sampling_time_steps(
num_steps=num_steps,
t_c=t_c,
num_steps_before_tc=num_steps_before_tc,
num_steps_after_tc=num_steps_after_tc,
)
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
t_c_freeze = float(t_c) if freeze_after_tc else None
x_next = latents.to(torch.float64)
cls_x_next = cls_latents.to(torch.float64)
device = x_next.device
z_mid = cls_mid = None
t_mid = float(t_mid)
t_c_v = None if t_c is None else float(t_c)
cls_frozen = None
traj = [x_next.clone()] if return_trajectory else None
with torch.no_grad():
for t_cur, t_next in zip(t_steps[:-2], t_steps[1:-1]):
dt = t_next - t_cur
x_cur = x_next
cls_x_cur = cls_x_next
cls_model_input, cls_frozen = _cls_effective_and_freeze(
cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc
)
tc, tn = float(t_cur), float(t_next)
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
if abs(tc - t_mid) < abs(tn - t_mid):
z_mid = x_cur.clone()
cls_mid = cls_model_input.clone()
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
model_input = torch.cat([x_cur] * 2, dim=0)
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
y_cur = torch.cat([y, y_null], dim=0)
else:
model_input = x_cur
y_cur = y
kwargs = dict(y=y_cur)
time_input = torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
diffusion = compute_diffusion(t_cur)
# 跨过/进入 t_c 后关闭图像随机性;t>t_c 区间保留图像噪声
add_img_noise = True
if t_c_v is not None and float(t_next) <= t_c_v:
add_img_noise = False
eps_i = torch.randn_like(x_cur).to(device) if add_img_noise else torch.zeros_like(x_cur)
deps = eps_i * torch.sqrt(torch.abs(dt))
v_cur, _, cls_v_cur = model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
)
v_cur = v_cur.to(torch.float64)
cls_v_cur = cls_v_cur.to(torch.float64)
if add_img_noise:
s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
d_cur = v_cur - 0.5 * diffusion * s_cur
cls_s_cur = get_score_from_velocity(cls_v_cur, cls_model_input, time_input, path_type=path_type)
cls_d_cur = cls_v_cur - 0.5 * diffusion * cls_s_cur
else:
# t<=t_c 去随机段:使用显式欧拉 + velocity 漂移(不使用修正漂移项)
d_cur = v_cur
cls_d_cur = cls_v_cur
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
if cls_cfg > 0:
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
else:
cls_d_cur = cls_d_cur_cond
x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps
if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9:
cls_x_next = cls_frozen
else:
cls_x_next = cls_x_cur + cls_d_cur * dt
if return_trajectory:
traj.append(x_next.clone())
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
z_mid = x_next.clone()
cls_mid = cls_x_next.clone()
t_cur, t_next = t_steps[-2], t_steps[-1]
dt = t_next - t_cur
x_cur = x_next
cls_x_cur = cls_x_next
cls_model_input, cls_frozen = _cls_effective_and_freeze(
cls_x_cur, cls_frozen, t_cur, t_c_freeze, freeze_after_tc
)
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
model_input = torch.cat([x_cur] * 2, dim=0)
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
y_cur = torch.cat([y, y_null], dim=0)
else:
model_input = x_cur
y_cur = y
kwargs = dict(y=y_cur)
time_input = torch.ones(model_input.size(0)).to(
device=device, dtype=torch.float64
) * t_cur
v_cur, _, cls_v_cur = model(
model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs, cls_token=cls_model_input.to(dtype=_dtype)
)
v_cur = v_cur.to(torch.float64)
cls_v_cur = cls_v_cur.to(torch.float64)
# 最后一步无随机项,保持与 ODE 一致使用 d=v。
d_cur = v_cur
cls_d_cur = cls_v_cur
if cfg_scale > 1. and t_cur <= guidance_high and t_cur >= guidance_low:
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
if cls_cfg > 0:
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
else:
cls_d_cur = cls_d_cur_cond
mean_x = x_cur + dt * d_cur
if freeze_after_tc and t_c_freeze is not None and float(t_cur) <= float(t_c_freeze) + 1e-9:
cls_mean_x = cls_frozen
else:
cls_mean_x = cls_x_cur + dt * cls_d_cur
if return_trajectory and return_mid_state and return_cls_final:
return mean_x, z_mid, cls_mid, cls_mean_x, traj
if return_trajectory and return_mid_state:
return mean_x, z_mid, cls_mid, traj
if return_trajectory and return_cls_final:
return mean_x, cls_mean_x, traj
if return_trajectory:
return mean_x, traj
if return_mid_state and return_cls_final:
return mean_x, z_mid, cls_mid, cls_mean_x
if return_mid_state:
return mean_x, z_mid, cls_mid
if return_cls_final:
return mean_x, cls_mean_x
return mean_x
def euler_ode_sampler(
model,
latents,
y,
num_steps=20,
cfg_scale=1.0,
guidance_low=0.0,
guidance_high=1.0,
path_type="linear",
cls_latents=None,
args=None,
return_mid_state=False,
t_mid=0.5,
t_c=None,
num_steps_before_tc=None,
num_steps_after_tc=None,
return_trajectory=False,
):
"""
REG 的 ODE 入口:与 SDE 采样器解耦,直接委托 euler_sampler(linspace 1→0 或 t_c 分段,无 t_floor)。
"""
return euler_sampler(
model,
latents,
y,
num_steps=num_steps,
heun=False,
cfg_scale=cfg_scale,
guidance_low=guidance_low,
guidance_high=guidance_high,
path_type=path_type,
cls_latents=cls_latents,
args=args,
return_mid_state=return_mid_state,
t_mid=t_mid,
t_c=t_c,
num_steps_before_tc=num_steps_before_tc,
num_steps_after_tc=num_steps_after_tc,
return_trajectory=return_trajectory,
)
def euler_sampler(
model,
latents,
y,
num_steps=20,
heun=False,
cfg_scale=1.0,
guidance_low=0.0,
guidance_high=1.0,
path_type="linear",
cls_latents=None,
args=None,
return_mid_state=False,
t_mid=0.5,
t_c=None,
num_steps_before_tc=None,
num_steps_after_tc=None,
return_trajectory=False,
):
"""
轻量确定性漂移采样(与 glflow 同名同参的前缀兼容:model, latents, y, num_steps, heun, cfg, guidance, path_type, cls_latents, args)。
- 默认:linspace(1, 0, num_steps+1),无 t_floor(与原先独立 ODE 一致)。
- 可选:同时传入 t_c、num_steps_before_tc、num_steps_after_tc 时,网格为 1→t_c→0;并与 EM 一致在 t≤t_c 段冻结 cls。
- 可选:return_mid_state / return_trajectory 供 train.py 与 sample_from_checkpoint 使用。
REG 的 SiT 需要 cls_token;cls_latents 不可为 None。heun 占位未使用。
"""
if cls_latents is None:
raise ValueError(
"euler_sampler: 本仓库 REG SiT 需要 cls_token,请传入 cls_latents(例如高斯噪声或训练中的 cls 初值)。"
)
if cfg_scale > 1.0:
y_null = torch.full((y.size(0),), 1000, device=y.device, dtype=y.dtype)
else:
y_null = None
_dtype = latents.dtype
cls_cfg = getattr(args, "cls_cfg_scale", 0) if args is not None else 0
device = latents.device
t_steps = _build_euler_sampler_time_steps(
num_steps, t_c, num_steps_before_tc, num_steps_after_tc, device
)
freeze_after_tc = _tc_segmented_freeze_cls(t_c, num_steps_before_tc, num_steps_after_tc)
t_c_v = float(t_c) if freeze_after_tc else None
x_next = latents.to(torch.float64)
cls_x_next = cls_latents.to(torch.float64)
z_mid = cls_mid = None
t_mid = float(t_mid)
cls_frozen = None
traj = [x_next.clone()] if return_trajectory else None
with torch.no_grad():
for t_cur, t_next in zip(t_steps[:-1], t_steps[1:]):
dt = t_next - t_cur
x_cur = x_next
cls_x_cur = cls_x_next
cls_model_input, cls_frozen = _cls_effective_and_freeze(
cls_x_cur, cls_frozen, t_cur, t_c_v, freeze_after_tc
)
tc, tn = float(t_cur), float(t_next)
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
if abs(tc - t_mid) < abs(tn - t_mid):
z_mid = x_cur.clone()
cls_mid = cls_model_input.clone()
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
model_input = torch.cat([x_cur] * 2, dim=0)
cls_model_input = torch.cat([cls_model_input] * 2, dim=0)
y_cur = torch.cat([y, y_null], dim=0)
else:
model_input = x_cur
y_cur = y
time_input = torch.ones(model_input.size(0), device=device, dtype=torch.float64) * t_cur
v_cur, _, cls_v_cur = model(
model_input.to(dtype=_dtype),
time_input.to(dtype=_dtype),
y_cur,
cls_token=cls_model_input.to(dtype=_dtype),
)
v_cur = v_cur.to(torch.float64)
cls_v_cur = cls_v_cur.to(torch.float64)
# ODE: follow velocity parameterization directly (d/dt x_t = v_t).
# This aligns with velocity training target and avoids extra v->score->drift conversion.
d_cur = v_cur
cls_d_cur = cls_v_cur
if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
d_cur_cond, d_cur_uncond = d_cur.chunk(2)
d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
cls_d_cur_cond, cls_d_cur_uncond = cls_d_cur.chunk(2)
if cls_cfg > 0:
cls_d_cur = cls_d_cur_uncond + cls_cfg * (cls_d_cur_cond - cls_d_cur_uncond)
else:
cls_d_cur = cls_d_cur_cond
x_next = x_cur + dt * d_cur
if freeze_after_tc and t_c_v is not None and float(t_cur) <= float(t_c_v) + 1e-9:
cls_x_next = cls_frozen
else:
cls_x_next = cls_x_cur + dt * cls_d_cur
if return_trajectory:
traj.append(x_next.clone())
if return_mid_state and z_mid is None and tn <= t_mid <= tc:
z_mid = x_next.clone()
cls_mid = cls_x_next.clone()
if return_trajectory and return_mid_state:
return x_next, z_mid, cls_mid, traj
if return_trajectory:
return x_next, traj
if return_mid_state:
return x_next, z_mid, cls_mid
return x_next