jsflow / REG /loss.py
xiangzai's picture
Add files using upload-large-folder tool
b65e56d verified
import torch
import numpy as np
import torch.nn.functional as F
try:
from scipy.optimize import linear_sum_assignment
except ImportError:
linear_sum_assignment = None
def ot_pair_noise_to_cls(noise_cls, cls_gt):
"""
Minibatch OT(与 conditional-flow-matching / torchcfm 中 sample_plan_with_scipy 一致):
在 batch 内用平方欧氏代价重排 noise,使 noise_ot[i] 与 cls_gt[i] 构成近似最优传输配对。
noise_cls, cls_gt: (N, D) 或任意可在最后一维展平为 D 的形状。
"""
n = noise_cls.shape[0]
if n <= 1:
return noise_cls, cls_gt
if linear_sum_assignment is None:
return noise_cls, cls_gt
x0 = noise_cls.detach().float().reshape(n, -1)
x1 = cls_gt.detach().float().reshape(n, -1)
M = torch.cdist(x0, x1) ** 2
_, j = linear_sum_assignment(M.cpu().numpy())
j = torch.as_tensor(j, device=noise_cls.device, dtype=torch.long)
return noise_cls[j], cls_gt
def mean_flat(x):
"""
Take the mean over all non-batch dimensions.
"""
return torch.mean(x, dim=list(range(1, len(x.size()))))
def sum_flat(x):
"""
Take the mean over all non-batch dimensions.
"""
return torch.sum(x, dim=list(range(1, len(x.size()))))
class SILoss:
def __init__(
self,
prediction='v',
path_type="linear",
weighting="uniform",
encoders=[],
accelerator=None,
latents_scale=None,
latents_bias=None,
t_c=0.5,
ot_cls=True,
):
self.prediction = prediction
self.weighting = weighting
self.path_type = path_type
self.encoders = encoders
self.accelerator = accelerator
self.latents_scale = latents_scale
self.latents_bias = latents_bias
# t 与 train.py / JsFlow 一致:t=0 为干净 latent,t=1 为纯噪声。
# t ∈ (t_c, 1]:语义 cls 沿 OT 配对后的路径从噪声演化为 cls_gt(生成语义通道);
# t ∈ [0, t_c]:cls 恒为真实 cls_gt,目标速度为 0(通道不再插值)。
tc = float(t_c)
self.t_c = min(max(tc, 1e-4), 1.0 - 1e-4)
self.ot_cls = bool(ot_cls)
def interpolant(self, t):
if self.path_type == "linear":
alpha_t = 1 - t
sigma_t = t
d_alpha_t = -1
d_sigma_t = 1
elif self.path_type == "cosine":
alpha_t = torch.cos(t * np.pi / 2)
sigma_t = torch.sin(t * np.pi / 2)
d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
else:
raise NotImplementedError()
return alpha_t, sigma_t, d_alpha_t, d_sigma_t
def __call__(self, model, images, model_kwargs=None, zs=None, cls_token=None,
time_input=None, noises=None,):
if model_kwargs == None:
model_kwargs = {}
# sample timesteps
if time_input is None:
if self.weighting == "uniform":
time_input = torch.rand((images.shape[0], 1, 1, 1))
elif self.weighting == "lognormal":
# sample timestep according to log-normal distribution of sigmas following EDM
rnd_normal = torch.randn((images.shape[0], 1 ,1, 1))
sigma = rnd_normal.exp()
if self.path_type == "linear":
time_input = sigma / (1 + sigma)
elif self.path_type == "cosine":
time_input = 2 / np.pi * torch.atan(sigma)
time_input = time_input.to(device=images.device, dtype=torch.float32)
cls_token = cls_token.to(device=images.device, dtype=torch.float32)
if noises is None:
noises = torch.randn_like(images)
noises_cls = torch.randn_like(cls_token)
else:
if isinstance(noises, (tuple, list)) and len(noises) == 2:
noises, noises_cls = noises
else:
noises_cls = torch.randn_like(cls_token)
alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input)
model_input = alpha_t * images + sigma_t * noises
if self.prediction == 'v':
model_target = d_alpha_t * images + d_sigma_t * noises
else:
raise NotImplementedError()
N = images.shape[0]
t_flat = time_input.view(-1).float()
high_noise_mask = (t_flat > self.t_c).float().view(N, *([1] * (cls_token.dim() - 1)))
low_noise_mask = 1.0 - high_noise_mask
noise_cls_raw = noises_cls
if self.ot_cls:
noise_cls_paired, cls_gt_paired = ot_pair_noise_to_cls(noise_cls_raw, cls_token)
else:
noise_cls_paired, cls_gt_paired = noise_cls_raw, cls_token
tau_shape = (N,) + (1,) * max(0, cls_token.dim() - 1)
tau = (time_input.reshape(tau_shape) - self.t_c) / (1.0 - self.t_c + 1e-8)
tau = torch.clamp(tau, 0.0, 1.0)
alpha_sem = 1.0 - tau
sigma_sem = tau
cls_t_high = alpha_sem * cls_gt_paired + sigma_sem * noise_cls_paired
cls_t = high_noise_mask * cls_t_high + low_noise_mask * cls_token
cls_t = torch.nan_to_num(cls_t, nan=0.0, posinf=1e4, neginf=-1e4)
cls_t = torch.clamp(cls_t, -1e4, 1e4)
cls_for_model = cls_t * high_noise_mask + cls_t.detach() * low_noise_mask
inv_scale = 1.0 / (1.0 - self.t_c + 1e-8)
v_cls_high = (noise_cls_paired - cls_gt_paired) * inv_scale
v_cls_target = high_noise_mask * v_cls_high
model_output, zs_tilde, cls_output = model(
model_input, time_input.flatten(), **model_kwargs, cls_token=cls_for_model
)
#denoising_loss
denoising_loss = mean_flat((model_output - model_target) ** 2)
denoising_loss_cls = mean_flat((cls_output - v_cls_target) ** 2)
# projection loss
proj_loss = 0.
bsz = zs[0].shape[0]
for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)):
for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)):
z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1)
z_j = torch.nn.functional.normalize(z_j, dim=-1)
proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1))
proj_loss /= (len(zs) * bsz)
return denoising_loss, proj_loss, time_input, noises, denoising_loss_cls
def tc_velocity_loss(self, model, images, model_kwargs=None, cls_token=None, noises=None):
"""
额外约束:在 t=t_c 处直接监督图像 velocity 场,增强单步(t_c -> 0)稳定性。
仅作用于图像分支,不改变原有 cls/projection 主损失定义。
"""
if model_kwargs is None:
model_kwargs = {}
if cls_token is None:
raise ValueError("tc_velocity_loss requires cls_token")
if noises is None:
noises = torch.randn_like(images)
bsz = images.shape[0]
time_input = torch.full(
(bsz, 1, 1, 1), float(self.t_c), device=images.device, dtype=torch.float32
)
alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input)
model_input = alpha_t * images + sigma_t * noises
model_target = d_alpha_t * images + d_sigma_t * noises
model_output, _, _ = model(
model_input, time_input.flatten(), **model_kwargs, cls_token=cls_token
)
return mean_flat((model_output - model_target) ** 2)