Instructions to use mlx-community/HiDream-O1-Image-Dev-mlx-bf16 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use mlx-community/HiDream-O1-Image-Dev-mlx-bf16 with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir HiDream-O1-Image-Dev-mlx-bf16 mlx-community/HiDream-O1-Image-Dev-mlx-bf16
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
| """Standalone MLX model wrapper for HiDream-O1-Image.""" | |
| from __future__ import annotations | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import numpy as np | |
| class TimestepEmbedder(nn.Module): | |
| def __init__(self, hidden_size: int, frequency_embedding_size: int = 256): | |
| super().__init__() | |
| self.frequency_embedding_size = frequency_embedding_size | |
| self.fc1 = nn.Linear(frequency_embedding_size, hidden_size, bias=True) | |
| self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True) | |
| def timestep_embedding(t: mx.array, dim: int, max_period: float = 10000.0) -> mx.array: | |
| half = dim // 2 | |
| freqs = mx.exp(-math.log(max_period) * mx.arange(0, half, dtype=mx.float32) / half) | |
| args = t[:, None].astype(mx.float32) * freqs[None] | |
| emb = mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1) | |
| if dim % 2: | |
| emb = mx.concatenate([emb, mx.zeros_like(emb[:, :1])], axis=-1) | |
| return emb | |
| def __call__(self, t: mx.array) -> mx.array: | |
| t_freq = self.timestep_embedding(t * 1000.0, self.frequency_embedding_size) | |
| return self.fc2(nn.silu(self.fc1(t_freq.astype(self.fc1.weight.dtype)))) | |
| class BottleneckPatchEmbed(nn.Module): | |
| def __init__(self, patch_size: int = 32, in_chans: int = 3, | |
| pca_dim: int = 1024, embed_dim: int = 4096): | |
| super().__init__() | |
| self.proj1 = nn.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False) | |
| self.proj2 = nn.Linear(pca_dim, embed_dim, bias=True) | |
| def __call__(self, x: mx.array) -> mx.array: | |
| return self.proj2(self.proj1(x)) | |
| class FinalLayer(nn.Module): | |
| def __init__(self, hidden_size: int, patch_size: int = 32, out_channels: int = 3): | |
| super().__init__() | |
| self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) | |
| def __call__(self, x: mx.array) -> mx.array: | |
| return self.linear(x) | |
| CUSTOM_HEAD_KEY_MAP = { | |
| "model.t_embedder1.mlp.0.weight": "t_embedder1.fc1.weight", | |
| "model.t_embedder1.mlp.0.bias": "t_embedder1.fc1.bias", | |
| "model.t_embedder1.mlp.2.weight": "t_embedder1.fc2.weight", | |
| "model.t_embedder1.mlp.2.bias": "t_embedder1.fc2.bias", | |
| "model.x_embedder.proj1.weight": "x_embedder.proj1.weight", | |
| "model.x_embedder.proj2.weight": "x_embedder.proj2.weight", | |
| "model.x_embedder.proj2.bias": "x_embedder.proj2.bias", | |
| "model.final_layer2.linear.weight": "final_layer2.linear.weight", | |
| "model.final_layer2.linear.bias": "final_layer2.linear.bias", | |
| } | |
| class HiDreamConfig: | |
| hidden_size: int = 4096 | |
| patch_size: int = 32 | |
| in_channels: int = 3 | |
| bottleneck_dim: int = 1024 | |
| tms_token_id: int = 151673 | |
| image_token_id: int = 151655 | |
| video_token_id: int = 151656 | |
| vision_start_token_id: int = 151652 | |
| def build_model(cfg: HiDreamConfig, mlx_vlm_qwen3_vl_model): | |
| class HiDream(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.visual = mlx_vlm_qwen3_vl_model.vision_tower | |
| self.language_model = mlx_vlm_qwen3_vl_model.language_model | |
| self.t_embedder1 = TimestepEmbedder(cfg.hidden_size) | |
| self.x_embedder = BottleneckPatchEmbed( | |
| patch_size=cfg.patch_size, in_chans=cfg.in_channels, | |
| pca_dim=cfg.bottleneck_dim, embed_dim=cfg.hidden_size, | |
| ) | |
| self.final_layer2 = FinalLayer( | |
| hidden_size=cfg.hidden_size, | |
| patch_size=cfg.patch_size, | |
| out_channels=cfg.in_channels, | |
| ) | |
| return HiDream() | |
| def precompute_text_embeds_with_vision(model, cfg, input_ids, pixel_values=None, image_grid_thw=None): | |
| """Compute text embeddings + (in edit mode) inject vision features at image_token | |
| positions. Returns embeds [B, S_text, hidden]. Call once before the denoising | |
| loop — output is constant across timesteps. | |
| """ | |
| embed_tokens = model.language_model.model.embed_tokens | |
| inputs_embeds = embed_tokens(input_ids) | |
| if pixel_values is None or image_grid_thw is None: | |
| return inputs_embeds | |
| vt_out = model.visual(pixel_values, image_grid_thw) | |
| image_features = vt_out[0] if isinstance(vt_out, tuple) else vt_out | |
| if isinstance(image_features, (list, tuple)): | |
| image_features = mx.concatenate(image_features, axis=0) | |
| # Build a [B, S, H] tensor that has image_features at image_token positions | |
| # and inputs_embeds everywhere else, via mx.where on a broadcast mask. | |
| ids_np = np.asarray(input_ids) | |
| img_positions = np.where(ids_np[0] == cfg.image_token_id)[0] | |
| if img_positions.shape[0] != image_features.shape[0]: | |
| raise RuntimeError( | |
| f"image_features {image_features.shape[0]} != " | |
| f"image_token_id positions {img_positions.shape[0]} (input_ids was: {ids_np.shape})" | |
| ) | |
| B, S, H = inputs_embeds.shape | |
| # Build aligned-to-S features: zero everywhere except at image positions. | |
| aligned = np.zeros((B, S, H), dtype=np.float32) | |
| aligned[0, img_positions] = np.asarray(image_features.astype(mx.float32)) | |
| aligned_mx = mx.array(aligned).astype(inputs_embeds.dtype) | |
| # Mask: 1 at image positions, 0 elsewhere | |
| mask_2d = (ids_np == cfg.image_token_id).astype(np.bool_) | |
| mask_3d = np.broadcast_to(mask_2d[..., None], (B, S, H)) | |
| mask_mx = mx.array(mask_3d.copy()) | |
| return mx.where(mask_mx, aligned_mx, inputs_embeds) | |
| def forward_generation(model, cfg, inputs_embeds_with_vision, position_ids, vinputs, timestep, | |
| input_ids, token_types, attention_mask_4d): | |
| """Per-step forward. Takes the precomputed text+vision inputs_embeds, the | |
| fresh-noise vinputs, and the timestep. Returns x_pred [B, S_total, patch_dim]. | |
| Signature change vs the T2I-only version: pixel_values/image_grid_thw moved | |
| out (call precompute_text_embeds_with_vision once before the loop). input_ids | |
| is still needed inside because we look up tms_token positions for t_emb scatter. | |
| """ | |
| inputs_embeds = inputs_embeds_with_vision | |
| t_emb = model.t_embedder1(timestep) | |
| tms_mask = (input_ids == cfg.tms_token_id) | |
| tms_mask_3d = mx.broadcast_to(tms_mask[..., None], inputs_embeds.shape) | |
| t_emb_expanded = mx.broadcast_to(t_emb[:, None, :], inputs_embeds.shape) | |
| inputs_embeds = mx.where(tms_mask_3d, t_emb_expanded, inputs_embeds) | |
| vinputs_embedded = model.x_embedder(vinputs).astype(inputs_embeds.dtype) | |
| inputs_embeds = mx.concatenate([inputs_embeds, vinputs_embedded], axis=1) | |
| text_model = model.language_model.model | |
| # mlx-vlm Qwen3VLModel.__call__ accepts (inputs, inputs_embeds, mask, cache, position_ids, ...). | |
| # Pass our 4D additive mask directly; it bypasses the internal causal mask. | |
| # `inputs` is required positionally but ignored when inputs_embeds is set | |
| # in mlx-vlm's implementation — pass a placeholder of correct shape. | |
| placeholder = mx.zeros(inputs_embeds.shape[:2], dtype=mx.int32) | |
| h = text_model( | |
| placeholder, | |
| inputs_embeds=inputs_embeds, | |
| mask=attention_mask_4d, | |
| cache=None, | |
| position_ids=position_ids, | |
| ) | |
| # Apply final norm. mlx-vlm's Qwen3VLModel applies it internally and returns hidden_states. | |
| x_pred = model.final_layer2(h) | |
| return x_pred | |