File size: 4,251 Bytes
ffe929e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Verify precompute_text_embeds_with_vision actually scatters image features
into the right positions of inputs_embeds, without mangling text positions.
"""
from __future__ import annotations
import sys
from pathlib import Path
HERE = Path(__file__).parent
sys.path.insert(0, str(HERE))

import numpy as np
import mlx.core as mx
from mlx_vlm import load as mlx_vlm_load

from pipeline_helpers import build_edit_text_sample
from hidream_model import HiDreamConfig, build_model, precompute_text_embeds_with_vision

LAB = Path(__file__).resolve().parents[2]
MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q6"
REF = "sample_outputs/v3_1024_cat_q8.png"

print("loading model...")
backbone, processor = mlx_vlm_load(str(MODEL_PATH))
cfg = HiDreamConfig()
model = build_model(cfg, backbone)
custom = mx.load(str(MODEL_PATH / "extras" / "custom_heads.safetensors"))
model.load_weights(list(custom.items()), strict=False)
mx.eval(model.parameters())

tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
for n in ("boi", "bor", "eor", "bot", "tms"):
    if not hasattr(tokenizer, f"{n}_token"):
        setattr(tokenizer, f"{n}_token", f"<|{n}_token|>")

sample = build_edit_text_sample(
    "a cat", [str(LAB / REF)], 1024, 1024, tokenizer, processor, backbone.config,
)

input_ids = mx.array(sample["input_ids"])
pixel_values = mx.array(sample["pixel_values"]).astype(mx.bfloat16)
image_grid_thw = mx.array(sample["image_grid_thw"])

# 1) Just the embed_tokens output (no scatter)
embed_tokens = model.language_model.model.embed_tokens
text_only_embeds = embed_tokens(input_ids)
mx.eval(text_only_embeds)
print(f"\ntext-only embeds shape: {text_only_embeds.shape} dtype: {text_only_embeds.dtype}")

# 2) Vision tower output
vt_out = model.visual(pixel_values, image_grid_thw)
img_features = vt_out[0] if isinstance(vt_out, tuple) else vt_out
mx.eval(img_features)
print(f"image_features shape: {img_features.shape} dtype: {img_features.dtype}")
print(f"  stats: mean={float(mx.mean(img_features.astype(mx.float32))):.4f} std={float(mx.std(img_features.astype(mx.float32))):.4f} min={float(mx.min(img_features.astype(mx.float32))):.3f} max={float(mx.max(img_features.astype(mx.float32))):.3f}")

# 3) Run our precompute
combined = precompute_text_embeds_with_vision(model, cfg, input_ids, pixel_values, image_grid_thw)
mx.eval(combined)
print(f"\ncombined embeds shape: {combined.shape} dtype: {combined.dtype}")

# 4) Inspect: at image positions, combined should equal image_features
ids_np = np.asarray(input_ids[0])
img_pos = np.where(ids_np == cfg.image_token_id)[0]
text_pos = np.where(ids_np != cfg.image_token_id)[0]
print(f"\nimage_token positions: {len(img_pos)} (first 5: {img_pos[:5].tolist()}, last 5: {img_pos[-5:].tolist()})")
print(f"text positions: {len(text_pos)} (first 5: {text_pos[:5].tolist()})")

# At image positions: combined should be image_features (in same order)
# combined[0, img_pos[i], :] should equal img_features[i, :]
combined_np = np.asarray(combined[0].astype(mx.float32))
img_feat_np = np.asarray(img_features.astype(mx.float32))
print("\n--- check: combined[0, img_pos[0]] vs img_features[0] ---")
print(f"  combined[0, {img_pos[0]}, :8] = {combined_np[img_pos[0], :8]}")
print(f"  image_features[0, :8]         = {img_feat_np[0, :8]}")
print(f"  diff: {np.abs(combined_np[img_pos[0]] - img_feat_np[0]).max():.4f}")

print("\n--- check: combined[0, img_pos[5]] vs img_features[5] ---")
print(f"  combined[0, {img_pos[5]}, :8] = {combined_np[img_pos[5], :8]}")
print(f"  image_features[5, :8]         = {img_feat_np[5, :8]}")
print(f"  diff: {np.abs(combined_np[img_pos[5]] - img_feat_np[5]).max():.4f}")

# At text positions: combined should equal embed_tokens output
text_only_np = np.asarray(text_only_embeds[0].astype(mx.float32))
diff_at_text = np.abs(combined_np[text_pos] - text_only_np[text_pos]).max()
print(f"\n--- check: combined matches text embeddings at text positions ---")
print(f"  max abs diff at text positions: {diff_at_text:.6f} (should be 0)")

# Also: at image positions, embed_tokens gives the image_token's WEIRD embedding (since the token is just a placeholder)
print(f"\n  embed_tokens at img_pos[0] (the placeholder embedding): {text_only_np[img_pos[0], :8]}")