Mrbizarro's picture
Initial release: code, docs, hero samples
ffe929e verified
"""Diagnose what build_edit_text_sample produces, no model load."""
from __future__ import annotations
import sys
from pathlib import Path
HERE = Path(__file__).parent
sys.path.insert(0, str(HERE))
import numpy as np
from mlx_vlm import load as mlx_vlm_load
LAB = Path(__file__).resolve().parents[2]
MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q6"
REF = "/tmp/hidream_edit_smoke/ref.png"
# Use mlx-vlm to get a working processor that skips the video-processor dep issue
backbone, processor = mlx_vlm_load(str(MODEL_PATH))
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
for n in ("boi", "bor", "eor", "bot", "tms"):
if not hasattr(tokenizer, f"{n}_token"):
setattr(tokenizer, f"{n}_token", f"<|{n}_token|>")
MC = backbone.config
from pipeline_helpers import build_edit_text_sample, PATCH_SIZE
prompt = "in the style of the reference image, a vibrant abstract composition, vivid colors, modern art"
H = W = 512
sample = build_edit_text_sample(prompt, [REF], H, W, tokenizer, processor, MC)
print("=== build_edit_text_sample shapes ===")
for k, v in sample.items():
if hasattr(v, "shape"):
print(f" {k}: shape={v.shape} dtype={v.dtype}")
else:
print(f" {k}: {v}")
iid = sample["input_ids"][0]
img_token_id = MC.image_token_id
vs_token_id = MC.vision_start_token_id
img_count = int((iid == img_token_id).sum())
vs_count = int((iid == vs_token_id).sum())
tms_count = int((iid == 151673).sum()) # tms_token_id
print(f"\n=== input_ids breakdown (text-side, length {iid.shape[0]}) ===")
print(f" image_token_id ({img_token_id}): {img_count} positions <-- vision tower fills these")
print(f" vision_start_token_id ({vs_token_id}): {vs_count}")
print(f" tms_token_id (151673): {tms_count}")
print(f" first 30 ids: {list(iid[:30])}")
print(f" last 5 ids: {list(iid[-5:])}")
pix = sample["pixel_values"]
g = sample["image_grid_thw"]
print(f"\n=== vision tower input ===")
print(f" pixel_values shape: {pix.shape}")
print(f" image_grid_thw: {g}")
# Per-image vision patch count = T*H*W, post-merge = T*H/m*W/m
m = backbone.config.vision_config.spatial_merge_size
for i, row in enumerate(g):
t, h, w = row
pre_merge = int(t * h * w)
post_merge = int(t * (h//m) * (w//m))
print(f" ref {i}: pre-merge patches={pre_merge}, post-merge={post_merge}")
print(f" TOTAL post-merge features (what vision tower outputs): {sum(int(r[0])*(int(r[1])//m)*(int(r[2])//m) for r in g)}")
print(f" TOTAL image_token_id positions in input_ids: {img_count}")
print(f" ** these must match for scatter to work **")
vinput_mask = sample["vinput_mask"][0]
vinput_mask_tgt = sample["vinput_mask_tgt_only"][0]
print(f"\n=== mask checks ===")
print(f" total vinput positions (tgt+refs): {int(vinput_mask.sum())} = {sample['tgt_image_len']} + {int(vinput_mask.sum()) - sample['tgt_image_len']}")
print(f" total tgt-only positions: {int(vinput_mask_tgt.sum())} (expect {sample['tgt_image_len']})")
# Position IDs
pids = sample["position_ids"]
print(f"\n=== position_ids ===")
print(f" shape: {pids.shape} (3D mrope: rope_dim, batch, seq)")
print(f" ranges per dim: {[(int(pids[d].min()), int(pids[d].max())) for d in range(pids.shape[0])]}")
# Where are the discontinuities? Look at the boundary between text-side and vision-token-side
txt_seq_len = iid.shape[0]
print(f" text/vision boundary at position {txt_seq_len}")
print(f" pids[:, 0, txt_seq_len-3:txt_seq_len+3] (around the boundary):")
print(pids[:, 0, max(0, txt_seq_len-3):txt_seq_len+3])