Instructions to use mlx-community/HiDream-O1-Image-Dev-mlx-bf16 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use mlx-community/HiDream-O1-Image-Dev-mlx-bf16 with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir HiDream-O1-Image-Dev-mlx-bf16 mlx-community/HiDream-O1-Image-Dev-mlx-bf16
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
File size: 4,251 Bytes
ffe929e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | """Verify precompute_text_embeds_with_vision actually scatters image features
into the right positions of inputs_embeds, without mangling text positions.
"""
from __future__ import annotations
import sys
from pathlib import Path
HERE = Path(__file__).parent
sys.path.insert(0, str(HERE))
import numpy as np
import mlx.core as mx
from mlx_vlm import load as mlx_vlm_load
from pipeline_helpers import build_edit_text_sample
from hidream_model import HiDreamConfig, build_model, precompute_text_embeds_with_vision
LAB = Path(__file__).resolve().parents[2]
MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q6"
REF = "sample_outputs/v3_1024_cat_q8.png"
print("loading model...")
backbone, processor = mlx_vlm_load(str(MODEL_PATH))
cfg = HiDreamConfig()
model = build_model(cfg, backbone)
custom = mx.load(str(MODEL_PATH / "extras" / "custom_heads.safetensors"))
model.load_weights(list(custom.items()), strict=False)
mx.eval(model.parameters())
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
for n in ("boi", "bor", "eor", "bot", "tms"):
if not hasattr(tokenizer, f"{n}_token"):
setattr(tokenizer, f"{n}_token", f"<|{n}_token|>")
sample = build_edit_text_sample(
"a cat", [str(LAB / REF)], 1024, 1024, tokenizer, processor, backbone.config,
)
input_ids = mx.array(sample["input_ids"])
pixel_values = mx.array(sample["pixel_values"]).astype(mx.bfloat16)
image_grid_thw = mx.array(sample["image_grid_thw"])
# 1) Just the embed_tokens output (no scatter)
embed_tokens = model.language_model.model.embed_tokens
text_only_embeds = embed_tokens(input_ids)
mx.eval(text_only_embeds)
print(f"\ntext-only embeds shape: {text_only_embeds.shape} dtype: {text_only_embeds.dtype}")
# 2) Vision tower output
vt_out = model.visual(pixel_values, image_grid_thw)
img_features = vt_out[0] if isinstance(vt_out, tuple) else vt_out
mx.eval(img_features)
print(f"image_features shape: {img_features.shape} dtype: {img_features.dtype}")
print(f" stats: mean={float(mx.mean(img_features.astype(mx.float32))):.4f} std={float(mx.std(img_features.astype(mx.float32))):.4f} min={float(mx.min(img_features.astype(mx.float32))):.3f} max={float(mx.max(img_features.astype(mx.float32))):.3f}")
# 3) Run our precompute
combined = precompute_text_embeds_with_vision(model, cfg, input_ids, pixel_values, image_grid_thw)
mx.eval(combined)
print(f"\ncombined embeds shape: {combined.shape} dtype: {combined.dtype}")
# 4) Inspect: at image positions, combined should equal image_features
ids_np = np.asarray(input_ids[0])
img_pos = np.where(ids_np == cfg.image_token_id)[0]
text_pos = np.where(ids_np != cfg.image_token_id)[0]
print(f"\nimage_token positions: {len(img_pos)} (first 5: {img_pos[:5].tolist()}, last 5: {img_pos[-5:].tolist()})")
print(f"text positions: {len(text_pos)} (first 5: {text_pos[:5].tolist()})")
# At image positions: combined should be image_features (in same order)
# combined[0, img_pos[i], :] should equal img_features[i, :]
combined_np = np.asarray(combined[0].astype(mx.float32))
img_feat_np = np.asarray(img_features.astype(mx.float32))
print("\n--- check: combined[0, img_pos[0]] vs img_features[0] ---")
print(f" combined[0, {img_pos[0]}, :8] = {combined_np[img_pos[0], :8]}")
print(f" image_features[0, :8] = {img_feat_np[0, :8]}")
print(f" diff: {np.abs(combined_np[img_pos[0]] - img_feat_np[0]).max():.4f}")
print("\n--- check: combined[0, img_pos[5]] vs img_features[5] ---")
print(f" combined[0, {img_pos[5]}, :8] = {combined_np[img_pos[5], :8]}")
print(f" image_features[5, :8] = {img_feat_np[5, :8]}")
print(f" diff: {np.abs(combined_np[img_pos[5]] - img_feat_np[5]).max():.4f}")
# At text positions: combined should equal embed_tokens output
text_only_np = np.asarray(text_only_embeds[0].astype(mx.float32))
diff_at_text = np.abs(combined_np[text_pos] - text_only_np[text_pos]).max()
print(f"\n--- check: combined matches text embeddings at text positions ---")
print(f" max abs diff at text positions: {diff_at_text:.6f} (should be 0)")
# Also: at image positions, embed_tokens gives the image_token's WEIRD embedding (since the token is just a placeholder)
print(f"\n embed_tokens at img_pos[0] (the placeholder embedding): {text_only_np[img_pos[0], :8]}")
|