"""Verify precompute_text_embeds_with_vision actually scatters image features into the right positions of inputs_embeds, without mangling text positions. """ from __future__ import annotations import sys from pathlib import Path HERE = Path(__file__).parent sys.path.insert(0, str(HERE)) import numpy as np import mlx.core as mx from mlx_vlm import load as mlx_vlm_load from pipeline_helpers import build_edit_text_sample from hidream_model import HiDreamConfig, build_model, precompute_text_embeds_with_vision LAB = Path(__file__).resolve().parents[2] MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q6" REF = "sample_outputs/v3_1024_cat_q8.png" print("loading model...") backbone, processor = mlx_vlm_load(str(MODEL_PATH)) cfg = HiDreamConfig() model = build_model(cfg, backbone) custom = mx.load(str(MODEL_PATH / "extras" / "custom_heads.safetensors")) model.load_weights(list(custom.items()), strict=False) mx.eval(model.parameters()) tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor for n in ("boi", "bor", "eor", "bot", "tms"): if not hasattr(tokenizer, f"{n}_token"): setattr(tokenizer, f"{n}_token", f"<|{n}_token|>") sample = build_edit_text_sample( "a cat", [str(LAB / REF)], 1024, 1024, tokenizer, processor, backbone.config, ) input_ids = mx.array(sample["input_ids"]) pixel_values = mx.array(sample["pixel_values"]).astype(mx.bfloat16) image_grid_thw = mx.array(sample["image_grid_thw"]) # 1) Just the embed_tokens output (no scatter) embed_tokens = model.language_model.model.embed_tokens text_only_embeds = embed_tokens(input_ids) mx.eval(text_only_embeds) print(f"\ntext-only embeds shape: {text_only_embeds.shape} dtype: {text_only_embeds.dtype}") # 2) Vision tower output vt_out = model.visual(pixel_values, image_grid_thw) img_features = vt_out[0] if isinstance(vt_out, tuple) else vt_out mx.eval(img_features) print(f"image_features shape: {img_features.shape} dtype: {img_features.dtype}") print(f" stats: mean={float(mx.mean(img_features.astype(mx.float32))):.4f} std={float(mx.std(img_features.astype(mx.float32))):.4f} min={float(mx.min(img_features.astype(mx.float32))):.3f} max={float(mx.max(img_features.astype(mx.float32))):.3f}") # 3) Run our precompute combined = precompute_text_embeds_with_vision(model, cfg, input_ids, pixel_values, image_grid_thw) mx.eval(combined) print(f"\ncombined embeds shape: {combined.shape} dtype: {combined.dtype}") # 4) Inspect: at image positions, combined should equal image_features ids_np = np.asarray(input_ids[0]) img_pos = np.where(ids_np == cfg.image_token_id)[0] text_pos = np.where(ids_np != cfg.image_token_id)[0] print(f"\nimage_token positions: {len(img_pos)} (first 5: {img_pos[:5].tolist()}, last 5: {img_pos[-5:].tolist()})") print(f"text positions: {len(text_pos)} (first 5: {text_pos[:5].tolist()})") # At image positions: combined should be image_features (in same order) # combined[0, img_pos[i], :] should equal img_features[i, :] combined_np = np.asarray(combined[0].astype(mx.float32)) img_feat_np = np.asarray(img_features.astype(mx.float32)) print("\n--- check: combined[0, img_pos[0]] vs img_features[0] ---") print(f" combined[0, {img_pos[0]}, :8] = {combined_np[img_pos[0], :8]}") print(f" image_features[0, :8] = {img_feat_np[0, :8]}") print(f" diff: {np.abs(combined_np[img_pos[0]] - img_feat_np[0]).max():.4f}") print("\n--- check: combined[0, img_pos[5]] vs img_features[5] ---") print(f" combined[0, {img_pos[5]}, :8] = {combined_np[img_pos[5], :8]}") print(f" image_features[5, :8] = {img_feat_np[5, :8]}") print(f" diff: {np.abs(combined_np[img_pos[5]] - img_feat_np[5]).max():.4f}") # At text positions: combined should equal embed_tokens output text_only_np = np.asarray(text_only_embeds[0].astype(mx.float32)) diff_at_text = np.abs(combined_np[text_pos] - text_only_np[text_pos]).max() print(f"\n--- check: combined matches text embeddings at text positions ---") print(f" max abs diff at text positions: {diff_at_text:.6f} (should be 0)") # Also: at image positions, embed_tokens gives the image_token's WEIRD embedding (since the token is just a placeholder) print(f"\n embed_tokens at img_pos[0] (the placeholder embedding): {text_only_np[img_pos[0], :8]}")