File size: 3,238 Bytes
ffe929e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""Quick A/B: forward_generation with vs without mx.compile.

Times 5 forward passes after warm-up. Same shapes as a 1024x1024 inference.
"""
from __future__ import annotations
import sys, time
from pathlib import Path

HERE = Path(__file__).parent
sys.path.insert(0, str(HERE))

import numpy as np
import mlx.core as mx
from mlx_vlm import load as mlx_vlm_load
from pipeline_helpers import build_t2i_text_sample, build_attention_mask, PATCH_SIZE
from hidream_model import HiDreamConfig, build_model, forward_generation

LAB = Path(__file__).resolve().parents[2]
MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q8"

print("loading model...")
t0 = time.time()
backbone, processor = mlx_vlm_load(str(MODEL_PATH))
print(f"  {time.time()-t0:.1f}s")

cfg = HiDreamConfig()
model = build_model(cfg, backbone)
custom = mx.load(str(MODEL_PATH / "extras" / "custom_heads.safetensors"))
model.load_weights(list(custom.items()), strict=False)
mx.eval(model.parameters())
print("model ready")

# Build inputs at 1024x1024
WIDTH, HEIGHT = 1024, 1024
N_PATCH = (WIDTH // PATCH_SIZE) * (HEIGHT // PATCH_SIZE)  # 1024

tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
for n in ("boi", "bor", "eor", "bot", "tms"):
    if not hasattr(tokenizer, f"{n}_token"):
        setattr(tokenizer, f"{n}_token", f"<|{n}_token|>")

sample = build_t2i_text_sample(
    "a small red mushroom on a bed of moss",
    HEIGHT, WIDTH, tokenizer, processor, backbone.config,
)
input_ids = mx.array(sample["input_ids"])
position_ids = mx.array(sample["position_ids"])
token_types = mx.array(sample["token_types"])
mask4d = mx.array(build_attention_mask(sample["token_types"], -1e4)).astype(mx.bfloat16)
vinputs = mx.random.normal((1, N_PATCH, 3 * PATCH_SIZE * PATCH_SIZE)).astype(mx.bfloat16)
timestep = mx.array([0.5], dtype=mx.float32)

print(f"shapes: input_ids={input_ids.shape} pos={position_ids.shape} "
      f"vinputs={vinputs.shape} mask={mask4d.shape}")

# --- Uncompiled baseline ---
print("\n=== baseline (uncompiled) ===")
# warmup
for _ in range(2):
    out = forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d)
    mx.eval(out)

# time
N = 5
t0 = time.time()
for _ in range(N):
    out = forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d)
    mx.eval(out)
elapsed = time.time() - t0
print(f"  baseline: {elapsed/N:.3f}s/step over {N} steps")

# --- Compiled ---
print("\n=== mx.compile ===")
def fwd(input_ids, position_ids, vinputs, timestep, token_types, mask4d):
    return forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d)

try:
    fwd_c = mx.compile(fwd)
    # warmup (first call compiles)
    for _ in range(2):
        out = fwd_c(input_ids, position_ids, vinputs, timestep, token_types, mask4d)
        mx.eval(out)
    t0 = time.time()
    for _ in range(N):
        out = fwd_c(input_ids, position_ids, vinputs, timestep, token_types, mask4d)
        mx.eval(out)
    elapsed_c = time.time() - t0
    print(f"  compiled: {elapsed_c/N:.3f}s/step over {N} steps  (speedup {elapsed/elapsed_c:.2f}x)")
except Exception as e:
    print(f"  mx.compile failed: {type(e).__name__}: {e}")