| import torch |
| import numpy as np |
| import torch.nn.functional as F |
|
|
| try: |
| from scipy.optimize import linear_sum_assignment |
| except ImportError: |
| linear_sum_assignment = None |
|
|
|
|
| def ot_pair_noise_to_cls(noise_cls, cls_gt): |
| """ |
| Minibatch OT(与 conditional-flow-matching / torchcfm 中 sample_plan_with_scipy 一致): |
| 在 batch 内用平方欧氏代价重排 noise,使 noise_ot[i] 与 cls_gt[i] 构成近似最优传输配对。 |
| noise_cls, cls_gt: (N, D) 或任意可在最后一维展平为 D 的形状。 |
| """ |
| n = noise_cls.shape[0] |
| if n <= 1: |
| return noise_cls, cls_gt |
| if linear_sum_assignment is None: |
| return noise_cls, cls_gt |
| x0 = noise_cls.detach().float().reshape(n, -1) |
| x1 = cls_gt.detach().float().reshape(n, -1) |
| M = torch.cdist(x0, x1) ** 2 |
| _, j = linear_sum_assignment(M.cpu().numpy()) |
| j = torch.as_tensor(j, device=noise_cls.device, dtype=torch.long) |
| return noise_cls[j], cls_gt |
|
|
|
|
| def mean_flat(x): |
| """ |
| Take the mean over all non-batch dimensions. |
| """ |
| return torch.mean(x, dim=list(range(1, len(x.size())))) |
|
|
| def sum_flat(x): |
| """ |
| Take the mean over all non-batch dimensions. |
| """ |
| return torch.sum(x, dim=list(range(1, len(x.size())))) |
|
|
| class SILoss: |
| def __init__( |
| self, |
| prediction='v', |
| path_type="linear", |
| weighting="uniform", |
| encoders=[], |
| accelerator=None, |
| latents_scale=None, |
| latents_bias=None, |
| t_c=0.5, |
| ot_cls=True, |
| ): |
| self.prediction = prediction |
| self.weighting = weighting |
| self.path_type = path_type |
| self.encoders = encoders |
| self.accelerator = accelerator |
| self.latents_scale = latents_scale |
| self.latents_bias = latents_bias |
| |
| |
| |
| tc = float(t_c) |
| self.t_c = min(max(tc, 1e-4), 1.0 - 1e-4) |
| self.ot_cls = bool(ot_cls) |
|
|
| def interpolant(self, t): |
| if self.path_type == "linear": |
| alpha_t = 1 - t |
| sigma_t = t |
| d_alpha_t = -1 |
| d_sigma_t = 1 |
| elif self.path_type == "cosine": |
| alpha_t = torch.cos(t * np.pi / 2) |
| sigma_t = torch.sin(t * np.pi / 2) |
| d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2) |
| d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2) |
| else: |
| raise NotImplementedError() |
|
|
| return alpha_t, sigma_t, d_alpha_t, d_sigma_t |
|
|
| def __call__(self, model, images, model_kwargs=None, zs=None, cls_token=None, |
| time_input=None, noises=None,): |
| if model_kwargs == None: |
| model_kwargs = {} |
| |
| if time_input is None: |
| if self.weighting == "uniform": |
| time_input = torch.rand((images.shape[0], 1, 1, 1)) |
| elif self.weighting == "lognormal": |
| |
| rnd_normal = torch.randn((images.shape[0], 1 ,1, 1)) |
| sigma = rnd_normal.exp() |
| if self.path_type == "linear": |
| time_input = sigma / (1 + sigma) |
| elif self.path_type == "cosine": |
| time_input = 2 / np.pi * torch.atan(sigma) |
|
|
| time_input = time_input.to(device=images.device, dtype=torch.float32) |
| cls_token = cls_token.to(device=images.device, dtype=torch.float32) |
|
|
| if noises is None: |
| noises = torch.randn_like(images) |
| noises_cls = torch.randn_like(cls_token) |
| else: |
| if isinstance(noises, (tuple, list)) and len(noises) == 2: |
| noises, noises_cls = noises |
| else: |
| noises_cls = torch.randn_like(cls_token) |
|
|
| alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input) |
|
|
| model_input = alpha_t * images + sigma_t * noises |
| if self.prediction == 'v': |
| model_target = d_alpha_t * images + d_sigma_t * noises |
| else: |
| raise NotImplementedError() |
|
|
| N = images.shape[0] |
| t_flat = time_input.view(-1).float() |
| high_noise_mask = (t_flat > self.t_c).float().view(N, *([1] * (cls_token.dim() - 1))) |
| low_noise_mask = 1.0 - high_noise_mask |
|
|
| noise_cls_raw = noises_cls |
| if self.ot_cls: |
| noise_cls_paired, cls_gt_paired = ot_pair_noise_to_cls(noise_cls_raw, cls_token) |
| else: |
| noise_cls_paired, cls_gt_paired = noise_cls_raw, cls_token |
|
|
| tau_shape = (N,) + (1,) * max(0, cls_token.dim() - 1) |
| tau = (time_input.reshape(tau_shape) - self.t_c) / (1.0 - self.t_c + 1e-8) |
| tau = torch.clamp(tau, 0.0, 1.0) |
| alpha_sem = 1.0 - tau |
| sigma_sem = tau |
|
|
| cls_t_high = alpha_sem * cls_gt_paired + sigma_sem * noise_cls_paired |
| cls_t = high_noise_mask * cls_t_high + low_noise_mask * cls_token |
| cls_t = torch.nan_to_num(cls_t, nan=0.0, posinf=1e4, neginf=-1e4) |
| cls_t = torch.clamp(cls_t, -1e4, 1e4) |
|
|
| cls_for_model = cls_t * high_noise_mask + cls_t.detach() * low_noise_mask |
|
|
| inv_scale = 1.0 / (1.0 - self.t_c + 1e-8) |
| v_cls_high = (noise_cls_paired - cls_gt_paired) * inv_scale |
| v_cls_target = high_noise_mask * v_cls_high |
|
|
| model_output, zs_tilde, cls_output = model( |
| model_input, time_input.flatten(), **model_kwargs, cls_token=cls_for_model |
| ) |
|
|
| |
| denoising_loss = mean_flat((model_output - model_target) ** 2) |
| denoising_loss_cls = mean_flat((cls_output - v_cls_target) ** 2) |
|
|
| |
| proj_loss = 0. |
| bsz = zs[0].shape[0] |
| for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)): |
| for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)): |
| z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1) |
| z_j = torch.nn.functional.normalize(z_j, dim=-1) |
| proj_loss += mean_flat(-(z_j * z_tilde_j).sum(dim=-1)) |
| proj_loss /= (len(zs) * bsz) |
|
|
| return denoising_loss, proj_loss, time_input, noises, denoising_loss_cls |
|
|
| def tc_velocity_loss(self, model, images, model_kwargs=None, cls_token=None, noises=None): |
| """ |
| 额外约束:在 t=t_c 处直接监督图像 velocity 场,增强单步(t_c -> 0)稳定性。 |
| 仅作用于图像分支,不改变原有 cls/projection 主损失定义。 |
| """ |
| if model_kwargs is None: |
| model_kwargs = {} |
| if cls_token is None: |
| raise ValueError("tc_velocity_loss requires cls_token") |
| if noises is None: |
| noises = torch.randn_like(images) |
|
|
| bsz = images.shape[0] |
| time_input = torch.full( |
| (bsz, 1, 1, 1), float(self.t_c), device=images.device, dtype=torch.float32 |
| ) |
| alpha_t, sigma_t, d_alpha_t, d_sigma_t = self.interpolant(time_input) |
| model_input = alpha_t * images + sigma_t * noises |
| model_target = d_alpha_t * images + d_sigma_t * noises |
|
|
| model_output, _, _ = model( |
| model_input, time_input.flatten(), **model_kwargs, cls_token=cls_token |
| ) |
| return mean_flat((model_output - model_target) ** 2) |
|
|