Mrbizarro's picture
Initial release: code, docs, hero samples
ffe929e verified
"""Standalone MLX model wrapper for HiDream-O1-Image."""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.fc1 = nn.Linear(frequency_embedding_size, hidden_size, bias=True)
self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True)
@staticmethod
def timestep_embedding(t: mx.array, dim: int, max_period: float = 10000.0) -> mx.array:
half = dim // 2
freqs = mx.exp(-math.log(max_period) * mx.arange(0, half, dtype=mx.float32) / half)
args = t[:, None].astype(mx.float32) * freqs[None]
emb = mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1)
if dim % 2:
emb = mx.concatenate([emb, mx.zeros_like(emb[:, :1])], axis=-1)
return emb
def __call__(self, t: mx.array) -> mx.array:
t_freq = self.timestep_embedding(t * 1000.0, self.frequency_embedding_size)
return self.fc2(nn.silu(self.fc1(t_freq.astype(self.fc1.weight.dtype))))
class BottleneckPatchEmbed(nn.Module):
def __init__(self, patch_size: int = 32, in_chans: int = 3,
pca_dim: int = 1024, embed_dim: int = 4096):
super().__init__()
self.proj1 = nn.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False)
self.proj2 = nn.Linear(pca_dim, embed_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return self.proj2(self.proj1(x))
class FinalLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int = 32, out_channels: int = 3):
super().__init__()
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return self.linear(x)
CUSTOM_HEAD_KEY_MAP = {
"model.t_embedder1.mlp.0.weight": "t_embedder1.fc1.weight",
"model.t_embedder1.mlp.0.bias": "t_embedder1.fc1.bias",
"model.t_embedder1.mlp.2.weight": "t_embedder1.fc2.weight",
"model.t_embedder1.mlp.2.bias": "t_embedder1.fc2.bias",
"model.x_embedder.proj1.weight": "x_embedder.proj1.weight",
"model.x_embedder.proj2.weight": "x_embedder.proj2.weight",
"model.x_embedder.proj2.bias": "x_embedder.proj2.bias",
"model.final_layer2.linear.weight": "final_layer2.linear.weight",
"model.final_layer2.linear.bias": "final_layer2.linear.bias",
}
@dataclass
class HiDreamConfig:
hidden_size: int = 4096
patch_size: int = 32
in_channels: int = 3
bottleneck_dim: int = 1024
tms_token_id: int = 151673
image_token_id: int = 151655
video_token_id: int = 151656
vision_start_token_id: int = 151652
def build_model(cfg: HiDreamConfig, mlx_vlm_qwen3_vl_model):
class HiDream(nn.Module):
def __init__(self):
super().__init__()
self.visual = mlx_vlm_qwen3_vl_model.vision_tower
self.language_model = mlx_vlm_qwen3_vl_model.language_model
self.t_embedder1 = TimestepEmbedder(cfg.hidden_size)
self.x_embedder = BottleneckPatchEmbed(
patch_size=cfg.patch_size, in_chans=cfg.in_channels,
pca_dim=cfg.bottleneck_dim, embed_dim=cfg.hidden_size,
)
self.final_layer2 = FinalLayer(
hidden_size=cfg.hidden_size,
patch_size=cfg.patch_size,
out_channels=cfg.in_channels,
)
return HiDream()
def precompute_text_embeds_with_vision(model, cfg, input_ids, pixel_values=None, image_grid_thw=None):
"""Compute text embeddings + (in edit mode) inject vision features at image_token
positions. Returns embeds [B, S_text, hidden]. Call once before the denoising
loop — output is constant across timesteps.
"""
embed_tokens = model.language_model.model.embed_tokens
inputs_embeds = embed_tokens(input_ids)
if pixel_values is None or image_grid_thw is None:
return inputs_embeds
vt_out = model.visual(pixel_values, image_grid_thw)
image_features = vt_out[0] if isinstance(vt_out, tuple) else vt_out
if isinstance(image_features, (list, tuple)):
image_features = mx.concatenate(image_features, axis=0)
# Build a [B, S, H] tensor that has image_features at image_token positions
# and inputs_embeds everywhere else, via mx.where on a broadcast mask.
ids_np = np.asarray(input_ids)
img_positions = np.where(ids_np[0] == cfg.image_token_id)[0]
if img_positions.shape[0] != image_features.shape[0]:
raise RuntimeError(
f"image_features {image_features.shape[0]} != "
f"image_token_id positions {img_positions.shape[0]} (input_ids was: {ids_np.shape})"
)
B, S, H = inputs_embeds.shape
# Build aligned-to-S features: zero everywhere except at image positions.
aligned = np.zeros((B, S, H), dtype=np.float32)
aligned[0, img_positions] = np.asarray(image_features.astype(mx.float32))
aligned_mx = mx.array(aligned).astype(inputs_embeds.dtype)
# Mask: 1 at image positions, 0 elsewhere
mask_2d = (ids_np == cfg.image_token_id).astype(np.bool_)
mask_3d = np.broadcast_to(mask_2d[..., None], (B, S, H))
mask_mx = mx.array(mask_3d.copy())
return mx.where(mask_mx, aligned_mx, inputs_embeds)
def forward_generation(model, cfg, inputs_embeds_with_vision, position_ids, vinputs, timestep,
input_ids, token_types, attention_mask_4d):
"""Per-step forward. Takes the precomputed text+vision inputs_embeds, the
fresh-noise vinputs, and the timestep. Returns x_pred [B, S_total, patch_dim].
Signature change vs the T2I-only version: pixel_values/image_grid_thw moved
out (call precompute_text_embeds_with_vision once before the loop). input_ids
is still needed inside because we look up tms_token positions for t_emb scatter.
"""
inputs_embeds = inputs_embeds_with_vision
t_emb = model.t_embedder1(timestep)
tms_mask = (input_ids == cfg.tms_token_id)
tms_mask_3d = mx.broadcast_to(tms_mask[..., None], inputs_embeds.shape)
t_emb_expanded = mx.broadcast_to(t_emb[:, None, :], inputs_embeds.shape)
inputs_embeds = mx.where(tms_mask_3d, t_emb_expanded, inputs_embeds)
vinputs_embedded = model.x_embedder(vinputs).astype(inputs_embeds.dtype)
inputs_embeds = mx.concatenate([inputs_embeds, vinputs_embedded], axis=1)
text_model = model.language_model.model
# mlx-vlm Qwen3VLModel.__call__ accepts (inputs, inputs_embeds, mask, cache, position_ids, ...).
# Pass our 4D additive mask directly; it bypasses the internal causal mask.
# `inputs` is required positionally but ignored when inputs_embeds is set
# in mlx-vlm's implementation — pass a placeholder of correct shape.
placeholder = mx.zeros(inputs_embeds.shape[:2], dtype=mx.int32)
h = text_model(
placeholder,
inputs_embeds=inputs_embeds,
mask=attention_mask_4d,
cache=None,
position_ids=position_ids,
)
# Apply final norm. mlx-vlm's Qwen3VLModel applies it internally and returns hidden_states.
x_pred = model.final_layer2(h)
return x_pred