| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput |
| from typing import Union, Optional, Tuple |
| import torch |
|
|
| class AdditFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): |
| def step( |
| self, |
| model_output: torch.FloatTensor, |
| timestep: Union[float, torch.FloatTensor], |
| sample: torch.FloatTensor, |
| s_churn: float = 0.0, |
| s_tmin: float = 0.0, |
| s_tmax: float = float("inf"), |
| s_noise: float = 1.0, |
| generator: Optional[torch.Generator] = None, |
| return_dict: bool = True, |
| step_index: Optional[int] = None, |
| ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: |
| """ |
| Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion |
| process from the learned model outputs (most often the predicted noise). |
| |
| Args: |
| model_output (`torch.FloatTensor`): |
| The direct output from learned diffusion model. |
| timestep (`float`): |
| The current discrete timestep in the diffusion chain. |
| sample (`torch.FloatTensor`): |
| A current instance of a sample created by the diffusion process. |
| s_churn (`float`): |
| s_tmin (`float`): |
| s_tmax (`float`): |
| s_noise (`float`, defaults to 1.0): |
| Scaling factor for noise added to the sample. |
| generator (`torch.Generator`, *optional*): |
| A random number generator. |
| return_dict (`bool`): |
| Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or |
| tuple. |
| |
| Returns: |
| [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: |
| If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is |
| returned, otherwise a tuple is returned where the first element is the sample tensor. |
| """ |
|
|
| if ( |
| isinstance(timestep, int) |
| or isinstance(timestep, torch.IntTensor) |
| or isinstance(timestep, torch.LongTensor) |
| ): |
| raise ValueError( |
| ( |
| "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" |
| " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" |
| " one of the `scheduler.timesteps` as a timestep." |
| ), |
| ) |
|
|
| if step_index is not None: |
| self._step_index = step_index |
|
|
| if self.step_index is None: |
| self._init_step_index(timestep) |
|
|
| |
| sample = sample.to(torch.float32) |
|
|
| sigma = self.sigmas[self.step_index] |
| sigma_next = self.sigmas[self.step_index + 1] |
|
|
| prev_sample = sample + (sigma_next - sigma) * model_output |
|
|
| |
| x_0 = sample - sigma * model_output |
|
|
| |
| prev_sample = prev_sample.to(model_output.dtype) |
| x_0 = x_0.to(model_output.dtype) |
|
|
| |
| self._step_index += 1 |
|
|
| if not return_dict: |
| return (prev_sample, x_0) |
|
|
| return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) |