import torch import numpy as np import torch.nn.functional as F try: from scipy.optimize import linear_sum_assignment except ImportError: linear_sum_assignment = None def ot_pair_noise_to_cls(noise_cls, cls_gt): """ Minibatch OT(与 conditional-flow-matching / torchcfm 中 sample_plan_with_scipy 一致): 在 batch 内用平方欧氏代价重排 noise,使 noise_ot[i] 与 cls_gt[i] 构成近似最优传输配对。 noise_cls, cls_gt: (N, D) 或任意可在最后一维展平为 D 的形状。 """ n = noise_cls.shape[0] if n <= 1: return noise_cls, cls_gt if linear_sum_assignment is None: return noise_cls, cls_gt x0 = noise_cls.detach().float().reshape(n, -1) x1 = cls_gt.detach().float().reshape(n, -1) M = torch.cdist(x0, x1) ** 2 _, j = linear_sum_assignment(M.cpu().numpy()) j = torch.as_tensor(j, device=noise_cls.device, dtype=torch.long) return noise_cls[j], cls_gt def mean_flat(x): """ Take the mean over all non-batch dimensions. """ return torch.mean(x, dim=list(range(1, len(x.size())))) def sum_flat(x): """ Take the mean over all non-batch dimensions. """ return torch.sum(x, dim=list(range(1, len(x.size())))) class SILoss: def __init__( self, prediction='v', path_type="linear", weighting="uniform", encoders=[], accelerator=None, latents_scale=None, latents_bias=None, t_c=0.5, ot_cls=True, ): self.prediction = prediction self.weighting = weighting self.path_type = path_type self.encoders = encoders self.accelerator = accelerator self.latents_scale = latents_scale self.latents_bias = latents_bias # t 与 train.py / JsFlow 一致:t=0 为干净 latent,t=1 为纯噪声。 # t ∈ (t_c, 1]:语义 cls 沿 OT 配对后的路径从噪声演化为 cls_gt(生成语义通道); # t ∈ [0, t_c]:cls 恒为真实 cls_gt,目标速度为 0(通道不再插值)。 tc = float(t_c) self.t_c = min(max(tc, 1e-4), 1.0 - 1e-4) self.ot_cls = bool(ot_cls) def interpolant(self, t): if self.path_type == "linear": alpha_t = 1 - t sigma_t = t d_alpha_t = -1 d_sigma_t = 1 elif self.path_type == "cosine": alpha_t = torch.cos(t * np.pi / 2) sigma_t = torch.sin(t * np.pi / 2) d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) else: raise NotImplementedError() return alpha_t, sigma_t, d_alpha_t, d_sigma_t def __call__(self, model, images, model_kwargs=None, zs=None, cls_token=None, time_input=None, noises=None,): if model_kwargs == None: model_kwargs = {} # sample timesteps if time_input is None: if self.weighting == "uniform": time_input = torch.rand((images.shape[0], 1, 1, 1)) elif self.weighting == "lognormal": # sample timestep according to log-normal distribution of sigmas following EDM rnd_normal = torch.randn((images.shape[0], 1 ,1, 1)) sigma = rnd_normal.exp() if self.path_type == "linear": time_input = sigma / (1 + sigma) elif self.path_type == "cosine": time_input = 2 / np.pi * torch.atan(sigma) time_input = time_input.to(device=images.device, dtype=torch.float32) cls_token = cls_token.to(device=images.device, dtype=torch.float32) if noises is None: noises = torch.randn_like(images) noises_cls = torch.randn_like(cls_token) else: if isinstance(noises, (tuple, list)) and len(noises) == 2: noises, noises_cls = noises else: noises_cls = torch.randn_like(cls_token) alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input) model_input = alpha_t * images + sigma_t * noises if self.prediction == 'v': model_target = d_alpha_t * images + d_sigma_t * noises else: raise NotImplementedError() N = images.shape[0] t_flat = time_input.view(-1).float() high_noise_mask = (t_flat > self.t_c).float().view(N, *([1] * (cls_token.dim() - 1))) low_noise_mask = 1.0 - high_noise_mask noise_cls_raw = noises_cls if self.ot_cls: noise_cls_paired, cls_gt_paired = ot_pair_noise_to_cls(noise_cls_raw, cls_token) else: noise_cls_paired, cls_gt_paired = noise_cls_raw, cls_token tau_shape = (N,) + (1,) * max(0, cls_token.dim() - 1) tau = (time_input.reshape(tau_shape) - self.t_c) / (1.0 - self.t_c + 1e-8) tau = torch.clamp(tau, 0.0, 1.0) alpha_sem = 1.0 - tau sigma_sem = tau cls_t_high = alpha_sem * cls_gt_paired + sigma_sem * noise_cls_paired cls_t = high_noise_mask * cls_t_high + low_noise_mask * cls_token cls_t = torch.nan_to_num(cls_t, nan=0.0, posinf=1e4, neginf=-1e4) cls_t = torch.clamp(cls_t, -1e4, 1e4) cls_for_model = cls_t * high_noise_mask + cls_t.detach() * low_noise_mask inv_scale = 1.0 / (1.0 - self.t_c + 1e-8) v_cls_high = (noise_cls_paired - cls_gt_paired) * inv_scale v_cls_target = high_noise_mask * v_cls_high model_output, zs_tilde, cls_output = model( model_input, time_input.flatten(), **model_kwargs, cls_token=cls_for_model ) #denoising_loss denoising_loss = mean_flat((model_output - model_target) ** 2) denoising_loss_cls = mean_flat((cls_output - v_cls_target) ** 2) # projection loss proj_loss = 0. bsz = zs[0].shape[0] for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)): for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)): z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1) z_j = torch.nn.functional.normalize(z_j, dim=-1) proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1)) proj_loss /= (len(zs) * bsz) return denoising_loss, proj_loss, time_input, noises, denoising_loss_cls