File size: 3,966 Bytes
ffe929e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""MLX port of FlashFlowMatchEulerDiscreteScheduler from HiDream-O1.

Reference: HiDream-ai/HiDream-O1-Image @ models/flash_scheduler.py.
Trimmed to the path the Dev recipe actually uses:
  - num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False
  - timesteps overridden by DEFAULT_TIMESTEPS after construction
  - karras/exponential/beta sigmas not used
  - step() with s_churn/s_tmin/s_tmax stripped (always defaults)

The math is verbatim from upstream — only the framework swap (torch -> mlx).
"""
from __future__ import annotations

import mlx.core as mx
import numpy as np


# Verbatim from HiDream-O1 models/pipeline.py
DEFAULT_TIMESTEPS = [
    999, 987, 974, 960, 945, 929, 913, 895, 877, 857, 836, 814, 790, 764, 737,
    707, 675, 640, 602, 560, 515, 464, 409, 347, 278, 199, 110, 8,
]


class FlashFlowMatchScheduler:
    """Euler scheduler for flow matching, with optional noise injection."""

    def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0):
        self.num_train_timesteps = num_train_timesteps
        self.shift = shift

        sigmas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps, dtype=np.float32)
        sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
        self.sigmas_np = sigmas
        self.timesteps_np = sigmas * num_train_timesteps

        self.num_inference_steps: int | None = None
        self._step_index: int | None = None

    def set_timesteps(self, num_inference_steps: int, custom_timesteps: list[int] | None = None):
        if custom_timesteps is not None:
            timesteps = np.asarray(custom_timesteps, dtype=np.float32)
            sigmas = (timesteps / self.num_train_timesteps).astype(np.float32)
            sigmas = np.append(sigmas, 0.0).astype(np.float32)
        else:
            timesteps = np.linspace(self.num_train_timesteps, 1.0, num_inference_steps, dtype=np.float32)
            sigmas = (timesteps / self.num_train_timesteps).astype(np.float32)
            sigmas = self.shift * sigmas / (1.0 + (self.shift - 1.0) * sigmas)
            sigmas = np.append(sigmas, 0.0).astype(np.float32)

        self.num_inference_steps = len(timesteps)
        self.timesteps_np = timesteps
        self.sigmas_np = sigmas
        self._step_index = None

    @property
    def timesteps(self) -> mx.array:
        return mx.array(self.timesteps_np)

    @property
    def sigmas(self) -> mx.array:
        return mx.array(self.sigmas_np)

    def _init_step_index(self, timestep_value: float):
        ts = self.timesteps_np
        matches = np.where(np.isclose(ts, timestep_value, atol=1e-3))[0]
        if len(matches) == 0:
            raise ValueError(f"timestep {timestep_value!r} not in scheduler.timesteps")
        self._step_index = int(matches[1] if len(matches) > 1 else matches[0])

    def step(self, model_output, timestep, sample,
             s_noise=1.0, noise_clip_std=0.0, seed=None):
        if self._step_index is None:
            self._init_step_index(float(timestep))
        idx = self._step_index

        sigma = float(self.sigmas_np[idx])
        sigma_next = float(self.sigmas_np[idx + 1])

        sample_f = sample.astype(mx.float32)
        model_output_f = model_output.astype(mx.float32)

        denoised = sample_f - model_output_f * sigma

        if idx < self.num_inference_steps:
            if seed is not None:
                key = mx.random.key(seed + idx)
                noise = mx.random.normal(model_output_f.shape, key=key)
            else:
                noise = mx.random.normal(model_output_f.shape)

            if noise_clip_std > 0:
                std = float(mx.std(noise))
                clip = noise_clip_std * std
                noise = mx.clip(noise, -clip, clip)

            new_sample = sigma_next * noise * s_noise + (1.0 - sigma_next) * denoised
        else:
            new_sample = denoised

        self._step_index += 1
        return new_sample.astype(sample.dtype)