File size: 3,535 Bytes
ffe929e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Diagnose what build_edit_text_sample produces, no model load."""
from __future__ import annotations
import sys
from pathlib import Path
HERE = Path(__file__).parent
sys.path.insert(0, str(HERE))

import numpy as np
from mlx_vlm import load as mlx_vlm_load

LAB = Path(__file__).resolve().parents[2]
MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q6"
REF = "/tmp/hidream_edit_smoke/ref.png"

# Use mlx-vlm to get a working processor that skips the video-processor dep issue
backbone, processor = mlx_vlm_load(str(MODEL_PATH))
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
for n in ("boi", "bor", "eor", "bot", "tms"):
    if not hasattr(tokenizer, f"{n}_token"):
        setattr(tokenizer, f"{n}_token", f"<|{n}_token|>")

MC = backbone.config

from pipeline_helpers import build_edit_text_sample, PATCH_SIZE

prompt = "in the style of the reference image, a vibrant abstract composition, vivid colors, modern art"
H = W = 512
sample = build_edit_text_sample(prompt, [REF], H, W, tokenizer, processor, MC)

print("=== build_edit_text_sample shapes ===")
for k, v in sample.items():
    if hasattr(v, "shape"):
        print(f"  {k}: shape={v.shape} dtype={v.dtype}")
    else:
        print(f"  {k}: {v}")

iid = sample["input_ids"][0]
img_token_id = MC.image_token_id
vs_token_id = MC.vision_start_token_id
img_count = int((iid == img_token_id).sum())
vs_count = int((iid == vs_token_id).sum())
tms_count = int((iid == 151673).sum())   # tms_token_id
print(f"\n=== input_ids breakdown (text-side, length {iid.shape[0]}) ===")
print(f"  image_token_id ({img_token_id}): {img_count} positions  <-- vision tower fills these")
print(f"  vision_start_token_id ({vs_token_id}): {vs_count}")
print(f"  tms_token_id (151673): {tms_count}")
print(f"  first 30 ids: {list(iid[:30])}")
print(f"  last 5 ids:   {list(iid[-5:])}")

pix = sample["pixel_values"]
g = sample["image_grid_thw"]
print(f"\n=== vision tower input ===")
print(f"  pixel_values shape: {pix.shape}")
print(f"  image_grid_thw: {g}")
# Per-image vision patch count = T*H*W, post-merge = T*H/m*W/m
m = backbone.config.vision_config.spatial_merge_size
for i, row in enumerate(g):
    t, h, w = row
    pre_merge = int(t * h * w)
    post_merge = int(t * (h//m) * (w//m))
    print(f"  ref {i}: pre-merge patches={pre_merge}, post-merge={post_merge}")
print(f"  TOTAL post-merge features (what vision tower outputs): {sum(int(r[0])*(int(r[1])//m)*(int(r[2])//m) for r in g)}")
print(f"  TOTAL image_token_id positions in input_ids: {img_count}")
print(f"  ** these must match for scatter to work **")

vinput_mask = sample["vinput_mask"][0]
vinput_mask_tgt = sample["vinput_mask_tgt_only"][0]
print(f"\n=== mask checks ===")
print(f"  total vinput positions (tgt+refs): {int(vinput_mask.sum())} = {sample['tgt_image_len']} + {int(vinput_mask.sum()) - sample['tgt_image_len']}")
print(f"  total tgt-only positions: {int(vinput_mask_tgt.sum())} (expect {sample['tgt_image_len']})")

# Position IDs
pids = sample["position_ids"]
print(f"\n=== position_ids ===")
print(f"  shape: {pids.shape}  (3D mrope: rope_dim, batch, seq)")
print(f"  ranges per dim: {[(int(pids[d].min()), int(pids[d].max())) for d in range(pids.shape[0])]}")
# Where are the discontinuities? Look at the boundary between text-side and vision-token-side
txt_seq_len = iid.shape[0]
print(f"  text/vision boundary at position {txt_seq_len}")
print(f"  pids[:, 0, txt_seq_len-3:txt_seq_len+3] (around the boundary):")
print(pids[:, 0, max(0, txt_seq_len-3):txt_seq_len+3])