Instructions to use mlx-community/HiDream-O1-Image-Dev-mlx-bf16 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use mlx-community/HiDream-O1-Image-Dev-mlx-bf16 with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir HiDream-O1-Image-Dev-mlx-bf16 mlx-community/HiDream-O1-Image-Dev-mlx-bf16
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
| """Quick A/B: forward_generation with vs without mx.compile. | |
| Times 5 forward passes after warm-up. Same shapes as a 1024x1024 inference. | |
| """ | |
| from __future__ import annotations | |
| import sys, time | |
| from pathlib import Path | |
| HERE = Path(__file__).parent | |
| sys.path.insert(0, str(HERE)) | |
| import numpy as np | |
| import mlx.core as mx | |
| from mlx_vlm import load as mlx_vlm_load | |
| from pipeline_helpers import build_t2i_text_sample, build_attention_mask, PATCH_SIZE | |
| from hidream_model import HiDreamConfig, build_model, forward_generation | |
| LAB = Path(__file__).resolve().parents[2] | |
| MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q8" | |
| print("loading model...") | |
| t0 = time.time() | |
| backbone, processor = mlx_vlm_load(str(MODEL_PATH)) | |
| print(f" {time.time()-t0:.1f}s") | |
| cfg = HiDreamConfig() | |
| model = build_model(cfg, backbone) | |
| custom = mx.load(str(MODEL_PATH / "extras" / "custom_heads.safetensors")) | |
| model.load_weights(list(custom.items()), strict=False) | |
| mx.eval(model.parameters()) | |
| print("model ready") | |
| # Build inputs at 1024x1024 | |
| WIDTH, HEIGHT = 1024, 1024 | |
| N_PATCH = (WIDTH // PATCH_SIZE) * (HEIGHT // PATCH_SIZE) # 1024 | |
| tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor | |
| for n in ("boi", "bor", "eor", "bot", "tms"): | |
| if not hasattr(tokenizer, f"{n}_token"): | |
| setattr(tokenizer, f"{n}_token", f"<|{n}_token|>") | |
| sample = build_t2i_text_sample( | |
| "a small red mushroom on a bed of moss", | |
| HEIGHT, WIDTH, tokenizer, processor, backbone.config, | |
| ) | |
| input_ids = mx.array(sample["input_ids"]) | |
| position_ids = mx.array(sample["position_ids"]) | |
| token_types = mx.array(sample["token_types"]) | |
| mask4d = mx.array(build_attention_mask(sample["token_types"], -1e4)).astype(mx.bfloat16) | |
| vinputs = mx.random.normal((1, N_PATCH, 3 * PATCH_SIZE * PATCH_SIZE)).astype(mx.bfloat16) | |
| timestep = mx.array([0.5], dtype=mx.float32) | |
| print(f"shapes: input_ids={input_ids.shape} pos={position_ids.shape} " | |
| f"vinputs={vinputs.shape} mask={mask4d.shape}") | |
| # --- Uncompiled baseline --- | |
| print("\n=== baseline (uncompiled) ===") | |
| # warmup | |
| for _ in range(2): | |
| out = forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d) | |
| mx.eval(out) | |
| # time | |
| N = 5 | |
| t0 = time.time() | |
| for _ in range(N): | |
| out = forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d) | |
| mx.eval(out) | |
| elapsed = time.time() - t0 | |
| print(f" baseline: {elapsed/N:.3f}s/step over {N} steps") | |
| # --- Compiled --- | |
| print("\n=== mx.compile ===") | |
| def fwd(input_ids, position_ids, vinputs, timestep, token_types, mask4d): | |
| return forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d) | |
| try: | |
| fwd_c = mx.compile(fwd) | |
| # warmup (first call compiles) | |
| for _ in range(2): | |
| out = fwd_c(input_ids, position_ids, vinputs, timestep, token_types, mask4d) | |
| mx.eval(out) | |
| t0 = time.time() | |
| for _ in range(N): | |
| out = fwd_c(input_ids, position_ids, vinputs, timestep, token_types, mask4d) | |
| mx.eval(out) | |
| elapsed_c = time.time() - t0 | |
| print(f" compiled: {elapsed_c/N:.3f}s/step over {N} steps (speedup {elapsed/elapsed_c:.2f}x)") | |
| except Exception as e: | |
| print(f" mx.compile failed: {type(e).__name__}: {e}") | |