Instructions to use mlx-community/HiDream-O1-Image-Dev-mlx-bf16 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use mlx-community/HiDream-O1-Image-Dev-mlx-bf16 with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir HiDream-O1-Image-Dev-mlx-bf16 mlx-community/HiDream-O1-Image-Dev-mlx-bf16
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
File size: 3,238 Bytes
ffe929e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 | """Quick A/B: forward_generation with vs without mx.compile.
Times 5 forward passes after warm-up. Same shapes as a 1024x1024 inference.
"""
from __future__ import annotations
import sys, time
from pathlib import Path
HERE = Path(__file__).parent
sys.path.insert(0, str(HERE))
import numpy as np
import mlx.core as mx
from mlx_vlm import load as mlx_vlm_load
from pipeline_helpers import build_t2i_text_sample, build_attention_mask, PATCH_SIZE
from hidream_model import HiDreamConfig, build_model, forward_generation
LAB = Path(__file__).resolve().parents[2]
MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q8"
print("loading model...")
t0 = time.time()
backbone, processor = mlx_vlm_load(str(MODEL_PATH))
print(f" {time.time()-t0:.1f}s")
cfg = HiDreamConfig()
model = build_model(cfg, backbone)
custom = mx.load(str(MODEL_PATH / "extras" / "custom_heads.safetensors"))
model.load_weights(list(custom.items()), strict=False)
mx.eval(model.parameters())
print("model ready")
# Build inputs at 1024x1024
WIDTH, HEIGHT = 1024, 1024
N_PATCH = (WIDTH // PATCH_SIZE) * (HEIGHT // PATCH_SIZE) # 1024
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
for n in ("boi", "bor", "eor", "bot", "tms"):
if not hasattr(tokenizer, f"{n}_token"):
setattr(tokenizer, f"{n}_token", f"<|{n}_token|>")
sample = build_t2i_text_sample(
"a small red mushroom on a bed of moss",
HEIGHT, WIDTH, tokenizer, processor, backbone.config,
)
input_ids = mx.array(sample["input_ids"])
position_ids = mx.array(sample["position_ids"])
token_types = mx.array(sample["token_types"])
mask4d = mx.array(build_attention_mask(sample["token_types"], -1e4)).astype(mx.bfloat16)
vinputs = mx.random.normal((1, N_PATCH, 3 * PATCH_SIZE * PATCH_SIZE)).astype(mx.bfloat16)
timestep = mx.array([0.5], dtype=mx.float32)
print(f"shapes: input_ids={input_ids.shape} pos={position_ids.shape} "
f"vinputs={vinputs.shape} mask={mask4d.shape}")
# --- Uncompiled baseline ---
print("\n=== baseline (uncompiled) ===")
# warmup
for _ in range(2):
out = forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d)
mx.eval(out)
# time
N = 5
t0 = time.time()
for _ in range(N):
out = forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d)
mx.eval(out)
elapsed = time.time() - t0
print(f" baseline: {elapsed/N:.3f}s/step over {N} steps")
# --- Compiled ---
print("\n=== mx.compile ===")
def fwd(input_ids, position_ids, vinputs, timestep, token_types, mask4d):
return forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d)
try:
fwd_c = mx.compile(fwd)
# warmup (first call compiles)
for _ in range(2):
out = fwd_c(input_ids, position_ids, vinputs, timestep, token_types, mask4d)
mx.eval(out)
t0 = time.time()
for _ in range(N):
out = fwd_c(input_ids, position_ids, vinputs, timestep, token_types, mask4d)
mx.eval(out)
elapsed_c = time.time() - t0
print(f" compiled: {elapsed_c/N:.3f}s/step over {N} steps (speedup {elapsed/elapsed_c:.2f}x)")
except Exception as e:
print(f" mx.compile failed: {type(e).__name__}: {e}")
|