Spaces:
Running on Zero
Running on Zero
| from .base_pipeline import BasePipeline | |
| import torch | |
| import copy | |
| def FlowMatchSFTLoss(pipe: BasePipeline, **inputs): | |
| max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps)) | |
| min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps)) | |
| timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,)) | |
| timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device) | |
| noise = torch.randn_like(inputs["input_latents"]) | |
| origin_latents = copy.deepcopy(inputs["input_latents"]) | |
| noisy_latents = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep) | |
| tgt_latent_len = noisy_latents.shape[2] // 2 | |
| noisy_latents[:, :, tgt_latent_len:, ...] = origin_latents[:, :, tgt_latent_len:, ...] | |
| inputs["latents"] = noisy_latents | |
| if "first_frame_latents" in inputs: | |
| inputs["latents"][:, :, 0:1] = inputs['first_frame_latents'] | |
| training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep) | |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} | |
| noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep) | |
| diff = (noise_pred[:, :, 1:tgt_latent_len] - training_target[:, :, 1:tgt_latent_len])**2 | |
| # diff: [B,C,T,H,W] | |
| gamma = 0.01 | |
| T = tgt_latent_len | |
| i = torch.arange(1, T, device=diff.device).float() | |
| d = torch.abs(2 * i / (T - 1) - 1.0) | |
| w_f = 1.0 + gamma * d**2 # [T] | |
| w_f = w_f.view(1,1,T-1,1,1) | |
| loss = (diff * w_f).mean() | |
| loss = loss * pipe.scheduler.training_weight(timestep) | |
| return loss | |
| def DirectDistillLoss(pipe: BasePipeline, **inputs): | |
| pipe.scheduler.set_timesteps(inputs["num_inference_steps"]) | |
| pipe.scheduler.training = True | |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} | |
| for progress_id, timestep in enumerate(pipe.scheduler.timesteps): | |
| timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep, progress_id=progress_id) | |
| inputs["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs) | |
| loss = torch.nn.functional.mse_loss(inputs["latents"].float(), inputs["input_latents"].float()) | |
| return loss | |
| class TrajectoryImitationLoss(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.initialized = False | |
| def initialize(self, device): | |
| import lpips # TODO: remove it | |
| self.loss_fn = lpips.LPIPS(net='alex').to(device) | |
| self.initialized = True | |
| def fetch_trajectory(self, pipe: BasePipeline, timesteps_student, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): | |
| trajectory = [inputs_shared["latents"].clone()] | |
| pipe.scheduler.set_timesteps(num_inference_steps, target_timesteps=timesteps_student) | |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} | |
| for progress_id, timestep in enumerate(pipe.scheduler.timesteps): | |
| timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| noise_pred = pipe.cfg_guided_model_fn( | |
| pipe.model_fn, cfg_scale, | |
| inputs_shared, inputs_posi, inputs_nega, | |
| **models, timestep=timestep, progress_id=progress_id | |
| ) | |
| inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) | |
| trajectory.append(inputs_shared["latents"].clone()) | |
| return pipe.scheduler.timesteps, trajectory | |
| def align_trajectory(self, pipe: BasePipeline, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): | |
| loss = 0 | |
| pipe.scheduler.set_timesteps(num_inference_steps, training=True) | |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} | |
| for progress_id, timestep in enumerate(pipe.scheduler.timesteps): | |
| timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| progress_id_teacher = torch.argmin((timesteps_teacher - timestep).abs()) | |
| inputs_shared["latents"] = trajectory_teacher[progress_id_teacher] | |
| noise_pred = pipe.cfg_guided_model_fn( | |
| pipe.model_fn, cfg_scale, | |
| inputs_shared, inputs_posi, inputs_nega, | |
| **models, timestep=timestep, progress_id=progress_id | |
| ) | |
| sigma = pipe.scheduler.sigmas[progress_id] | |
| sigma_ = 0 if progress_id + 1 >= len(pipe.scheduler.timesteps) else pipe.scheduler.sigmas[progress_id + 1] | |
| if progress_id + 1 >= len(pipe.scheduler.timesteps): | |
| latents_ = trajectory_teacher[-1] | |
| else: | |
| progress_id_teacher = torch.argmin((timesteps_teacher - pipe.scheduler.timesteps[progress_id + 1]).abs()) | |
| latents_ = trajectory_teacher[progress_id_teacher] | |
| target = (latents_ - inputs_shared["latents"]) / (sigma_ - sigma) | |
| loss = loss + torch.nn.functional.mse_loss(noise_pred.float(), target.float()) * pipe.scheduler.training_weight(timestep) | |
| return loss | |
| def compute_regularization(self, pipe: BasePipeline, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, num_inference_steps, cfg_scale): | |
| inputs_shared["latents"] = trajectory_teacher[0] | |
| pipe.scheduler.set_timesteps(num_inference_steps) | |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} | |
| for progress_id, timestep in enumerate(pipe.scheduler.timesteps): | |
| timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device=pipe.device) | |
| noise_pred = pipe.cfg_guided_model_fn( | |
| pipe.model_fn, cfg_scale, | |
| inputs_shared, inputs_posi, inputs_nega, | |
| **models, timestep=timestep, progress_id=progress_id | |
| ) | |
| inputs_shared["latents"] = pipe.step(pipe.scheduler, progress_id=progress_id, noise_pred=noise_pred.detach(), **inputs_shared) | |
| image_pred = pipe.vae_decoder(inputs_shared["latents"]) | |
| image_real = pipe.vae_decoder(trajectory_teacher[-1]) | |
| loss = self.loss_fn(image_pred.float(), image_real.float()) | |
| return loss | |
| def forward(self, pipe: BasePipeline, inputs_shared, inputs_posi, inputs_nega): | |
| if not self.initialized: | |
| self.initialize(pipe.device) | |
| with torch.no_grad(): | |
| pipe.scheduler.set_timesteps(8) | |
| timesteps_teacher, trajectory_teacher = self.fetch_trajectory(inputs_shared["teacher"], pipe.scheduler.timesteps, inputs_shared, inputs_posi, inputs_nega, 50, 2) | |
| timesteps_teacher = timesteps_teacher.to(dtype=pipe.torch_dtype, device=pipe.device) | |
| loss_1 = self.align_trajectory(pipe, timesteps_teacher, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) | |
| loss_2 = self.compute_regularization(pipe, trajectory_teacher, inputs_shared, inputs_posi, inputs_nega, 8, 1) | |
| loss = loss_1 + loss_2 | |
| return loss | |