Mrbizarro's picture
Initial release: code, docs, hero samples
ffe929e verified
"""Verify precompute_text_embeds_with_vision actually scatters image features
into the right positions of inputs_embeds, without mangling text positions.
"""
from __future__ import annotations
import sys
from pathlib import Path
HERE = Path(__file__).parent
sys.path.insert(0, str(HERE))
import numpy as np
import mlx.core as mx
from mlx_vlm import load as mlx_vlm_load
from pipeline_helpers import build_edit_text_sample
from hidream_model import HiDreamConfig, build_model, precompute_text_embeds_with_vision
LAB = Path(__file__).resolve().parents[2]
MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q6"
REF = "sample_outputs/v3_1024_cat_q8.png"
print("loading model...")
backbone, processor = mlx_vlm_load(str(MODEL_PATH))
cfg = HiDreamConfig()
model = build_model(cfg, backbone)
custom = mx.load(str(MODEL_PATH / "extras" / "custom_heads.safetensors"))
model.load_weights(list(custom.items()), strict=False)
mx.eval(model.parameters())
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
for n in ("boi", "bor", "eor", "bot", "tms"):
if not hasattr(tokenizer, f"{n}_token"):
setattr(tokenizer, f"{n}_token", f"<|{n}_token|>")
sample = build_edit_text_sample(
"a cat", [str(LAB / REF)], 1024, 1024, tokenizer, processor, backbone.config,
)
input_ids = mx.array(sample["input_ids"])
pixel_values = mx.array(sample["pixel_values"]).astype(mx.bfloat16)
image_grid_thw = mx.array(sample["image_grid_thw"])
# 1) Just the embed_tokens output (no scatter)
embed_tokens = model.language_model.model.embed_tokens
text_only_embeds = embed_tokens(input_ids)
mx.eval(text_only_embeds)
print(f"\ntext-only embeds shape: {text_only_embeds.shape} dtype: {text_only_embeds.dtype}")
# 2) Vision tower output
vt_out = model.visual(pixel_values, image_grid_thw)
img_features = vt_out[0] if isinstance(vt_out, tuple) else vt_out
mx.eval(img_features)
print(f"image_features shape: {img_features.shape} dtype: {img_features.dtype}")
print(f" stats: mean={float(mx.mean(img_features.astype(mx.float32))):.4f} std={float(mx.std(img_features.astype(mx.float32))):.4f} min={float(mx.min(img_features.astype(mx.float32))):.3f} max={float(mx.max(img_features.astype(mx.float32))):.3f}")
# 3) Run our precompute
combined = precompute_text_embeds_with_vision(model, cfg, input_ids, pixel_values, image_grid_thw)
mx.eval(combined)
print(f"\ncombined embeds shape: {combined.shape} dtype: {combined.dtype}")
# 4) Inspect: at image positions, combined should equal image_features
ids_np = np.asarray(input_ids[0])
img_pos = np.where(ids_np == cfg.image_token_id)[0]
text_pos = np.where(ids_np != cfg.image_token_id)[0]
print(f"\nimage_token positions: {len(img_pos)} (first 5: {img_pos[:5].tolist()}, last 5: {img_pos[-5:].tolist()})")
print(f"text positions: {len(text_pos)} (first 5: {text_pos[:5].tolist()})")
# At image positions: combined should be image_features (in same order)
# combined[0, img_pos[i], :] should equal img_features[i, :]
combined_np = np.asarray(combined[0].astype(mx.float32))
img_feat_np = np.asarray(img_features.astype(mx.float32))
print("\n--- check: combined[0, img_pos[0]] vs img_features[0] ---")
print(f" combined[0, {img_pos[0]}, :8] = {combined_np[img_pos[0], :8]}")
print(f" image_features[0, :8] = {img_feat_np[0, :8]}")
print(f" diff: {np.abs(combined_np[img_pos[0]] - img_feat_np[0]).max():.4f}")
print("\n--- check: combined[0, img_pos[5]] vs img_features[5] ---")
print(f" combined[0, {img_pos[5]}, :8] = {combined_np[img_pos[5], :8]}")
print(f" image_features[5, :8] = {img_feat_np[5, :8]}")
print(f" diff: {np.abs(combined_np[img_pos[5]] - img_feat_np[5]).max():.4f}")
# At text positions: combined should equal embed_tokens output
text_only_np = np.asarray(text_only_embeds[0].astype(mx.float32))
diff_at_text = np.abs(combined_np[text_pos] - text_only_np[text_pos]).max()
print(f"\n--- check: combined matches text embeddings at text positions ---")
print(f" max abs diff at text positions: {diff_at_text:.6f} (should be 0)")
# Also: at image positions, embed_tokens gives the image_token's WEIRD embedding (since the token is just a placeholder)
print(f"\n embed_tokens at img_pos[0] (the placeholder embedding): {text_only_np[img_pos[0], :8]}")