"""Quick A/B: forward_generation with vs without mx.compile. Times 5 forward passes after warm-up. Same shapes as a 1024x1024 inference. """ from __future__ import annotations import sys, time from pathlib import Path HERE = Path(__file__).parent sys.path.insert(0, str(HERE)) import numpy as np import mlx.core as mx from mlx_vlm import load as mlx_vlm_load from pipeline_helpers import build_t2i_text_sample, build_attention_mask, PATCH_SIZE from hidream_model import HiDreamConfig, build_model, forward_generation LAB = Path(__file__).resolve().parents[2] MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q8" print("loading model...") t0 = time.time() backbone, processor = mlx_vlm_load(str(MODEL_PATH)) print(f" {time.time()-t0:.1f}s") cfg = HiDreamConfig() model = build_model(cfg, backbone) custom = mx.load(str(MODEL_PATH / "extras" / "custom_heads.safetensors")) model.load_weights(list(custom.items()), strict=False) mx.eval(model.parameters()) print("model ready") # Build inputs at 1024x1024 WIDTH, HEIGHT = 1024, 1024 N_PATCH = (WIDTH // PATCH_SIZE) * (HEIGHT // PATCH_SIZE) # 1024 tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor for n in ("boi", "bor", "eor", "bot", "tms"): if not hasattr(tokenizer, f"{n}_token"): setattr(tokenizer, f"{n}_token", f"<|{n}_token|>") sample = build_t2i_text_sample( "a small red mushroom on a bed of moss", HEIGHT, WIDTH, tokenizer, processor, backbone.config, ) input_ids = mx.array(sample["input_ids"]) position_ids = mx.array(sample["position_ids"]) token_types = mx.array(sample["token_types"]) mask4d = mx.array(build_attention_mask(sample["token_types"], -1e4)).astype(mx.bfloat16) vinputs = mx.random.normal((1, N_PATCH, 3 * PATCH_SIZE * PATCH_SIZE)).astype(mx.bfloat16) timestep = mx.array([0.5], dtype=mx.float32) print(f"shapes: input_ids={input_ids.shape} pos={position_ids.shape} " f"vinputs={vinputs.shape} mask={mask4d.shape}") # --- Uncompiled baseline --- print("\n=== baseline (uncompiled) ===") # warmup for _ in range(2): out = forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d) mx.eval(out) # time N = 5 t0 = time.time() for _ in range(N): out = forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d) mx.eval(out) elapsed = time.time() - t0 print(f" baseline: {elapsed/N:.3f}s/step over {N} steps") # --- Compiled --- print("\n=== mx.compile ===") def fwd(input_ids, position_ids, vinputs, timestep, token_types, mask4d): return forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d) try: fwd_c = mx.compile(fwd) # warmup (first call compiles) for _ in range(2): out = fwd_c(input_ids, position_ids, vinputs, timestep, token_types, mask4d) mx.eval(out) t0 = time.time() for _ in range(N): out = fwd_c(input_ids, position_ids, vinputs, timestep, token_types, mask4d) mx.eval(out) elapsed_c = time.time() - t0 print(f" compiled: {elapsed_c/N:.3f}s/step over {N} steps (speedup {elapsed/elapsed_c:.2f}x)") except Exception as e: print(f" mx.compile failed: {type(e).__name__}: {e}")