Instructions to use mlx-community/HiDream-O1-Image-Dev-mlx-bf16 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use mlx-community/HiDream-O1-Image-Dev-mlx-bf16 with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir HiDream-O1-Image-Dev-mlx-bf16 mlx-community/HiDream-O1-Image-Dev-mlx-bf16
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
| """MLX port of FlashFlowMatchEulerDiscreteScheduler from HiDream-O1. | |
| Reference: HiDream-ai/HiDream-O1-Image @ models/flash_scheduler.py. | |
| Trimmed to the path the Dev recipe actually uses: | |
| - num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False | |
| - timesteps overridden by DEFAULT_TIMESTEPS after construction | |
| - karras/exponential/beta sigmas not used | |
| - step() with s_churn/s_tmin/s_tmax stripped (always defaults) | |
| The math is verbatim from upstream — only the framework swap (torch -> mlx). | |
| """ | |
| from __future__ import annotations | |
| import mlx.core as mx | |
| import numpy as np | |
| # Verbatim from HiDream-O1 models/pipeline.py | |
| DEFAULT_TIMESTEPS = [ | |
| 999, 987, 974, 960, 945, 929, 913, 895, 877, 857, 836, 814, 790, 764, 737, | |
| 707, 675, 640, 602, 560, 515, 464, 409, 347, 278, 199, 110, 8, | |
| ] | |
| class FlashFlowMatchScheduler: | |
| """Euler scheduler for flow matching, with optional noise injection.""" | |
| def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0): | |
| self.num_train_timesteps = num_train_timesteps | |
| self.shift = shift | |
| sigmas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps, dtype=np.float32) | |
| sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas) | |
| self.sigmas_np = sigmas | |
| self.timesteps_np = sigmas * num_train_timesteps | |
| self.num_inference_steps: int | None = None | |
| self._step_index: int | None = None | |
| def set_timesteps(self, num_inference_steps: int, custom_timesteps: list[int] | None = None): | |
| if custom_timesteps is not None: | |
| timesteps = np.asarray(custom_timesteps, dtype=np.float32) | |
| sigmas = (timesteps / self.num_train_timesteps).astype(np.float32) | |
| sigmas = np.append(sigmas, 0.0).astype(np.float32) | |
| else: | |
| timesteps = np.linspace(self.num_train_timesteps, 1.0, num_inference_steps, dtype=np.float32) | |
| sigmas = (timesteps / self.num_train_timesteps).astype(np.float32) | |
| sigmas = self.shift * sigmas / (1.0 + (self.shift - 1.0) * sigmas) | |
| sigmas = np.append(sigmas, 0.0).astype(np.float32) | |
| self.num_inference_steps = len(timesteps) | |
| self.timesteps_np = timesteps | |
| self.sigmas_np = sigmas | |
| self._step_index = None | |
| def timesteps(self) -> mx.array: | |
| return mx.array(self.timesteps_np) | |
| def sigmas(self) -> mx.array: | |
| return mx.array(self.sigmas_np) | |
| def _init_step_index(self, timestep_value: float): | |
| ts = self.timesteps_np | |
| matches = np.where(np.isclose(ts, timestep_value, atol=1e-3))[0] | |
| if len(matches) == 0: | |
| raise ValueError(f"timestep {timestep_value!r} not in scheduler.timesteps") | |
| self._step_index = int(matches[1] if len(matches) > 1 else matches[0]) | |
| def step(self, model_output, timestep, sample, | |
| s_noise=1.0, noise_clip_std=0.0, seed=None): | |
| if self._step_index is None: | |
| self._init_step_index(float(timestep)) | |
| idx = self._step_index | |
| sigma = float(self.sigmas_np[idx]) | |
| sigma_next = float(self.sigmas_np[idx + 1]) | |
| sample_f = sample.astype(mx.float32) | |
| model_output_f = model_output.astype(mx.float32) | |
| denoised = sample_f - model_output_f * sigma | |
| if idx < self.num_inference_steps: | |
| if seed is not None: | |
| key = mx.random.key(seed + idx) | |
| noise = mx.random.normal(model_output_f.shape, key=key) | |
| else: | |
| noise = mx.random.normal(model_output_f.shape) | |
| if noise_clip_std > 0: | |
| std = float(mx.std(noise)) | |
| clip = noise_clip_std * std | |
| noise = mx.clip(noise, -clip, clip) | |
| new_sample = sigma_next * noise * s_noise + (1.0 - sigma_next) * denoised | |
| else: | |
| new_sample = denoised | |
| self._step_index += 1 | |
| return new_sample.astype(sample.dtype) | |