"""Diagnose what build_edit_text_sample produces, no model load.""" from __future__ import annotations import sys from pathlib import Path HERE = Path(__file__).parent sys.path.insert(0, str(HERE)) import numpy as np from mlx_vlm import load as mlx_vlm_load LAB = Path(__file__).resolve().parents[2] MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q6" REF = "/tmp/hidream_edit_smoke/ref.png" # Use mlx-vlm to get a working processor that skips the video-processor dep issue backbone, processor = mlx_vlm_load(str(MODEL_PATH)) tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor for n in ("boi", "bor", "eor", "bot", "tms"): if not hasattr(tokenizer, f"{n}_token"): setattr(tokenizer, f"{n}_token", f"<|{n}_token|>") MC = backbone.config from pipeline_helpers import build_edit_text_sample, PATCH_SIZE prompt = "in the style of the reference image, a vibrant abstract composition, vivid colors, modern art" H = W = 512 sample = build_edit_text_sample(prompt, [REF], H, W, tokenizer, processor, MC) print("=== build_edit_text_sample shapes ===") for k, v in sample.items(): if hasattr(v, "shape"): print(f" {k}: shape={v.shape} dtype={v.dtype}") else: print(f" {k}: {v}") iid = sample["input_ids"][0] img_token_id = MC.image_token_id vs_token_id = MC.vision_start_token_id img_count = int((iid == img_token_id).sum()) vs_count = int((iid == vs_token_id).sum()) tms_count = int((iid == 151673).sum()) # tms_token_id print(f"\n=== input_ids breakdown (text-side, length {iid.shape[0]}) ===") print(f" image_token_id ({img_token_id}): {img_count} positions <-- vision tower fills these") print(f" vision_start_token_id ({vs_token_id}): {vs_count}") print(f" tms_token_id (151673): {tms_count}") print(f" first 30 ids: {list(iid[:30])}") print(f" last 5 ids: {list(iid[-5:])}") pix = sample["pixel_values"] g = sample["image_grid_thw"] print(f"\n=== vision tower input ===") print(f" pixel_values shape: {pix.shape}") print(f" image_grid_thw: {g}") # Per-image vision patch count = T*H*W, post-merge = T*H/m*W/m m = backbone.config.vision_config.spatial_merge_size for i, row in enumerate(g): t, h, w = row pre_merge = int(t * h * w) post_merge = int(t * (h//m) * (w//m)) print(f" ref {i}: pre-merge patches={pre_merge}, post-merge={post_merge}") print(f" TOTAL post-merge features (what vision tower outputs): {sum(int(r[0])*(int(r[1])//m)*(int(r[2])//m) for r in g)}") print(f" TOTAL image_token_id positions in input_ids: {img_count}") print(f" ** these must match for scatter to work **") vinput_mask = sample["vinput_mask"][0] vinput_mask_tgt = sample["vinput_mask_tgt_only"][0] print(f"\n=== mask checks ===") print(f" total vinput positions (tgt+refs): {int(vinput_mask.sum())} = {sample['tgt_image_len']} + {int(vinput_mask.sum()) - sample['tgt_image_len']}") print(f" total tgt-only positions: {int(vinput_mask_tgt.sum())} (expect {sample['tgt_image_len']})") # Position IDs pids = sample["position_ids"] print(f"\n=== position_ids ===") print(f" shape: {pids.shape} (3D mrope: rope_dim, batch, seq)") print(f" ranges per dim: {[(int(pids[d].min()), int(pids[d].max())) for d in range(pids.shape[0])]}") # Where are the discontinuities? Look at the boundary between text-side and vision-token-side txt_seq_len = iid.shape[0] print(f" text/vision boundary at position {txt_seq_len}") print(f" pids[:, 0, txt_seq_len-3:txt_seq_len+3] (around the boundary):") print(pids[:, 0, max(0, txt_seq_len-3):txt_seq_len+3])