Mrbizarro commited on
Commit
ffe929e
·
verified ·
1 Parent(s): 269dd41

Initial release: code, docs, hero samples

Browse files
.gitattributes CHANGED
@@ -1,35 +1,20 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
 
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
  *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
1
+ # Git LFS configuration for HuggingFace model release.
2
+ # Weights live under mlx_models/ and are git-LFS-tracked.
3
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
5
+ *.gguf filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
6
  *.ot filter=lfs diff=lfs merge=lfs -text
7
+ *.onnx filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
8
  *.tar filter=lfs diff=lfs merge=lfs -text
9
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
10
+
11
+ # Sample images stay regular git (small enough for plain repo storage)
12
+ sample_outputs/**/*.png -filter -diff -merge text=auto
13
+ sample_outputs/hero/01_tea_master.png filter=lfs diff=lfs merge=lfs -text
14
+ sample_outputs/hero/02_tropical_beach.png filter=lfs diff=lfs merge=lfs -text
15
+ sample_outputs/hero/03_astronaut.png filter=lfs diff=lfs merge=lfs -text
16
+ sample_outputs/hero/04_construction_worker.png filter=lfs diff=lfs merge=lfs -text
17
+ sample_outputs/hero/05_mountain_peak.png filter=lfs diff=lfs merge=lfs -text
18
+ sample_outputs/hero/06_alice_cyberpunk.png filter=lfs diff=lfs merge=lfs -text
19
+ sample_outputs/hero/07_kitchen_morning.png filter=lfs diff=lfs merge=lfs -text
20
+ sample_outputs/hero/08_fitness_BF16.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ .venv/
5
+ env/
6
+
7
+ # Model weights — too large for code repo. Push to HF Hub separately
8
+ # (see UPLOAD_TO_HF.md). Hero samples are kept under sample_outputs/hero/
9
+ # for the model card; per-batch outputs are gitignored.
10
+ mlx_models/
11
+
12
+ # Per-batch sample outputs are scratch; only the curated hero set is checked in.
13
+ sample_outputs/*.png
14
+ sample_outputs/showcase/
15
+ sample_outputs/showcase_q6/
16
+ sample_outputs/showcase_q8/
17
+ sample_outputs/showcase_creative/
18
+ sample_outputs/showcase_realism/
19
+ sample_outputs/showcase_antiplastic/
20
+ sample_outputs/cinematic_*/
21
+ sample_outputs/artifact_test/
22
+ sample_outputs/ab_mflux/
23
+
24
+ # Logs are local-only
25
+ logs/
CLAUDE.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HIDREAM-O1-MLX-LAB — agent manual
2
+
3
+ **Read this first** when entering this lab.
4
+
5
+ ## What this is
6
+
7
+ A standalone lab for porting **HiDream-O1-Image-Dev** (8B Qwen3-VL-based unified pixel-patch transformer, MIT licence) to **MLX** for fast local image generation on Apple Silicon. Status as of 2026-05-09: **shipped to Phosphene `dev` branch**. Lab continues to host the conversion + inference scripts and serve as the home for future work (edit/multi-ref, 2048 generation, post-process experiments).
8
+
9
+ ## Where it lives
10
+
11
+ - **This dir**: `/Users/salo/HIDREAM-O1-MLX-LAB-active/`
12
+ - **Branch**: `perf-lab-hidream-o1-mlx` (local-only git, no remote yet)
13
+ - **Outside `~/pinokio/`** deliberately so Pinokio cleanup can't touch it
14
+ - **README.md** marker at root: "DO NOT DELETE"
15
+ - **Phosphene integration** lives in `~/pinokio/api/phosphene-dev.git/agent/image_engine.py` (`kind="hidream"`), shipped on `dev` branch
16
+
17
+ ## Session-start protocol
18
+
19
+ 1. `git fetch && git status -sb` — check the branch
20
+ 2. Read **STATE.md** — current state, recent work, open items
21
+ 3. Read **docs/EVALUATION.md** — what HiDream is good/weak at, A/B vs mflux, perf numbers
22
+ 4. Read **docs/HIDREAM_O1_MLX_PORT_REPORT.md** — architecture details, weight conversion, Q4 vs Q8 finding
23
+ 5. Read **docs/PHOSPHENE_INTEGRATION_PLAN.md** — what we shipped to Phosphene and how
24
+
25
+ ## Layout
26
+
27
+ ```
28
+ .
29
+ ├── README.md DO NOT DELETE marker
30
+ ├── CLAUDE.md this file
31
+ ├── STATE.md current state, open items
32
+ ├── docs/
33
+ │ ├── EVALUATION.md quality + perf, A/B vs mflux, blend experiment
34
+ │ ├── HIDREAM_O1_MLX_PORT_REPORT.md architecture, weight conversion
35
+ │ └── PHOSPHENE_INTEGRATION_PLAN.md integration plan + actual diff
36
+ ├── scripts/hidream_o1/
37
+ │ ├── flow_match.py FlashFlowMatch scheduler in MLX
38
+ │ ├── pipeline_helpers.py T2I sample, mrope, mask, patchify
39
+ │ ├── hidream_model.py custom heads + forward_generation
40
+ │ ├── convert_hidream_o1_to_mlx.py HF safetensors -> MLX, Q4/6/8
41
+ │ ├── generate_hidream_o1_mlx.py T2I generator (CLI entry-point)
42
+ │ ├── _compile_bench.py mx.compile A/B bench (0% gain — bandwidth-bound)
43
+ │ └── showcase_batch.sh 10-prompt showcase battery
44
+ ├── notes/weight_map.json cached HF safetensors index
45
+ ├── mlx_models/
46
+ │ ├── hidream-o1-dev-q4/ (5.6 GB backbone + 75 MB custom heads)
47
+ │ └── hidream-o1-dev-q8/ (9.96 GB backbone + 75 MB custom heads)
48
+ ├── sample_outputs/ generated samples (gitignored)
49
+ ├── logs/ run logs
50
+ └── .venv/ uv venv (mlx 0.31.2, mlx-vlm 0.5.0, transformers 5.8.0)
51
+ ```
52
+
53
+ ## How to run
54
+
55
+ ```bash
56
+ # Generate one image (Q8 recommended)
57
+ .venv/bin/python scripts/hidream_o1/generate_hidream_o1_mlx.py \
58
+ --model-path mlx_models/hidream-o1-dev-q8 \
59
+ --prompt "your prompt here" \
60
+ --width 1024 --height 1024 \
61
+ --output sample_outputs/whatever.png \
62
+ --seed 42
63
+
64
+ # Re-convert from HF (only needed if you delete mlx_models/)
65
+ .venv/bin/python scripts/hidream_o1/convert_hidream_o1_to_mlx.py \
66
+ --hf-source HiDream-ai/HiDream-O1-Image-Dev \
67
+ --out-dir mlx_models/hidream-o1-dev-q8 \
68
+ --bits 8 --check-disk
69
+ ```
70
+
71
+ ## Hard rules
72
+
73
+ 1. **Q8 only.** Q4 ships dark; the bright/colourful ground truth comes back at Q8. Documented in EVALUATION.md.
74
+ 2. **`s_noise=7.5` is load-bearing.** Lowering it collapses the image. FlashFlowMatch tuned for the Dev distillation.
75
+ 3. **28 steps.** Dev was distilled to 28; lower is undertrained.
76
+ 4. **Splitting safetensors after conversion = land mine.** The original converter overwrote source mmap mid-read and zeroed every weight silently. Now split happens inside the converter in one pass; never re-read+overwrite the same file.
77
+ 5. **Custom heads go into `extras/custom_heads.safetensors`** (subfolder so mlx-vlm's `glob *.safetensors` doesn't pick them up).
78
+ 6. **Phosphene `agent/image_engine.py` calls this lab via subprocess** — don't import mlx-vlm into Phosphene's interpreter.
79
+ 7. **No edit/multi-ref support yet.** Architecture supports it, lab pipeline doesn't. Refs through Phosphene continue to use `mflux qwen-edit`.
80
+
81
+ ## Performance ceiling
82
+
83
+ `mx.compile` on the forward pass = **0% gain**. We are bandwidth-bound on the 36-layer Q8 decoder. **2.36 s/step at 1024 is the floor** on this hardware. To go faster you need a smaller distillation, fewer steps, or text-cache reuse across denoising steps (~2-5% gain at most, very invasive).
84
+
85
+ ## Identity rules
86
+
87
+ - Lab repo is local-only, **no remote** — commit author is `hidream-o1-mlx-lab <lab@local>` (cosmetic; doesn't matter)
88
+ - **Phosphene-dev.git commits**: identity is `mrbizarro <mrbizarro@users.noreply.github.com>`. **No Co-Authored-By trailer.** Branch is `dev`, never `main` without explicit OK.
89
+
90
+ ## Cross-references
91
+
92
+ - Phosphene CLAUDE.md: `~/pinokio/api/phosphene-dev.git/CLAUDE.md`
93
+ - HF model: https://huggingface.co/HiDream-ai/HiDream-O1-Image-Dev
94
+ - Reference repo: https://github.com/HiDream-ai/HiDream-O1-Image
95
+ - mlx-vlm qwen3_vl: https://github.com/Blaizzy/mlx-vlm/tree/main/mlx_vlm/models/qwen3_vl
LICENSE ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 mrbizarro and contributors
4
+
5
+ This project (hidream-o1-mlx) is an MLX port of HiDream-O1-Image-Dev for Apple
6
+ Silicon. The upstream HiDream-O1-Image source code (https://github.com/HiDream-ai/HiDream-O1-Image)
7
+ and the model weights (https://huggingface.co/HiDream-ai/HiDream-O1-Image-Dev)
8
+ are released under the MIT License by HiDream-ai. This port preserves that
9
+ license.
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ base_model: HiDream-ai/HiDream-O1-Image-Dev
4
+ tags:
5
+ - mlx
6
+ - mlx-vlm
7
+ - hidream
8
+ - text-to-image
9
+ - apple-silicon
10
+ - bf16
11
+ language:
12
+ - en
13
+ pipeline_tag: text-to-image
14
+ library_name: mlx
15
+ inference: false
16
+ ---
17
+
18
+ # HiDream-O1-Image-Dev — MLX port for Apple Silicon
19
+
20
+ A native MLX port of [HiDream-ai/HiDream-O1-Image-Dev](https://huggingface.co/HiDream-ai/HiDream-O1-Image-Dev) for fast local image generation on Apple Silicon Macs. **No PyTorch, no CUDA, no flash-attn required at inference time.**
21
+
22
+ HiDream-O1 is an 8B Qwen3-VL-based **unified pixel-patch transformer** — it predicts raw 32×32 RGB patches directly through the same backbone that handles text, with no separate VAE. The Dev variant is a 28-step distillation of the 50-step Full model, released under the MIT license.
23
+
24
+ This port:
25
+ - Reuses [`mlx-vlm`](https://github.com/Blaizzy/mlx-vlm)'s Qwen3-VL backbone (vision tower, decoder layers, mrope-3D)
26
+ - Adds the three diffusion-side custom heads (`t_embedder1`, `x_embedder`, `final_layer2`)
27
+ - Ports the `FlashFlowMatchEulerDiscreteScheduler` and the unified-token-sequence builder
28
+ - Ships **BF16 weights** (no quantization — see "Why BF16" below)
29
+
30
+ ## Hero samples
31
+
32
+ All generated by the included generator script on a 64 GB Mac Studio. Click any image to open full-resolution.
33
+
34
+ <table>
35
+ <tr>
36
+ <td><a href="sample_outputs/hero/04_construction_worker.png"><img src="sample_outputs/hero/04_construction_worker.png" width="350"/></a></td>
37
+ <td><a href="sample_outputs/hero/01_tea_master.png"><img src="sample_outputs/hero/01_tea_master.png" width="350"/></a></td>
38
+ </tr>
39
+ <tr>
40
+ <td>Construction worker on a rainy rooftop, Kodak Tri-X B&amp;W. 2048×2048, BF16, 213s.</td>
41
+ <td>Elderly Japanese tea master holding a ceramic cup. 1024×1024, Q6 (showcase), 36s.</td>
42
+ </tr>
43
+
44
+ <tr>
45
+ <td><a href="sample_outputs/hero/02_tropical_beach.png"><img src="sample_outputs/hero/02_tropical_beach.png" width="350"/></a></td>
46
+ <td><a href="sample_outputs/hero/07_kitchen_morning.png"><img src="sample_outputs/hero/07_kitchen_morning.png" width="350"/></a></td>
47
+ </tr>
48
+ <tr>
49
+ <td>Tropical beach with turquoise water and palms. 1024×1024, Q8, 67s.</td>
50
+ <td>Candid morning portrait, woman with coffee + toast, soft window light. 1440×2560, BF16, 127s.</td>
51
+ </tr>
52
+
53
+ <tr>
54
+ <td><a href="sample_outputs/hero/03_astronaut.png"><img src="sample_outputs/hero/03_astronaut.png" width="350"/></a></td>
55
+ <td><a href="sample_outputs/hero/05_mountain_peak.png"><img src="sample_outputs/hero/05_mountain_peak.png" width="350"/></a></td>
56
+ </tr>
57
+ <tr>
58
+ <td>Astronaut in space-station corridor, anamorphic lens flare. 2560×1440, BF16, 187s.</td>
59
+ <td>Snow-capped mountain peak at sunset. 2048×2048, Q4 (early), 236s.</td>
60
+ </tr>
61
+
62
+ <tr>
63
+ <td><a href="sample_outputs/hero/06_alice_cyberpunk.png"><img src="sample_outputs/hero/06_alice_cyberpunk.png" width="350"/></a></td>
64
+ <td><a href="sample_outputs/hero/08_fitness_BF16.png"><img src="sample_outputs/hero/08_fitness_BF16.png" width="350"/></a></td>
65
+ </tr>
66
+ <tr>
67
+ <td>Alice in cyberpunk, neon Cheshire cat hologram. 2048×2048, Q8, 276s.</td>
68
+ <td>Fitness influencer mid-deadlift in industrial gym. 1440×2560, BF16, 127s.</td>
69
+ </tr>
70
+ </table>
71
+
72
+ More: [`sample_outputs/hero/`](sample_outputs/hero/).
73
+
74
+ ## Why BF16, not Q4/Q6/Q8
75
+
76
+ | Quant | Backbone size | 1024×1024 wall | Quality |
77
+ |---|---|---|---|
78
+ | Q4 | 5.6 GB | 25 s | ❌ Brightness collapses — ships dark |
79
+ | Q6 | 8 GB | 36 s | ⚠ Visible 32-px patch grid at non-square dims |
80
+ | Q8 | 10 GB | 67 s | ⚠ Same — works only at square 2048×2048 |
81
+ | **BF16** | **17.55 GB** | **67 s** | ✅ Clean across all trained dimensions |
82
+
83
+ Per-group dequantization rounding compounds across the 36 decoder layers and shows as a 32-pixel grid in flat regions (skies, walls, water). BF16 matches the upstream's `torch_dtype=torch.float32 + autocast(bfloat16)` precision and is the only quant we tested that produces clean output across all trained dimensions. On a 64 GB Mac the 16 GB working set is comfortable; on 32 GB it's tight — use Q8 at square 2048×2048 there.
84
+
85
+ ## Install
86
+
87
+ Requires macOS on Apple Silicon (M1 or newer). Tested on macOS 14+ with a 64 GB Mac Studio.
88
+
89
+ ```bash
90
+ git clone https://github.com/<you>/hidream-o1-mlx
91
+ cd hidream-o1-mlx
92
+ uv venv --python 3.11
93
+ uv pip install -r requirements.txt
94
+
95
+ # Convert the upstream HF weights to MLX BF16 (~5 minutes, requires ~50 GB free disk)
96
+ .venv/bin/python scripts/hidream_o1/convert_hidream_o1_to_mlx.py \
97
+ --hf-source HiDream-ai/HiDream-O1-Image-Dev \
98
+ --out-dir mlx_models/hidream-o1-dev-bf16 \
99
+ --bits 16
100
+ ```
101
+
102
+ ## Usage
103
+
104
+ ```bash
105
+ # Single image, default 1024×1024 BF16
106
+ .venv/bin/python scripts/hidream_o1/generate_hidream_o1_mlx.py \
107
+ --model-path mlx_models/hidream-o1-dev-bf16 \
108
+ --prompt "your prompt here" \
109
+ --output sample_outputs/whatever.png \
110
+ --seed 42
111
+
112
+ # Higher resolution (2048×2048 = upstream default)
113
+ .venv/bin/python scripts/hidream_o1/generate_hidream_o1_mlx.py \
114
+ --model-path mlx_models/hidream-o1-dev-bf16 \
115
+ --prompt "..." \
116
+ --width 2048 --height 2048 \
117
+ --output sample_outputs/big.png
118
+
119
+ # Vertical / cinema (auto-snaps to nearest trained ratio)
120
+ .venv/bin/python scripts/hidream_o1/generate_hidream_o1_mlx.py \
121
+ --model-path mlx_models/hidream-o1-dev-bf16 \
122
+ --prompt "..." \
123
+ --width 1440 --height 2560 \
124
+ --output sample_outputs/portrait.png
125
+ ```
126
+
127
+ ### Trained resolutions
128
+
129
+ HiDream-O1 was trained on a fixed list of resolutions. The generator auto-snaps to the closest. Off-spec dims produce visible patch artifacts. The trained list:
130
+
131
+ ```
132
+ 2048×2048, 2304×1728, 1728×2304, 2560×1440, 1440×2560,
133
+ 2496×1664, 1664×2496, 3104×1312, 1312×3104, 2304×1792, 1792×2304
134
+ ```
135
+
136
+ ## Prompt tips for realism
137
+
138
+ HiDream is responsive to camera/film terminology. To avoid the AI-glossy look:
139
+
140
+ - Lead with `masterpiece, best quality` (community-found responder phrase)
141
+ - Subject + Actions → Setting → Style → Details ordering
142
+ - Specify equipment: `Leica M6 with Kodak Tri-X 400`, `Pentax K1000 + Cinestill 800T`, `Hasselblad H6D medium format`
143
+ - Reference real photographers: Sebastião Salgado, Saul Leiter, Wim Wenders, Annie Leibovitz, Anders Petersen
144
+ - Spell out skin imperfection: "natural pores", "faint laugh lines", "weathered hands", "no retouching"
145
+ - Avoid "stunning", "perfect", "beautiful" — they push toward AI-glamour aesthetics
146
+
147
+ The Dev model uses `guidance_scale=0.0` so negative prompts have no effect — push positive prompts harder instead.
148
+
149
+ ## What's in this repo
150
+
151
+ ```
152
+ hidream-o1-mlx/
153
+ ├── README.md (this file)
154
+ ├── LICENSE (MIT)
155
+ ├── requirements.txt (mlx-vlm 0.5.0, transformers 5.8+, deps)
156
+ ├── scripts/hidream_o1/
157
+ │ ├── convert_hidream_o1_to_mlx.py (HF → MLX, BF16 / Q4 / Q6 / Q8)
158
+ │ ├── generate_hidream_o1_mlx.py (T2I generator + experimental edit/multi-ref)
159
+ │ ├── hidream_model.py (custom heads + forward_generation)
160
+ │ ├── pipeline_helpers.py (T2I sample, mrope, mask, patchify)
161
+ │ └── flow_match.py (FlashFlowMatchScheduler in MLX)
162
+ ├── docs/
163
+ │ ├── EVALUATION.md (perf + quality findings, A/B vs mflux)
164
+ │ ├── HIDREAM_O1_MLX_PORT_REPORT.md (architecture + weight conversion details)
165
+ │ └── PHOSPHENE_INTEGRATION_PLAN.md (how it slots into a host app)
166
+ ├── sample_outputs/ (gallery)
167
+ └── mlx_models/ (where converted weights land)
168
+ ```
169
+
170
+ ## Performance
171
+
172
+ | Resolution | Per step | Total (28 steps) | Peak RAM |
173
+ |---|---|---|---|
174
+ | 1024×1024 | 2.4 s | 67 s | 16 GB |
175
+ | 1440×2560 | 4.5 s | 127 s | 16 GB |
176
+ | 2048×2048 | 6.7 s | 187 s | 16 GB |
177
+ | 3104×1312 | 7.6 s | 213 s | 16 GB |
178
+
179
+ `mx.compile` gives 0% speedup — the inference loop is bandwidth-bound on the 36-layer BF16 decoder. To go faster you'd need a smaller distillation (none public) or text-cache reuse across denoising steps.
180
+
181
+ ## Status
182
+
183
+ - ✅ Text-to-image: production-quality, BF16 default
184
+ - ✅ Native MLX, no PyTorch / CUDA / flash-attn at inference time
185
+ - ⚠ Edit / multi-reference: scaffolding present (`--ref-images` flag) but produces degenerate output — needs debugging. Refs through other engines (e.g. `mflux qwen-edit`) work correctly.
186
+ - ❌ Multi-reference subject personalization: same as above
187
+
188
+ ## Acknowledgements
189
+
190
+ - [HiDream-ai](https://github.com/HiDream-ai) for the original HiDream-O1-Image model + MIT license
191
+ - [Blaizzy/mlx-vlm](https://github.com/Blaizzy/mlx-vlm) for the Qwen3-VL MLX backbone (this port reuses their vision tower + decoder layers + mrope-3D wholesale)
192
+ - [Apple ml-explore/mlx](https://github.com/ml-explore/mlx) for the MLX framework
193
+ - The Civitai community's [HiDream prompt-engineering guide](https://civitai.com/articles/16050/hi-dream-prompt-engineering)
194
+
195
+ ## Citation
196
+
197
+ If you use this in research, cite the upstream model:
198
+
199
+ ```bibtex
200
+ @misc{hidream-o1-image,
201
+ author = {HiDream-ai},
202
+ title = {HiDream-O1-Image: Pixel-Level Unified Transformer},
203
+ year = {2026},
204
+ url = {https://github.com/HiDream-ai/HiDream-O1-Image}
205
+ }
206
+ ```
207
+
208
+ ## License
209
+
210
+ MIT — see [LICENSE](LICENSE).
STATE.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HIDREAM-O1-MLX-LAB — STATE
2
+
3
+ **Last updated:** 2026-05-09 (session that landed Q8 + Phosphene integration)
4
+
5
+ ## TL;DR — where we are
6
+
7
+ - **Q6 is the new sweet spot.** 1.30 s/step at 1024×1024, ~36 s per image, ~8.5 GB RAM. 2× faster than Q8 with equivalent quality.
8
+ - Q8 still works (2.36 s/step, 11.5 GB RAM) — keep it for deterministic upper-bound RAM use cases.
9
+ - Q4 deleted from disk: ships dark, no reason to keep around (regenerable in 5 min if needed).
10
+ - Backbone sizes: Q6 backbone 7.95 GB, Q8 backbone 9.96 GB. Custom heads 75 MB.
11
+ - **Shipped to Phosphene `dev`** as `kind="hidream"` in `agent/image_engine.py` (commits `45cad69`, `962b353`). Default model on Phosphene side will be updated to Q6.
12
+ - Showcase battery + A/B vs mflux Z-Image-Turbo done at Q8. At Q6, HiDream is now ~2× faster than Z-Image-Turbo (36s vs 80s) AND has lower deterministic RAM (8.5 GB vs 5.9–29.4 GB variable).
13
+ - 19+ sample images in `sample_outputs/`.
14
+ - Lab branch: `perf-lab-hidream-o1-mlx`, **no remote**.
15
+
16
+ ## What's been done
17
+
18
+ | Date | Work | Commit |
19
+ |---|---|---|
20
+ | 2026-05-09 | Initial scaffolding (Path B chosen) | `746efe9` |
21
+ | 2026-05-09 | Wire mlx-vlm Qwen3VLModel directly (4D mask path) | `53eb605` |
22
+ | 2026-05-09 | First working images (mushroom 512, cat/beach/portrait 1024 at Q4) | `d944a31` |
23
+ | 2026-05-09 | Q8 conversion + samples (dark aesthetic was Q4, not the model) | `2bf029a` |
24
+ | 2026-05-09 | Showcase battery + evaluation + Phosphene plan | `0bac049` |
25
+ | 2026-05-09 | Phosphene integration shipped to `dev` | phos `45cad69` |
26
+ | 2026-05-09 | A/B vs mflux Z-Image-Turbo on 3 prompts | `2761ad8` |
27
+ | 2026-05-09 | Phosphene IMAGE_GEN_RESEARCH doc updated | phos `962b353` |
28
+ | 2026-05-09 | --blend-seams post-process (opt-in, below-threshold at Q8) | `0583356` |
29
+ | 2026-05-09 | Q6 = sweet spot (2× faster than Q8, same quality) — Phosphene default switched | `f4fb0ba` + phos `8a48953` |
30
+ | 2026-05-09 | Q6 verified across 10-prompt showcase battery | `4d3f18c` |
31
+ | 2026-05-09 | Edit/multi-ref scaffold (WIP — runs but output degenerate) | `525b7ec` |
32
+ | 2026-05-09 | BF16 default — Q4/Q6/Q8 all show patch-grid at non-square dims | (next) |
33
+ | 2026-05-09 | Phosphene default switched to BF16 | phos `af94bd0` |
34
+ | 2026-05-09 | OSS release prep: HF model card, LICENSE, requirements, gitignore | (next) |
35
+
36
+ ## Known characteristics (not bugs)
37
+
38
+ - **Patch grid in flat regions** — architectural (PATCH_SIZE=32 with no overlap). Mild at Q8. `--blend-seams 1` is opt-in but doesn't visibly help.
39
+ - **Text rendering** — short, structured signs work ("BLOOM CAFE"). Long text falls apart.
40
+ - **Deterministic per-prompt RAM** — 11.5 GB at 1024 Q8 regardless of prompt complexity. Z-Image-Turbo varies wildly (5.9–29.4 GB).
41
+
42
+ ## Open work / next session candidates
43
+
44
+ Pick from these, listed roughly cheapest-first:
45
+
46
+ 1. ~~**2048×2048 Q8 generation pass**~~ — DONE 2026-05-09. `sample_outputs/v4_2048_alice_q8.png` — 276 s (9.86 s/step), peak RAM 10.8 GB. Q8 at 2048 is slower per step than Q4 (10s vs 8.4s) due to bandwidth. Output is showcase-grade: detailed cybernetic dress, holographic Cheshire cat, near-legible neon signs.
47
+ 2. **Test the Phosphene integration through the dev panel UI** (port 8199). Generate one shot via the Image Studio dropdown, confirm pill goes green, the PNG lands.
48
+ 3. **Edit / multi-reference path** — SCAFFOLD LANDED, NEEDS DEBUGGING.
49
+ - `build_edit_text_sample`, `resize_pilimage`, `calculate_dimensions`, `patchify_ref_image` all ported from upstream pipeline.py + utils.py.
50
+ - `--ref-images` flag wired in generate_hidream_o1_mlx.py.
51
+ - `precompute_text_embeds_with_vision` precomputes the text+vision embeds once before the loop (since they don't change with timestep) — a meaningful perf win.
52
+ - **Smoke test (synthesized two-color ref, K=1, 28 steps, Q6) runs end-to-end without errors but output is uniform tan/khaki.** T2I path with same prompt+seed produces a vibrant abstract correctly, so the model and weights are fine.
53
+ - Debugging done so far (see `scripts/hidream_o1/_edit_diag.py` and `_precompute_diag.py`):
54
+ - **All shapes verified correct** (input_ids 174 with 144 image-placeholders, vision tower outputs 144 features, vinput_mask = 256 tgt + 256 ref, position_ids 686 covering all spans).
55
+ - **Vision feature scatter verified mathematically correct** — at image_token positions `combined` equals `image_features` exactly (diff=0); at text positions `combined` equals `embed_tokens(input_ids)` exactly (diff=0). Vision features are well-behaved (mean ~0, std ~0.4).
56
+ - **Position_ids structure looks right** — text positions are sequential, target span gets fix_point=4096 base (per upstream), ref diffusion span continues sequentially.
57
+ - **Remaining suspects** (in order of likelihood):
58
+ - Mask construction: maybe text-row causal needs to ALSO see the K image_placeholder positions inside proc.input_ids? Upstream `_run_decoder_flash` has special handling — the non-flash 4D mask path may treat text positions as needing to see embedded vision features. Worth re-reading qwen3_vl_transformers.py:1486-1520.
59
+ - Position_ids semantic alignment: my appended-vinputs at positions [174..686) get mrope codes from input_ids_pad's vision_tokens portion, but maybe these need to match the appended embedding ORDER not just their positions in input_ids_pad.
60
+ - bf16 underflow in attention with the larger 686-token sequence vs T2I's 268.
61
+ - Samples: `sample_outputs/v6_edit_smoke.png` (degenerate, synthesized 2-color ref), `sample_outputs/v6_edit_cat_real.png` (degenerate, real cat photo as ref), `sample_outputs/v6_edit_t2i_baseline.png` (T2I works fine same prompt+seed).
62
+ - This is the single biggest open item. Would let HiDream replace mflux qwen-edit functionally.
63
+ 4. **Promote Phosphene integration to `main`** after the user has tested on dev panel.
64
+ 5. **Quality-aware post-process** — try a cheap learned upscaler instead of the seam blend (e.g. SeedVR2 via mflux's `mflux-upscale-seedvr2` to take 1024 → 2048).
65
+ 6. **Text-cache reuse across denoising steps** — fork mlx-vlm's Qwen3VLModel to cache the text-portion KV across the 28 denoising calls. ~2-5% speedup max but a real architectural improvement.
66
+
67
+ ## Hard stop conditions (still relevant)
68
+
69
+ - Q4 ships dark — established. Use Q8.
70
+ - mx.compile = 0% gain — established. Inference loop is at the floor.
71
+ - Splitting safetensors mid-read zeroed weights — fixed in converter; don't re-introduce.
72
+
73
+ ## How to ramp up fast (next session)
74
+
75
+ 1. `cd /Users/salo/HIDREAM-O1-MLX-LAB-active`
76
+ 2. `cat README.md CLAUDE.md STATE.md docs/EVALUATION.md` (in that order)
77
+ 3. `git log --oneline | head -10` to see where we are
78
+ 4. `ls sample_outputs/` to see what's been generated
79
+ 5. To regenerate or extend: see the commands in CLAUDE.md
80
+
81
+ ## Disk situation snapshot
82
+
83
+ As of 2026-05-09 the data volume `/dev/disk3s5` had ~45 GB free of 926 GB after the user's mid-session cleanup (deleted `phosphene-model-lab.git` and `comfy.git`, freed ~83 GB). The lab itself is ~16 GB on disk (10 GB Q8 + 6 GB Q4 models + 1.5 GB venv + samples + lab code). **Do not re-download** the HiDream HF source unless `mlx_models/hidream-o1-dev-q4/` AND `mlx_models/hidream-o1-dev-q8/` both go missing — both can be regenerated from one HF download.
UPLOAD_TO_HF.md ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to publish this to Hugging Face
2
+
3
+ Two-repo plan (recommended):
4
+
5
+ 1. **Code repo** on GitHub: `https://github.com/<you>/hidream-o1-mlx`
6
+ 2. **Weights repo** on HF Hub: `https://huggingface.co/<you-or-org>/HiDream-O1-Image-Dev-mlx-bf16`
7
+
8
+ Linking them keeps clones fast (people who just want the code don't pull the 17 GB safetensors via LFS) while still making the weights one-click pullable from `huggingface_hub`.
9
+
10
+ ## Option A — separate code + weights repos
11
+
12
+ ### 1. Code repo on GitHub
13
+
14
+ ```bash
15
+ cd /Users/salo/HIDREAM-O1-MLX-LAB-active
16
+
17
+ # Initialize a new public-friendly remote
18
+ git remote add origin git@github.com:<you>/hidream-o1-mlx.git
19
+
20
+ # .gitignore mlx_models/ so weights don't go to GitHub (HF Hub will host them)
21
+ echo "mlx_models/" >> .gitignore
22
+
23
+ git add -A
24
+ git commit -m "Initial public release"
25
+ git branch -M main
26
+ git push -u origin main
27
+ ```
28
+
29
+ ### 2. Weights repo on HF Hub
30
+
31
+ ```bash
32
+ # Install hf CLI if needed
33
+ pip install huggingface_hub
34
+
35
+ # Login once
36
+ hf auth login # paste a write-token from https://huggingface.co/settings/tokens
37
+
38
+ # Create the repo
39
+ hf repo create HiDream-O1-Image-Dev-mlx-bf16 --type model
40
+
41
+ # Upload only the weights dir + config + tokenizer + custom heads
42
+ cd mlx_models/hidream-o1-dev-bf16
43
+ hf upload <you>/HiDream-O1-Image-Dev-mlx-bf16 . . \
44
+ --repo-type model \
45
+ --commit-message "Initial BF16 release"
46
+ ```
47
+
48
+ What gets uploaded (~17.5 GB total):
49
+ - `model.safetensors` (17 GB) — backbone, mlx-vlm-loadable
50
+ - `extras/custom_heads.safetensors` (75 MB) — diffusion-side heads
51
+ - `config.json` — Qwen3-VL config (no `quantization` field for BF16)
52
+ - `tokenizer.json`, `tokenizer_config.json`, `vocab.json`, `merges.txt`, `chat_template.json`
53
+ - `preprocessor_config.json`, `video_preprocessor_config.json`
54
+ - `mlx_lab_meta.json` — provenance marker
55
+
56
+ ### 3. Cross-reference
57
+
58
+ In the GitHub README, point to the HF weights repo. In the HF model card README (which we already prepped), point to the GitHub code.
59
+
60
+ ## Option B — single HF repo with everything
61
+
62
+ If you want the simplest user experience (`hf download <repo>` → ready to run):
63
+
64
+ ```bash
65
+ hf repo create hidream-o1-mlx --type model
66
+ cd /Users/salo/HIDREAM-O1-MLX-LAB-active
67
+
68
+ # Track .py + .md as plain files; .safetensors via LFS (already in .gitattributes)
69
+ git remote add hf https://huggingface.co/<you>/hidream-o1-mlx
70
+ git lfs install
71
+ git lfs track "*.safetensors"
72
+ git add -A
73
+ git commit -m "Initial release"
74
+ git push hf main
75
+ ```
76
+
77
+ People then do:
78
+
79
+ ```bash
80
+ hf download <you>/hidream-o1-mlx
81
+ cd hidream-o1-mlx
82
+ uv venv --python 3.11 && uv pip install -r requirements.txt
83
+ .venv/bin/python scripts/hidream_o1/generate_hidream_o1_mlx.py --prompt "..." --output out.png
84
+ ```
85
+
86
+ ## What NOT to upload
87
+
88
+ - `.venv/` — gitignored
89
+ - `logs/` — gitignored
90
+ - `notes/` — internal scratch, optional
91
+ - `__pycache__/` — gitignored
92
+ - `mlx_models/hidream-o1-dev-q4/` and `q6/`, `q8/` — only ship BF16. They're regenerable with `--bits 4|6|8` and have known quality issues at non-square dims.
93
+
94
+ ## Pre-flight checklist
95
+
96
+ - [ ] LICENSE file (MIT) at root — done
97
+ - [ ] README.md as HF-format model card — done
98
+ - [ ] requirements.txt with pinned versions — done
99
+ - [ ] .gitattributes for LFS — done
100
+ - [ ] No personal paths (`/Users/salo/...`) hardcoded in scripts that aren't optional — verify with `grep -r "/Users/salo" scripts/`
101
+ - [ ] Sample images included for the model card — copy 4-6 best to `sample_outputs/hero/`
102
+ - [ ] Test fresh clone install on a different machine if possible
docs/EVALUATION.md ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiDream-O1-Image-Dev (Q8 MLX) — evaluation
2
+
3
+ **Setup:** lab branch `perf-lab-hidream-o1-mlx`, mlx-vlm 0.5.0 + mlx 0.31.2, Mac Studio (64 GB).
4
+ **Recipe:** Dev — 28 steps, FlashFlowMatch, `s_noise=7.5`, `noise_clip_std=2.5`, `shift=1.0`.
5
+ **All times** are honest wall-clock with `mx.eval` per step. **All RAM** is peak `maximum resident set size`.
6
+
7
+ ## Q6 showcase verification (2026-05-09 evening)
8
+
9
+ Re-ran the same 10-prompt battery at Q6 with identical seeds. **All 10 are visually equivalent or better than the Q8 versions:**
10
+
11
+ - 9/10 are near-pixel-identical aesthetics (different latent noise from quant differences yields same compositions / lighting / subjects)
12
+ - **10 (text rendering) is visibly better at Q6** — "BLOOM CAFE" neon sign is crisp at Q6 vs a glitched "M" at Q8
13
+
14
+ Per-image timing was rock-steady at **35.9 s** (1.28 s/step). Total battery time: ~6 minutes vs ~12 minutes at Q8.
15
+
16
+ Outputs: `sample_outputs/showcase_q6/` (compare against `sample_outputs/showcase/` for the Q8 originals).
17
+
18
+ ## Battery: 10 prompts, 1024×1024, all Q8
19
+
20
+ | # | Genre | Prompt summary | Result | Time |
21
+ |---|---|---|---|---|
22
+ | 01 | photo portrait | elderly Japanese tea master | **Excellent** — face character, gentle smile, paper screens, calligraphy | 81.5 s* |
23
+ | 02 | anime / illustration | pink-haired girl on Tokyo rooftop at dusk | **Excellent** — anime style + cherry blossoms + neon city below | 65.3 s |
24
+ | 03 | macro photo | dewdrop on spiderweb | **Excellent** — refractions, blurred leaf bg, crisp web detail | 65.9 s |
25
+ | 04 | architecture | futuristic library, holographic displays | **Excellent** — vaulted ceiling, stained glass, holo screens | 66.3 s |
26
+ | 05 | surreal painting | whale floating over desert at sunset | **Excellent** — magical realism, painterly clouds | 65.8 s |
27
+ | 06 | food flatlay | rustic Italian breakfast on marble | **Excellent** — golden croissants, espresso, berries, soft light | 66.4 s |
28
+ | 07 | cinematic action | samurai mid-leap with katana, Mt. Fuji bg | **Excellent** — dynamic pose, cherry blossoms, real mountain | 66.1 s |
29
+ | 08 | fantasy | dragon on crystal mountain with aurora | **Excellent** — iridescent scales, snow swirling, aurora visible | 66.4 s |
30
+ | 09 | wildlife photo | snow leopard staring at camera | **Excellent** — direct gaze, falling snow, mountain bg | 67.1 s |
31
+ | 10 | text rendering | "BLOOM CAFE" pink neon diner | **Good** — sign legible (small "M" glitch), retro diner, rainy street | 67.1 s |
32
+
33
+ *Image 01 included cold model load (~12-15 s).
34
+
35
+ **Steady-state per-image: 65-67 s at 1024×1024 Q8.** Dead-consistent across genres.
36
+
37
+ ## Honest timings
38
+
39
+ | Resolution | Quant | Per step | Total (28 steps) | Peak RAM |
40
+ |---|---|---|---|---|
41
+ | 512×512 | Q4 | 0.89 s | 24.9 s | ~6 GB |
42
+ | 1024×1024 | Q4 | 2.37 s | 66 s | ~6 GB |
43
+ | 1024×1024 | **Q6** | **1.30 s** | **36 s** | **~8.5 GB** |
44
+ | 1024×1024 | Q8 | 2.36 s | 66 s | ~11.5 GB |
45
+ | 1280×704 | Q8 | 2.53 s | 70.7 s | ~7 GB |
46
+ | 704×1280 | Q8 | 2.35 s | 65.9 s | ~3 GB (warm cache) |
47
+ | 2048×2048 | Q4 | 8.44 s | 236 s | ~7.2 GB |
48
+ | 2048×2048 | Q8 | 9.86 s | 276 s | ~10.8 GB |
49
+
50
+ **Q6 is the sweet spot.** 2× faster than Q8 at 1024 with the same prompt fidelity (cat in sunlit kitchen + beach with palm trees both rendered identically to Q8 outputs). 30% less RAM. The bandwidth-bound theory holds: fewer bits per param → less weight bandwidth → faster per-step.
51
+
52
+ **Q4 corrupts brightness** (ships dark) so the speed of Q4 vs Q6 is academic — never use Q4 for production. Q6 has the speed and Q8 has the steady-state safety; Q6 wins on perf, Q8 wins on a deterministic upper bound on RAM.
53
+
54
+ ## Where HiDream-O1-Image-Dev shines
55
+
56
+ - **Subject identity** — every prompt subject was rendered correctly. No "vibrant orange tabby" → cat-shape-blob. The model knows what things look like.
57
+ - **Multi-element scenes** — samurai + Fuji + cherry blossoms; cyberpunk Alice + neon Cheshire cat + circuit dress + rain. Composition stays coherent.
58
+ - **Style adherence** — anime ≠ photorealism ≠ oil painting ≠ macro. Got all four right.
59
+ - **Light realism** — the architecture image's light through stained glass; the food flatlay's morning warmth; the action scene's sunset rim lighting. Light feels real, not stamped on.
60
+ - **Text rendering** (limited) — "BLOOM CAFE" in neon was readable. Better than most diffusion models; not as clean as a model with explicit OCR pretraining.
61
+
62
+ ## Where it's weak
63
+
64
+ - **Patch-grid artifact** in flat regions. PATCH_SIZE=32 with no overlap → visible 32×32 grid in skies, water, walls. Most visible at low-frequency content. Architectural — not fixable without retraining or an overlap-blending postprocess.
65
+ - **Q4 brightness collapse** — Q4 desaturates and darkens everything. Q8 fixes it. **Ship Q8.**
66
+ - **Hands** — hands when present in scenes (e.g. tea master holding cup) look fine at moderate detail, but the model isn't immune to the standard diffusion hand failure modes; haven't stress-tested.
67
+ - **Dense long text** — "BLOOM CAFE" is short and structured. A paragraph of text would likely fall apart.
68
+ - **Speed at 2048** — 4 minutes per image is slow for iterative work. Fine for a final pass.
69
+
70
+ ## Sweet spot
71
+
72
+ **1024×1024, Q6, default Dev recipe, ~36 s/image, ~8.5 GB RAM.** Bright/colourful output equivalent to Q8, half the wall time, 30% less RAM. 512 is fast (~25 s) but loses detail. 2048 is gorgeous but iterative-unfriendly.
73
+
74
+ **Quant decision tree:**
75
+ - 16 GB Mac → don't run HiDream; use mflux Z-Image-Turbo
76
+ - 32 GB Mac → Q6 is comfortable, Q8 leaves no headroom alongside LTX
77
+ - 64 GB Mac → Q6 default; Q8 only when you want deterministic upper-bound RAM
78
+
79
+ ## A/B vs mflux Z-Image-Turbo
80
+
81
+ Same prompts, same seeds, both at 1024×1024.
82
+
83
+ | # | Prompt | HiDream Q8 | Z-Image-Turbo Q4 (mflux) | Subjective winner |
84
+ |---|---|---|---|---|
85
+ | 1 | tea master | [v3](../sample_outputs/showcase/01_portrait_photo.png) — wide scene, paper screens, calligraphy | [zimg](../sample_outputs/ab_mflux/01_portrait_zimage.png) — tighter portrait, gray garment, smile | **Tie** — different framings, both excellent |
86
+ | 2 | sunlit beach | [v3](../sample_outputs/v3_1024_beach_q8.png) — turquoise water, palm trees, beach chair | [zimg](../sample_outputs/ab_mflux/02_beach_zimage.png) — vivid blue water, palms, big sand foreground | **Tie** — both nail the prompt |
87
+ | 3 | alice cyberpunk | [v3](../sample_outputs/v3_alice_horizontal_q8.png) (horizontal) — clear dress + face + Cheshire | [zimg](../sample_outputs/ab_mflux/03_alice_zimage.png) — more painterly, atmospheric Cheshire silhouette | **HiDream** for face/dress detail; **Z-Image** for atmosphere |
88
+
89
+ ### Speed + RAM (measured, not estimated)
90
+
91
+ | Engine | Steps | Wall (1024) | Per step | Peak RAM |
92
+ |---|---|---|---|---|
93
+ | HiDream-O1-Dev / Q8 | 28 | **67 s** | 2.41 s | **11.5 GB** |
94
+ | Z-Image-Turbo / Q4 | 9 | 80 s | 8.85 s | **5.9–29.4 GB** (varies by prompt) |
95
+
96
+ Surprises:
97
+ - HiDream is **faster per image** despite needing 28 steps vs Z-Image-Turbo's 9 — Z-Image's per-step cost is ~3.7× HiDream's.
98
+ - Z-Image's peak RAM **varied wildly across prompts** (5.9 GB for portrait, 29.4 GB for the alice cyberpunk). HiDream's peak was steady at ~11.5 GB regardless of prompt complexity.
99
+
100
+ ### Verdict
101
+
102
+ Both are excellent local engines. Pick by the workload:
103
+
104
+ - **Default/compact**: keep **Z-Image-Turbo** — 5.9 GB RAM on most prompts, runs anywhere.
105
+ - **Hero shots / max prompt fidelity**: **HiDream-O1-Q8** — faster wall time, deterministic memory, more environmental detail in the output.
106
+ - **Editing / multi-ref**: keep **mflux qwen-edit** — HiDream lab pipeline doesn't support refs yet.
107
+
108
+ ## Patch-grid post-blend experiment
109
+
110
+ Implemented `--blend-seams <radius>` post-process in `generate_hidream_o1_mlx.py`: after decoding the final image, average a thin band across each 32-pixel patch boundary line (radius=1 → blend the seam row with one neighbour on each side, then 50% blend back into the seam itself).
111
+
112
+ **Result on the same beach prompt + seed 11 + Q8:**
113
+
114
+ | Comparison | Mean abs diff (out of 255) |
115
+ |---|---|
116
+ | baseline vs blend r=1 | 0.18 |
117
+ | baseline vs blend r=2 | 0.23 |
118
+
119
+ Per-row breakdown confirms the blend is **surgical** — only seam rows (every 32) change, by 1–2.7 pixel values; non-seam rows shift by <0.2. So the math is doing exactly what it says.
120
+
121
+ **But visually**: at Q8 the seam artifact is already mild. The blend's 1–2 pixel-value smoothing is below visual threshold. No win, but no harm — and zero added latency (numpy vector ops on a 1024×1024 image are sub-ms).
122
+
123
+ Bottom line: kept as opt-in flag `--blend-seams 1`. Did not enable by default. The real fix for the patch grid would need overlap-blended patches (architectural change) or a stronger spatial filter (which would visibly blur the image).
124
+
125
+ ## Software-side speed: nothing left
126
+
127
+ Tested `mx.compile` on the forward pass: **0% improvement** (2.366 s/step compiled vs 2.368 s/step uncompiled). The forward is already bandwidth-bound by the 36-layer Q8 decoder's matmul stream — MLX is already at near-GPU-saturation. Same conclusion for `mx.fast.scaled_dot_product_attention` (already used inside mlx-vlm's Qwen3VLAttention).
128
+
129
+ **The path to faster is architectural, not algorithmic:**
130
+ - Fewer steps (would need a smaller distillation; Dev is already the distilled variant)
131
+ - Smaller backbone (would need re-distillation onto a 4B Qwen3-VL — no public version)
132
+ - Caching the text-portion hidden states across denoising steps — possible but invasive (would need to subclass mlx-vlm's Qwen3VLModel; ~2-5% speedup at best since text is <2% of seq length)
133
+
134
+ ## Verdict
135
+
136
+ - **Working.** Q8 produces real, prompt-faithful, high-quality images at ~67 s/1024.
137
+ - **No more easy speedups.** The lab's inference loop is already at the floor for this architecture on this hardware.
138
+ - **Patch artifacts are real but mild.** Low-frequency regions show a 32-pixel grid. Subjects-with-content scenes hide it well.
139
+ - **Q8 is the only acceptable quant.** Q4 ships dark. If we ever want a smaller variant, would need different bit packing or selective Q6.
140
+
141
+ ## Recommendation for Phosphene
142
+
143
+ Slot it in as a third local engine alongside `mflux Z-Image-Turbo` (compact tier) and `mflux FLUX.2-klein-4B` (comfortable tier). Mark HiDream as **comfortable+** (32 GB+) due to the 11.5 GB working set. Don't make it the default — it's slower per image and uses more RAM than Z-Image-Turbo. Make it **the option** for users who want max prompt fidelity and license clarity (MIT, no NC restriction).
144
+
145
+ See [PHOSPHENE_INTEGRATION_PLAN.md](PHOSPHENE_INTEGRATION_PLAN.md) for the patch.
docs/HIDREAM_O1_MLX_PORT_REPORT.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiDream-O1-Image MLX port — working
2
+
3
+ Lab branch: `perf-lab-hidream-o1-mlx`
4
+ Lab path: `/Users/salo/HIDREAM-O1-MLX-LAB-active/`
5
+ Date: 2026-05-09
6
+ Status: **Shipped. Q8 inference end-to-end on Apple Silicon. Phosphene `dev` integration live (commit 45cad69).**
7
+
8
+ ---
9
+
10
+ ## TL;DR
11
+
12
+ - Path B (standalone MLX wrapper around `mlx-vlm` Qwen3-VL backbone) — confirmed viable.
13
+ - **Q8 is the right configuration.** Q4 corrupts the brightness distribution badly enough to make every image dark/moody. Q8 produces clean, fully prompt-faithful images.
14
+ - Sizes: Q4 backbone 5.6 GB, Q8 backbone 10 GB. Custom heads 75 MB.
15
+ - 512×512 in **24.9 s** Q4 (28 steps × 0.89 s). 1024×1024 in **65–67 s** at both Q4 and Q8 (28 steps × 2.4 s — Q8 not measurably slower because the bottleneck is bandwidth, not arithmetic). 2048×2048 in **236 s** Q4 (8.4 s/step).
16
+ - Peak RAM: Q4 ≈ 6 GB at 1024, ≈ 7.2 GB at 2048. Q8 ≈ 11.5 GB at 1024.
17
+ - 32 GB Mac plausible at Q4 only; 64 GB comfortable at Q8 + 2048.
18
+ - **Phosphene integration shipped to `dev`** (`agent/image_engine.py` `kind="hidream"`, commits `45cad69` + `962b353`). Available in the Image Studio engine dropdown.
19
+ - A/B vs `mflux` Z-Image-Turbo done — see [EVALUATION.md](EVALUATION.md). Both engines competitive; HiDream is faster per image (67s vs 80s) and uses deterministic memory; Z-Image is leaner most of the time but spikes on complex prompts.
20
+ - Edit + multi-reference paths still TODO — refs continue to flow through `mflux qwen-edit` per existing convention.
21
+
22
+ ## What landed
23
+
24
+ ```
25
+ /Users/salo/HIDREAM-O1-MLX-LAB-active/
26
+ ├── README.md (DO NOT DELETE marker)
27
+ ├── docs/HIDREAM_O1_MLX_PORT_REPORT.md (this file)
28
+ ├── notes/weight_map.json (HF safetensors index)
29
+ ├── scripts/hidream_o1/
30
+ │ ├── flow_match.py (FlashFlowMatch in MLX)
31
+ │ ├── pipeline_helpers.py (T2I sample, mrope, mask)
32
+ │ ├── hidream_model.py (custom heads + forward)
33
+ │ ├── convert_hidream_o1_to_mlx.py (HF -> MLX, Q4/6/8)
34
+ │ └── generate_hidream_o1_mlx.py (T2I generator)
35
+ ├── mlx_models/hidream-o1-dev-q4/
36
+ │ ├── model.safetensors (5.6 GB, mlx-vlm-loadable)
37
+ │ ├── extras/custom_heads.safetensors (75 MB)
38
+ │ ├── config.json (with "quantization" field)
39
+ │ └── tokenizer/processor metadata
40
+ ├── sample_outputs/
41
+ │ ├── v2_512_mushroom.png (24.9 s, "red mushroom on moss")
42
+ │ ├── v2_1024_cat.png (67.1 s, "tabby on wooden chair")
43
+ │ ├── v2_1024_beach.png (66.6 s, "sunlit beach with palms")
44
+ │ └── v2_1024_portrait.png (65.8 s, "portrait, red curly hair")
45
+ └── .venv/ (uv venv: mlx 0.31.2, mlx-vlm 0.5.0)
46
+ ```
47
+
48
+ ## How to run
49
+
50
+ ```bash
51
+ cd /Users/salo/HIDREAM-O1-MLX-LAB-active
52
+ .venv/bin/python scripts/hidream_o1/generate_hidream_o1_mlx.py \
53
+ --model-path mlx_models/hidream-o1-dev-q4 \
54
+ --prompt "your prompt here" \
55
+ --width 1024 --height 1024 \
56
+ --output sample_outputs/your_image.png \
57
+ --seed 42
58
+ ```
59
+
60
+ To re-convert from a fresh HF download:
61
+
62
+ ```bash
63
+ .venv/bin/python scripts/hidream_o1/convert_hidream_o1_to_mlx.py \
64
+ --hf-source HiDream-ai/HiDream-O1-Image-Dev \
65
+ --out-dir mlx_models/hidream-o1-dev-q4 \
66
+ --bits 4 --check-disk
67
+ ```
68
+
69
+ ## Implementation summary
70
+
71
+ - **Backbone**: mlx-vlm `qwen3_vl.Model` (vision tower + text decoder + mrope-3D), unchanged. 36 layers, hidden 4096, 32 heads, 8 KV heads, head_dim 128. Vision: 27 blocks, hidden 1152, deepstack at [8, 16, 24].
72
+ - **Custom heads** (under `model.` in HF, mapped to root in MLX):
73
+ - `t_embedder1` — sinusoidal-256 → SiLU → 4096 (timestep embedding)
74
+ - `x_embedder` — 32×32×3 → 1024 → 4096 (patch embedding)
75
+ - `final_layer2` — 4096 → 32×32×3 (patch output)
76
+ - **Forward**: text tokens via embed_tokens, replace tms positions with `t_emb`, append `x_embedder(vinputs)` to the sequence, run all decoder layers with a custom 4D additive mask (text causal, image bidirectional), apply `final_layer2`, slice at `vinput_mask`. Calls `mlx-vlm`'s `Qwen3VLModel.__call__` directly — it already accepts the 4D mask.
77
+ - **Scheduler**: `FlashFlowMatchScheduler` ported verbatim from `models/flash_scheduler.py` (Euler with optional fresh-noise injection). Dev recipe: 28 steps, custom `DEFAULT_TIMESTEPS`, `s_noise=7.5`, `noise_clip_std=2.5`.
78
+ - **Quantisation**: `mx.quantize(group_size=64, bits=4)` on Linear weights where the inner dim is divisible by 64. Vision MLP `linear_fc2` (1152, 4304) doesn't qualify and stays bf16 (~270 MB extra). Custom heads kept bf16 (small + sensitive).
79
+
80
+ ## Bugs found and fixed
81
+
82
+ 1. **mlx-vlm strict-load rejects the 9 custom-head keys.** Fix: write the diffusion-side weights to `extras/custom_heads.safetensors` (subdir, so mlx-vlm's `glob *.safetensors` doesn't pick it up). Wrapper loads both.
83
+ 2. **mlx-vlm needs `quantization` in config.json** to wrap `Linear → QuantizedLinear` before loading weights. Converter writes it.
84
+ 3. **Splitting the safetensors AFTER conversion overwrote the source mmap mid-read**, zeroing all weights silently. Fix: do the split inside the converter (write backbone and custom heads to different paths in one pass; never re-read and overwrite the same path).
85
+ 4. **bf16 → numpy raises** ("PEP 3118 buffer format string"). numpy has no bf16 dtype. Cast to fp32 first.
86
+ 5. **`mx.array([float], dtype=mx.float32)` is invalid syntax** in mlx 0.31.2. Use `mx.array(np.asarray([float], dtype=np.float32))`.
87
+ 6. **`vinput_mask` included the tms position**, causing `gen_patches` to be one row too long. Fix: tag tms positions as `3` so `(token_types == 1)` excludes them.
88
+ 7. **512×512 was being snapped to 2048×2048** by the predefined-resolution table (smallest entry is 1440×2560). Fix: snapping is now opt-in via `--snap-resolution`. By default we just patch-align (multiple of 32) and use the requested size.
89
+
90
+ ## Numbers
91
+
92
+ | Resolution | Quant | Steps | Wall time | s/step | Patches | Peak RAM |
93
+ |---|---|---|---|---|---|---|
94
+ | 512×512 | Q4 | 28 | 24.9 s | 0.89 | 256 | ~6 GB |
95
+ | 1024×1024 | Q4 | 28 | 65–67 s | 2.36 | 1024 | ~6 GB |
96
+ | 1024×1024 | Q8 | 28 | 67–68 s | 2.41 | 1024 | ~11.5 GB |
97
+ | 1280×704 | Q8 | 28 | 70.7 s | 2.53 | 880 | ~7 GB |
98
+ | 704×1280 | Q8 | 28 | 65.9 s | 2.35 | 880 | ~3 GB (warm cache) |
99
+ | 2048×2048 | Q4 | 28 | 236 s | 8.44 | 4096 | ~7.2 GB |
100
+
101
+ Model load: 0.5–4.6 s. Custom-head load: <0.1 s. Disk: Q4 backbone 5.6 GB, Q8 backbone 9.96 GB, custom heads 75 MB.
102
+
103
+ Q8 is **not measurably slower than Q4** at the same resolution — bandwidth-bound, not compute-bound. Use Q8 unless RAM is tight.
104
+
105
+ ## Aesthetic notes
106
+
107
+ **The "dark mood" was Q4 quantisation, not the model.** Q8 of the same prompt + seed produces fully prompt-faithful images:
108
+
109
+ - Cat prompt: Q4 → tabby in dim room. Q8 → vibrant orange tabby in bright sunlit kitchen with plant on windowsill.
110
+ - Beach prompt: Q4 → moonlit silhouette beach. Q8 → bright tropical beach with turquoise water, white sand, blue sky, beach chair.
111
+
112
+ Bottom line: **Q4 distorts the brightness/colour distribution of HiDream-O1's outputs significantly. Q8 is fine.** If you need Q4, expect dark images.
113
+
114
+ A small remaining artifact in flat regions (sky, water): a **patch grid** at the 32×32 boundary. This is intrinsic to the architecture — `final_layer2` predicts each patch independently with no overlap. Not fixable without architectural changes (e.g. a lightweight overlap-blending pass, or finetuning with patch-edge loss).
115
+
116
+ `s_noise=7.5` is load-bearing across both Q4 and Q8 — lowering it collapses the image to a near-uniform colour. This is the FlashFlowMatch scheduler's tuned configuration for the Dev distillation; don't change it.
117
+
118
+ ## Showcase prompts (Q8)
119
+
120
+ - "alice in cyberpunk" — vertical 704×1280 ([sample_outputs/v3_alice_vertical_q8.png](sample_outputs/v3_alice_vertical_q8.png), 65.9 s)
121
+ - "alice in cyberpunk" — horizontal 1280×704 ([sample_outputs/v3_alice_horizontal_q8.png](sample_outputs/v3_alice_horizontal_q8.png), 70.7 s)
122
+ - "vibrant orange tabby in sunlit kitchen" 1024×1024 ([sample_outputs/v3_1024_cat_q8.png](sample_outputs/v3_1024_cat_q8.png))
123
+ - "bright sunlit beach" 1024×1024 ([sample_outputs/v3_1024_beach_q8.png](sample_outputs/v3_1024_beach_q8.png))
124
+
125
+ ## Open questions / next steps if you want to keep going
126
+
127
+ 1. **Compare with full-precision reference** on the same prompts to isolate Q4 vs Dev-distillation effects on brightness.
128
+ 2. **Try Q6 or Q8** of just the decoder layers (vision can stay Q4) to see if attention values get under-represented.
129
+ 3. **Implement edit + multi-reference paths** (the build_*_sample helpers + ref_patches concat from `pipeline.py`).
130
+ 4. **Higher resolution (2048×2048)** — should fit on 64 GB. ~4 min predicted (4× the seq length, but attention is O(S²), so closer to ~6 min).
131
+ 5. **Promote a path forward**: package as `hidream-o1-mlx` Python module that can be imported into a Phosphene engine. NOT yet — wait for an apples-to-apples vs `mflux` Qwen-Image-Edit comparison on ≥5 prompts.
132
+
133
+ ## Hard-stop conditions — where we landed
134
+
135
+ - ✅ mlx-vlm Qwen3-VL is reusable (Path B confirmed).
136
+ - ✅ Q4 output is recognisable and not slow.
137
+ - ✅ Memory stays well under 64 GB.
138
+ - ✅ No new MLX kernels needed.
139
+ - ✅ No CUDA, no PyTorch at runtime.
140
+ - 🟡 Quality leans dark — needs comparison to confirm acceptable.
141
+
142
+ ## Recommendation
143
+
144
+ **Continue.** The hard parts are done: backbone reuse works, custom heads load, the forward pass produces real predictions, the flow-matching loop converges, and we have a working converter + generator. The next session can iterate on quality (compare with reference, try Q6, run more prompts) without any new architectural work.
145
+
146
+ Don't promote to Phosphene yet. Wait for: (a) the brightness question to be resolved, (b) at least 5 side-by-side prompts vs `mflux` Qwen-Image-Edit, (c) edit and multi-ref paths to be wired up so it can replace Qwen-Image-Edit functionally.
docs/PHOSPHENE_INTEGRATION_PLAN.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiDream-O1 → Phosphene integration plan
2
+
3
+ **Status:** plan only. No edits to Phosphene yet. Show this to Salo for approval first.
4
+
5
+ ## Where it slots in
6
+
7
+ Phosphene's `agent/image_engine.py` already abstracts image generation behind
8
+ `generate(prompt, n, output_dir, ..., config)` with a `kind` discriminator.
9
+ Three kinds exist today: `mock`, `mflux`, `bfl`. We add a fourth: `hidream`.
10
+
11
+ Pattern matches `mflux`: subprocess invocation of an external Python that owns
12
+ its own venv. Phosphene stays clean, dependencies stay isolated.
13
+
14
+ ## Files touched (3)
15
+
16
+ ### 1. `agent/image_engine.py` — add config fields, dispatch, generator
17
+
18
+ ```python
19
+ # Inside ImageEngineConfig (after mflux_quantize):
20
+ hidream_python: str = "" # path to lab venv python; empty = autodetect
21
+ hidream_model_path: str = "" # path to converted MLX model dir; empty = autodetect
22
+ hidream_steps: int = 28
23
+ hidream_noise_scale: float = 7.5 # Dev recipe default; do not change
24
+ hidream_noise_clip_std: float = 2.5
25
+ ```
26
+
27
+ ```python
28
+ # Inside generate():
29
+ if config.kind == "hidream":
30
+ return _generate_hidream(prompt, n, width, height, output_dir, base_seed, config, on_log=on_log)
31
+ ```
32
+
33
+ ```python
34
+ # Inside health_check():
35
+ if config.kind == "hidream":
36
+ py = _resolve_hidream_python(config)
37
+ model = _resolve_hidream_model(config)
38
+ if not py:
39
+ return False, "HiDream python not found. Install lab at /Users/salo/HIDREAM-O1-MLX-LAB-active/"
40
+ if not model:
41
+ return False, f"HiDream model dir not found at {config.hidream_model_path or 'autodetect'}"
42
+ return True, f"HiDream ready: {py} + {model}"
43
+ ```
44
+
45
+ ```python
46
+ # New module-level constants + helpers:
47
+ HIDREAM_LAB_DIR = Path("/Users/salo/HIDREAM-O1-MLX-LAB-active")
48
+ HIDREAM_DEFAULT_PY = HIDREAM_LAB_DIR / ".venv" / "bin" / "python"
49
+ HIDREAM_DEFAULT_MODEL = HIDREAM_LAB_DIR / "mlx_models" / "hidream-o1-dev-q8"
50
+ HIDREAM_GENERATE_SCRIPT = HIDREAM_LAB_DIR / "scripts" / "hidream_o1" / "generate_hidream_o1_mlx.py"
51
+
52
+ def _resolve_hidream_python(config) -> str | None:
53
+ p = Path(config.hidream_python) if config.hidream_python else HIDREAM_DEFAULT_PY
54
+ return str(p) if p.is_file() and os.access(p, os.X_OK) else None
55
+
56
+ def _resolve_hidream_model(config) -> str | None:
57
+ p = Path(config.hidream_model_path) if config.hidream_model_path else HIDREAM_DEFAULT_MODEL
58
+ return str(p) if (p / "model.safetensors").exists() else None
59
+
60
+ def _generate_hidream(prompt, n, width, height, output_dir, base_seed, config, on_log=None):
61
+ """Subprocess pattern matching _generate_mflux. One PNG per call to the
62
+ generator script, n calls total. Each candidate uses base_seed+i."""
63
+ py = _resolve_hidream_python(config) or sys.exit("HiDream python missing")
64
+ model = _resolve_hidream_model(config) or sys.exit("HiDream model missing")
65
+ script = str(HIDREAM_GENERATE_SCRIPT)
66
+
67
+ out: list[dict] = []
68
+ for i in range(n):
69
+ seed = (base_seed + i) if base_seed is not None else random.randint(0, 2**31 - 1)
70
+ png = output_dir / f"hidream_{int(time.time()*1000)}_{i:02d}.png"
71
+ cmd = [
72
+ py, script,
73
+ "--model-path", model,
74
+ "--prompt", prompt,
75
+ "--width", str(width),
76
+ "--height", str(height),
77
+ "--output", str(png),
78
+ "--seed", str(seed),
79
+ "--num-inference-steps", str(config.hidream_steps),
80
+ "--noise-scale-start", str(config.hidream_noise_scale),
81
+ "--noise-scale-end", str(config.hidream_noise_scale),
82
+ "--noise-clip-std", str(config.hidream_noise_clip_std),
83
+ ]
84
+ proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
85
+ for line in proc.stdout:
86
+ if on_log: on_log(line.rstrip())
87
+ rc = proc.wait()
88
+ if rc != 0 or not png.exists():
89
+ raise RuntimeError(f"hidream gen failed (rc={rc})")
90
+ out.append({
91
+ "png_path": str(png),
92
+ "seed": seed,
93
+ "engine": "hidream-o1-dev-q8",
94
+ "width": width,
95
+ "height": height,
96
+ })
97
+ return out
98
+ ```
99
+
100
+ ### 2. `mlx_ltx_panel.py` — settings UI option (one dropdown entry)
101
+
102
+ `update_settings()` and `_load_agent_image_config()` already accept `kind`
103
+ strings. Just add `"hidream"` to whatever validation lists exist (likely a
104
+ single line). The panel already shows config.kind in the agent settings card.
105
+
106
+ ### 3. `docs/IMAGE_GEN_RESEARCH_2026-05.md` — note the new option
107
+
108
+ Add a row to the engine comparison table:
109
+
110
+ | Engine | Local | Speed (1024) | RAM | Quality | License |
111
+ |---|---|---|---|---|---|
112
+ | FLUX.2 klein 4B / mflux | yes | ~50 s | ~16 GB | great | Apache 2.0 |
113
+ | Z-Image-Turbo / mflux | yes | ~30 s | ~6 GB | good | Apache 2.0 |
114
+ | **HiDream-O1-Image-Dev / Q8** | **yes** | **~67 s** | **~11 GB** | **great** | **MIT** |
115
+
116
+ ## What does NOT need to change
117
+
118
+ - `start.js` / `install.js` / `pinokio.js` — HiDream's lab is **outside**
119
+ Pinokio; Phosphene just shells out to the lab's python. No new install step.
120
+ - `mlx_warm_helper.py` — that's LTX-only. HiDream is sub-minute, no warm
121
+ helper needed for now (could add one later if we go to a long session of
122
+ many shots).
123
+ - Phosphene's venv (`ltx-2-mlx/env`) — untouched. mlx-vlm is in the lab's
124
+ separate `.venv`.
125
+
126
+ ## Risks & mitigations
127
+
128
+ | Risk | Mitigation |
129
+ |---|---|
130
+ | Lab path is hard-coded — moves break it | Configurable via `hidream_python` / `hidream_model_path`. Defaults are absolute; users can override in `state/agent_image_config.json`. |
131
+ | HiDream + LTX run at the same time (both want GPU) | Already a problem with mflux + LTX; Phosphene queue serialises shot generation. No new mitigation needed. |
132
+ | Lab dir gets nuked again | `README.md` marker is in place; user is aware. If it goes, Phosphene's `health_check` returns clearly and panel surfaces it. |
133
+ | Quality-tier defaults: most users won't have a 64 GB Mac | Mark HiDream as **Comfortable+ (32 GB+)** tier in the docs. Don't make it the default — keep mflux Z-Image-Turbo as default for compact tier, FLUX.2 klein as default for comfortable. |
134
+
135
+ ## Cost / size
136
+
137
+ - Disk: ~10 GB additional in lab (already there)
138
+ - RAM at 1024×1024: ~11.5 GB (Q8). Same RAM tier as FLUX.2 klein.
139
+ - One-time setup: lab venv install (~1.5 GB, already done).
140
+
141
+ ## Roll-out
142
+
143
+ 1. Patch `image_engine.py` (above).
144
+ 2. Add `"hidream"` to settings validation in `mlx_ltx_panel.py`.
145
+ 3. Switch agent_image_config.json kind to `"hidream"` in a single test session.
146
+ 4. Generate one shot through the agent UI; confirm PNG lands.
147
+ 5. Compare to the same prompt through `mflux qwen-image-edit`.
148
+ 6. If quality wins on at least 3 prompts → make it a real option in docs.
149
+ 7. Don't switch the default until we have ≥5 prompts where HiDream is clearly better than mflux Z-Image-Turbo, AND the dark-aesthetic concern is fully ruled out.
150
+
151
+ ## What I'd want before merging this
152
+
153
+ 1. ✅ Q8 conversion of HiDream-O1-Image-Dev (DONE)
154
+ 2. ✅ Stable single-shot text-to-image (DONE — sample images in `sample_outputs/`)
155
+ 3. 🟡 Showcase pass to characterise quality across genres (RUNNING)
156
+ 4. ❌ Side-by-side vs Phosphene's existing mflux engines on ≥5 matched prompts (NOT YET — needs the showcase to finish + a parallel run on mflux)
157
+ 5. ❌ One real agent-flow render that uses HiDream as the anchor engine and
158
+ feeds the result into LTX 2.3 (NOT YET — easy once health_check passes)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Apple Silicon only. Tested on macOS 14+ with Python 3.11 in a uv venv.
2
+ # These are the exact pins the lab was developed against — newer minor
3
+ # versions of mlx-vlm and transformers have been observed to break the
4
+ # Qwen3-VL backbone import path; if you upgrade, re-test before shipping.
5
+
6
+ mlx>=0.31.2
7
+ mlx-vlm>=0.5.0
8
+ transformers>=4.57.0,<6.0
9
+ huggingface_hub>=0.30
10
+ safetensors>=0.6
11
+ numpy>=2.0
12
+ pillow>=10.0
13
+ tqdm>=4.66
14
+ sentencepiece>=0.2.0
15
+ hf_transfer>=0.1.9 # optional, speeds up the HF download
sample_outputs/hero/01_tea_master.png ADDED

Git LFS Details

  • SHA256: 38ac862dd666dfca0ef0f632661a4f93298e6113eb851c3a048ddd3c1b36422f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.47 MB
sample_outputs/hero/02_tropical_beach.png ADDED

Git LFS Details

  • SHA256: 20743ab3e89312e3e348a42b43eb3c405986ced8792968f8c772fde2363a7a4d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
sample_outputs/hero/03_astronaut.png ADDED

Git LFS Details

  • SHA256: 7e876f6c286fc12f521308cc24acfe39647aff29e820b2c81b595825ed7f6de1
  • Pointer size: 132 Bytes
  • Size of remote file: 5.14 MB
sample_outputs/hero/04_construction_worker.png ADDED

Git LFS Details

  • SHA256: 03649c1acb92b57175c9055245f3e5da04c6e2eb5ccd3aa01a8ff1d7edc6287c
  • Pointer size: 132 Bytes
  • Size of remote file: 4.38 MB
sample_outputs/hero/05_mountain_peak.png ADDED

Git LFS Details

  • SHA256: fdb3a292322f91601f3204d6f507d8a496c146533dc6d5c33c881355d909b93a
  • Pointer size: 132 Bytes
  • Size of remote file: 3.05 MB
sample_outputs/hero/06_alice_cyberpunk.png ADDED

Git LFS Details

  • SHA256: 7e842b91de564975224ef9097b8ad91b26bd0253d58beac20ea31644a0463f70
  • Pointer size: 132 Bytes
  • Size of remote file: 5.68 MB
sample_outputs/hero/07_kitchen_morning.png ADDED

Git LFS Details

  • SHA256: 178da6df36debb128ce756bd05e165e5ae8af10789e321cd3c205e030f3c541d
  • Pointer size: 132 Bytes
  • Size of remote file: 3.65 MB
sample_outputs/hero/08_fitness_BF16.png ADDED

Git LFS Details

  • SHA256: 515265ae082e4596fdd6f2373be9f29aee0c3a2f408b614963396c8176d44289
  • Pointer size: 132 Bytes
  • Size of remote file: 3.61 MB
scripts/hidream_o1/__init__.py ADDED
File without changes
scripts/hidream_o1/_compile_bench.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quick A/B: forward_generation with vs without mx.compile.
2
+
3
+ Times 5 forward passes after warm-up. Same shapes as a 1024x1024 inference.
4
+ """
5
+ from __future__ import annotations
6
+ import sys, time
7
+ from pathlib import Path
8
+
9
+ HERE = Path(__file__).parent
10
+ sys.path.insert(0, str(HERE))
11
+
12
+ import numpy as np
13
+ import mlx.core as mx
14
+ from mlx_vlm import load as mlx_vlm_load
15
+ from pipeline_helpers import build_t2i_text_sample, build_attention_mask, PATCH_SIZE
16
+ from hidream_model import HiDreamConfig, build_model, forward_generation
17
+
18
+ LAB = Path(__file__).resolve().parents[2]
19
+ MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q8"
20
+
21
+ print("loading model...")
22
+ t0 = time.time()
23
+ backbone, processor = mlx_vlm_load(str(MODEL_PATH))
24
+ print(f" {time.time()-t0:.1f}s")
25
+
26
+ cfg = HiDreamConfig()
27
+ model = build_model(cfg, backbone)
28
+ custom = mx.load(str(MODEL_PATH / "extras" / "custom_heads.safetensors"))
29
+ model.load_weights(list(custom.items()), strict=False)
30
+ mx.eval(model.parameters())
31
+ print("model ready")
32
+
33
+ # Build inputs at 1024x1024
34
+ WIDTH, HEIGHT = 1024, 1024
35
+ N_PATCH = (WIDTH // PATCH_SIZE) * (HEIGHT // PATCH_SIZE) # 1024
36
+
37
+ tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
38
+ for n in ("boi", "bor", "eor", "bot", "tms"):
39
+ if not hasattr(tokenizer, f"{n}_token"):
40
+ setattr(tokenizer, f"{n}_token", f"<|{n}_token|>")
41
+
42
+ sample = build_t2i_text_sample(
43
+ "a small red mushroom on a bed of moss",
44
+ HEIGHT, WIDTH, tokenizer, processor, backbone.config,
45
+ )
46
+ input_ids = mx.array(sample["input_ids"])
47
+ position_ids = mx.array(sample["position_ids"])
48
+ token_types = mx.array(sample["token_types"])
49
+ mask4d = mx.array(build_attention_mask(sample["token_types"], -1e4)).astype(mx.bfloat16)
50
+ vinputs = mx.random.normal((1, N_PATCH, 3 * PATCH_SIZE * PATCH_SIZE)).astype(mx.bfloat16)
51
+ timestep = mx.array([0.5], dtype=mx.float32)
52
+
53
+ print(f"shapes: input_ids={input_ids.shape} pos={position_ids.shape} "
54
+ f"vinputs={vinputs.shape} mask={mask4d.shape}")
55
+
56
+ # --- Uncompiled baseline ---
57
+ print("\n=== baseline (uncompiled) ===")
58
+ # warmup
59
+ for _ in range(2):
60
+ out = forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d)
61
+ mx.eval(out)
62
+
63
+ # time
64
+ N = 5
65
+ t0 = time.time()
66
+ for _ in range(N):
67
+ out = forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d)
68
+ mx.eval(out)
69
+ elapsed = time.time() - t0
70
+ print(f" baseline: {elapsed/N:.3f}s/step over {N} steps")
71
+
72
+ # --- Compiled ---
73
+ print("\n=== mx.compile ===")
74
+ def fwd(input_ids, position_ids, vinputs, timestep, token_types, mask4d):
75
+ return forward_generation(model, cfg, input_ids, position_ids, vinputs, timestep, token_types, mask4d)
76
+
77
+ try:
78
+ fwd_c = mx.compile(fwd)
79
+ # warmup (first call compiles)
80
+ for _ in range(2):
81
+ out = fwd_c(input_ids, position_ids, vinputs, timestep, token_types, mask4d)
82
+ mx.eval(out)
83
+ t0 = time.time()
84
+ for _ in range(N):
85
+ out = fwd_c(input_ids, position_ids, vinputs, timestep, token_types, mask4d)
86
+ mx.eval(out)
87
+ elapsed_c = time.time() - t0
88
+ print(f" compiled: {elapsed_c/N:.3f}s/step over {N} steps (speedup {elapsed/elapsed_c:.2f}x)")
89
+ except Exception as e:
90
+ print(f" mx.compile failed: {type(e).__name__}: {e}")
scripts/hidream_o1/_edit_diag.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Diagnose what build_edit_text_sample produces, no model load."""
2
+ from __future__ import annotations
3
+ import sys
4
+ from pathlib import Path
5
+ HERE = Path(__file__).parent
6
+ sys.path.insert(0, str(HERE))
7
+
8
+ import numpy as np
9
+ from mlx_vlm import load as mlx_vlm_load
10
+
11
+ LAB = Path(__file__).resolve().parents[2]
12
+ MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q6"
13
+ REF = "/tmp/hidream_edit_smoke/ref.png"
14
+
15
+ # Use mlx-vlm to get a working processor that skips the video-processor dep issue
16
+ backbone, processor = mlx_vlm_load(str(MODEL_PATH))
17
+ tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
18
+ for n in ("boi", "bor", "eor", "bot", "tms"):
19
+ if not hasattr(tokenizer, f"{n}_token"):
20
+ setattr(tokenizer, f"{n}_token", f"<|{n}_token|>")
21
+
22
+ MC = backbone.config
23
+
24
+ from pipeline_helpers import build_edit_text_sample, PATCH_SIZE
25
+
26
+ prompt = "in the style of the reference image, a vibrant abstract composition, vivid colors, modern art"
27
+ H = W = 512
28
+ sample = build_edit_text_sample(prompt, [REF], H, W, tokenizer, processor, MC)
29
+
30
+ print("=== build_edit_text_sample shapes ===")
31
+ for k, v in sample.items():
32
+ if hasattr(v, "shape"):
33
+ print(f" {k}: shape={v.shape} dtype={v.dtype}")
34
+ else:
35
+ print(f" {k}: {v}")
36
+
37
+ iid = sample["input_ids"][0]
38
+ img_token_id = MC.image_token_id
39
+ vs_token_id = MC.vision_start_token_id
40
+ img_count = int((iid == img_token_id).sum())
41
+ vs_count = int((iid == vs_token_id).sum())
42
+ tms_count = int((iid == 151673).sum()) # tms_token_id
43
+ print(f"\n=== input_ids breakdown (text-side, length {iid.shape[0]}) ===")
44
+ print(f" image_token_id ({img_token_id}): {img_count} positions <-- vision tower fills these")
45
+ print(f" vision_start_token_id ({vs_token_id}): {vs_count}")
46
+ print(f" tms_token_id (151673): {tms_count}")
47
+ print(f" first 30 ids: {list(iid[:30])}")
48
+ print(f" last 5 ids: {list(iid[-5:])}")
49
+
50
+ pix = sample["pixel_values"]
51
+ g = sample["image_grid_thw"]
52
+ print(f"\n=== vision tower input ===")
53
+ print(f" pixel_values shape: {pix.shape}")
54
+ print(f" image_grid_thw: {g}")
55
+ # Per-image vision patch count = T*H*W, post-merge = T*H/m*W/m
56
+ m = backbone.config.vision_config.spatial_merge_size
57
+ for i, row in enumerate(g):
58
+ t, h, w = row
59
+ pre_merge = int(t * h * w)
60
+ post_merge = int(t * (h//m) * (w//m))
61
+ print(f" ref {i}: pre-merge patches={pre_merge}, post-merge={post_merge}")
62
+ print(f" TOTAL post-merge features (what vision tower outputs): {sum(int(r[0])*(int(r[1])//m)*(int(r[2])//m) for r in g)}")
63
+ print(f" TOTAL image_token_id positions in input_ids: {img_count}")
64
+ print(f" ** these must match for scatter to work **")
65
+
66
+ vinput_mask = sample["vinput_mask"][0]
67
+ vinput_mask_tgt = sample["vinput_mask_tgt_only"][0]
68
+ print(f"\n=== mask checks ===")
69
+ print(f" total vinput positions (tgt+refs): {int(vinput_mask.sum())} = {sample['tgt_image_len']} + {int(vinput_mask.sum()) - sample['tgt_image_len']}")
70
+ print(f" total tgt-only positions: {int(vinput_mask_tgt.sum())} (expect {sample['tgt_image_len']})")
71
+
72
+ # Position IDs
73
+ pids = sample["position_ids"]
74
+ print(f"\n=== position_ids ===")
75
+ print(f" shape: {pids.shape} (3D mrope: rope_dim, batch, seq)")
76
+ print(f" ranges per dim: {[(int(pids[d].min()), int(pids[d].max())) for d in range(pids.shape[0])]}")
77
+ # Where are the discontinuities? Look at the boundary between text-side and vision-token-side
78
+ txt_seq_len = iid.shape[0]
79
+ print(f" text/vision boundary at position {txt_seq_len}")
80
+ print(f" pids[:, 0, txt_seq_len-3:txt_seq_len+3] (around the boundary):")
81
+ print(pids[:, 0, max(0, txt_seq_len-3):txt_seq_len+3])
scripts/hidream_o1/_precompute_diag.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Verify precompute_text_embeds_with_vision actually scatters image features
2
+ into the right positions of inputs_embeds, without mangling text positions.
3
+ """
4
+ from __future__ import annotations
5
+ import sys
6
+ from pathlib import Path
7
+ HERE = Path(__file__).parent
8
+ sys.path.insert(0, str(HERE))
9
+
10
+ import numpy as np
11
+ import mlx.core as mx
12
+ from mlx_vlm import load as mlx_vlm_load
13
+
14
+ from pipeline_helpers import build_edit_text_sample
15
+ from hidream_model import HiDreamConfig, build_model, precompute_text_embeds_with_vision
16
+
17
+ LAB = Path(__file__).resolve().parents[2]
18
+ MODEL_PATH = LAB / "mlx_models" / "hidream-o1-dev-q6"
19
+ REF = "sample_outputs/v3_1024_cat_q8.png"
20
+
21
+ print("loading model...")
22
+ backbone, processor = mlx_vlm_load(str(MODEL_PATH))
23
+ cfg = HiDreamConfig()
24
+ model = build_model(cfg, backbone)
25
+ custom = mx.load(str(MODEL_PATH / "extras" / "custom_heads.safetensors"))
26
+ model.load_weights(list(custom.items()), strict=False)
27
+ mx.eval(model.parameters())
28
+
29
+ tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
30
+ for n in ("boi", "bor", "eor", "bot", "tms"):
31
+ if not hasattr(tokenizer, f"{n}_token"):
32
+ setattr(tokenizer, f"{n}_token", f"<|{n}_token|>")
33
+
34
+ sample = build_edit_text_sample(
35
+ "a cat", [str(LAB / REF)], 1024, 1024, tokenizer, processor, backbone.config,
36
+ )
37
+
38
+ input_ids = mx.array(sample["input_ids"])
39
+ pixel_values = mx.array(sample["pixel_values"]).astype(mx.bfloat16)
40
+ image_grid_thw = mx.array(sample["image_grid_thw"])
41
+
42
+ # 1) Just the embed_tokens output (no scatter)
43
+ embed_tokens = model.language_model.model.embed_tokens
44
+ text_only_embeds = embed_tokens(input_ids)
45
+ mx.eval(text_only_embeds)
46
+ print(f"\ntext-only embeds shape: {text_only_embeds.shape} dtype: {text_only_embeds.dtype}")
47
+
48
+ # 2) Vision tower output
49
+ vt_out = model.visual(pixel_values, image_grid_thw)
50
+ img_features = vt_out[0] if isinstance(vt_out, tuple) else vt_out
51
+ mx.eval(img_features)
52
+ print(f"image_features shape: {img_features.shape} dtype: {img_features.dtype}")
53
+ print(f" stats: mean={float(mx.mean(img_features.astype(mx.float32))):.4f} std={float(mx.std(img_features.astype(mx.float32))):.4f} min={float(mx.min(img_features.astype(mx.float32))):.3f} max={float(mx.max(img_features.astype(mx.float32))):.3f}")
54
+
55
+ # 3) Run our precompute
56
+ combined = precompute_text_embeds_with_vision(model, cfg, input_ids, pixel_values, image_grid_thw)
57
+ mx.eval(combined)
58
+ print(f"\ncombined embeds shape: {combined.shape} dtype: {combined.dtype}")
59
+
60
+ # 4) Inspect: at image positions, combined should equal image_features
61
+ ids_np = np.asarray(input_ids[0])
62
+ img_pos = np.where(ids_np == cfg.image_token_id)[0]
63
+ text_pos = np.where(ids_np != cfg.image_token_id)[0]
64
+ print(f"\nimage_token positions: {len(img_pos)} (first 5: {img_pos[:5].tolist()}, last 5: {img_pos[-5:].tolist()})")
65
+ print(f"text positions: {len(text_pos)} (first 5: {text_pos[:5].tolist()})")
66
+
67
+ # At image positions: combined should be image_features (in same order)
68
+ # combined[0, img_pos[i], :] should equal img_features[i, :]
69
+ combined_np = np.asarray(combined[0].astype(mx.float32))
70
+ img_feat_np = np.asarray(img_features.astype(mx.float32))
71
+ print("\n--- check: combined[0, img_pos[0]] vs img_features[0] ---")
72
+ print(f" combined[0, {img_pos[0]}, :8] = {combined_np[img_pos[0], :8]}")
73
+ print(f" image_features[0, :8] = {img_feat_np[0, :8]}")
74
+ print(f" diff: {np.abs(combined_np[img_pos[0]] - img_feat_np[0]).max():.4f}")
75
+
76
+ print("\n--- check: combined[0, img_pos[5]] vs img_features[5] ---")
77
+ print(f" combined[0, {img_pos[5]}, :8] = {combined_np[img_pos[5], :8]}")
78
+ print(f" image_features[5, :8] = {img_feat_np[5, :8]}")
79
+ print(f" diff: {np.abs(combined_np[img_pos[5]] - img_feat_np[5]).max():.4f}")
80
+
81
+ # At text positions: combined should equal embed_tokens output
82
+ text_only_np = np.asarray(text_only_embeds[0].astype(mx.float32))
83
+ diff_at_text = np.abs(combined_np[text_pos] - text_only_np[text_pos]).max()
84
+ print(f"\n--- check: combined matches text embeddings at text positions ---")
85
+ print(f" max abs diff at text positions: {diff_at_text:.6f} (should be 0)")
86
+
87
+ # Also: at image positions, embed_tokens gives the image_token's WEIRD embedding (since the token is just a placeholder)
88
+ print(f"\n embed_tokens at img_pos[0] (the placeholder embedding): {text_only_np[img_pos[0], :8]}")
scripts/hidream_o1/anti_plastic_batch.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Anti-plastic batch v2 — incorporates HiDream-specific prompt tips:
3
+ # - "masterpiece, best quality" prefix (Civitai community finding)
4
+ # - Subject + Actions → Setting → Style → Details ordering
5
+ # - Specific cameras (Leica 50mm, Pentax K1000, Hasselblad)
6
+ # - Specific film stocks (Tri-X 400, Portra 400, Cinestill 800T)
7
+ # - Documentary photographer references
8
+ # - BF16 weights (no quantization)
9
+ set -euo pipefail
10
+
11
+ LAB="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
12
+ PY="$LAB/.venv/bin/python"
13
+ MODEL="$LAB/mlx_models/hidream-o1-dev-bf16"
14
+ OUT="$LAB/sample_outputs/showcase_antiplastic"
15
+ mkdir -p "$OUT"
16
+
17
+ run() {
18
+ local name="$1" w="$2" h="$3" prompt="$4" seed="${5:-42}"
19
+ echo "=== $name ${w}x${h} (seed=$seed) ==="
20
+ cd "$LAB" && "$PY" scripts/hidream_o1/generate_hidream_o1_mlx.py \
21
+ --model-path "$MODEL" \
22
+ --prompt "$prompt" \
23
+ --width "$w" --height "$h" \
24
+ --output "$OUT/$name.png" \
25
+ --seed "$seed" 2>&1 | grep -E "loaded|using|generation:|saved" | tail -3
26
+ echo ""
27
+ }
28
+
29
+ run "01_construction_rain" 2048 2048 \
30
+ "masterpiece, best quality, 35mm DSLR photograph. A construction worker leans against a steel I-beam, taking a long drag from a cigarette between gloved fingers. On a half-built skyscraper rooftop in heavy rain, water streaming off his hard hat and reflective vest. Documentary photojournalism, Sebastião Salgado aesthetic, shot on Leica M6 with Kodak Tri-X 400 black and white film, harsh overcast daylight, deep grain, raw skin texture with rain droplets and stubble visible, no retouching, 50mm Summicron lens" \
31
+ 701
32
+
33
+ run "02_pub_musician" 2048 2048 \
34
+ "masterpiece, best quality, 35mm DSLR photograph. A bearded musician in his late thirties sings into a vintage Shure SM58 microphone, eyes closed mid-note, fingers callused on a worn acoustic guitar. In a dim London pub on a Tuesday night, three half-empty pint glasses on a small wooden stage edge, a single warm tungsten spotlight from above creating sharp shadows. Cinematic documentary, shot on Pentax K1000 with Cinestill 800T film, visible grain and halation around the spotlight, real sweat on his forehead, natural skin pores and laugh lines, Anders Petersen mood" \
35
+ 702
36
+
37
+ run "03_mechanic_garage" 3104 1312 \
38
+ "masterpiece, best quality, ultrawide editorial photograph. A female mechanic in her mid forties wipes engine grease from her hands with a faded red rag, standing beside the open hood of a 1967 Pontiac GTO. In her cluttered garage on a quiet Sunday afternoon, tool chests and stacks of car magazines along the wall, sun streaming through high windows catching dust motes in the air. Annie Leibovitz Vanity Fair aesthetic, shot on Hasselblad H6D medium format with natural skin texture retention, soft fill light, visible pores and faint freckles, weathered hands with chipped nail polish, no glamour retouching, real and lived-in" \
39
+ 703
40
+
41
+ echo "=== batch complete ==="
42
+ ls -la "$OUT"
scripts/hidream_o1/cinematic_batch.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Cinematic batch — people doing things, specific dress, photoreal/cinematic style.
3
+ # Usage: cinematic_batch.sh <quant> <width> <height>
4
+ set -euo pipefail
5
+
6
+ LAB="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
7
+ PY="$LAB/.venv/bin/python"
8
+ QUANT="${1:-q6}"
9
+ WIDTH="${2:-1920}"
10
+ HEIGHT="${3:-1088}"
11
+ MODEL="$LAB/mlx_models/hidream-o1-dev-${QUANT}"
12
+ OUT="$LAB/sample_outputs/cinematic_${QUANT}_${WIDTH}x${HEIGHT}"
13
+ mkdir -p "$OUT"
14
+ echo "cinematic batch: quant=${QUANT}, ${WIDTH}x${HEIGHT}, model=${MODEL}, out=${OUT}"
15
+
16
+ run() {
17
+ local name="$1" prompt="$2" seed="${3:-42}"
18
+ echo "=== $name (seed=$seed) ==="
19
+ cd "$LAB" && /usr/bin/time -l "$PY" scripts/hidream_o1/generate_hidream_o1_mlx.py \
20
+ --model-path "$MODEL" \
21
+ --prompt "$prompt" \
22
+ --width "$WIDTH" --height "$HEIGHT" \
23
+ --output "$OUT/$name.png" \
24
+ --seed "$seed" 2>&1 | grep -E "loaded|using|generation:|saved|maximum resident" | tail -5
25
+ echo ""
26
+ }
27
+
28
+ run "01_jazz_pianist" \
29
+ "cinematic medium shot of a jazz pianist in his fifties at a baby grand piano, wearing a navy three-piece suit with a thin gold pocket watch chain, fingers blurred mid-arpeggio, dim smoky club lighting from above, deep shadows, anamorphic 35mm film grain, shallow depth of field" \
30
+ 3
31
+
32
+ run "02_street_photographer" \
33
+ "cinematic shot of a young street photographer crouched on a Shinjuku crosswalk at night, holding a Leica M11 camera up to her eye, wearing an olive-green oversized trench coat over black jeans and black leather boots, neon signs in Japanese reflecting off wet asphalt, wide-angle lens, blade runner color palette, photorealistic" \
34
+ 17
35
+
36
+ run "03_michelin_chef" \
37
+ "close-up cinematic shot of a Michelin-star chef in a crisp white double-breasted chef coat with rolled sleeves, tweezers placing a single edible flower onto a black slate plate, steam rising, kitchen brigade in soft focus behind, warm copper-pan lighting, food cinematography, hyperreal" \
38
+ 29
39
+
40
+ run "04_ballet_dancer" \
41
+ "cinematic full-body shot of a ballerina mid-grand-jeté across an empty rehearsal studio, wearing a slate-grey leotard and pink satin pointe shoes, hair in a tight bun, golden afternoon sunlight streaming through tall windows, dust particles visible in the light beams, motion blur on her trailing foot" \
42
+ 41
43
+
44
+ run "05_astronaut" \
45
+ "cinematic wide shot of an astronaut in a battered orange ACES launch-and-entry suit walking down a long curved corridor inside a space station, helmet tucked under one arm, clipboard in the other, fluorescent overhead strip lighting, scratched white wall panels, anamorphic lens flare, sci-fi realism" \
46
+ 53
47
+
48
+ echo "=== batch complete ==="
49
+ ls -la "$OUT"
scripts/hidream_o1/convert_hidream_o1_to_mlx.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Convert HiDream-O1-Image-Dev safetensors -> MLX format."""
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import json
7
+ import shutil
8
+ import sys
9
+ import time
10
+ from pathlib import Path
11
+ from typing import Iterable
12
+
13
+ import mlx.core as mx
14
+
15
+
16
+ HF_TO_MLX_CUSTOM = {
17
+ "model.t_embedder1.mlp.0.weight": "t_embedder1.fc1.weight",
18
+ "model.t_embedder1.mlp.0.bias": "t_embedder1.fc1.bias",
19
+ "model.t_embedder1.mlp.2.weight": "t_embedder1.fc2.weight",
20
+ "model.t_embedder1.mlp.2.bias": "t_embedder1.fc2.bias",
21
+ "model.x_embedder.proj1.weight": "x_embedder.proj1.weight",
22
+ "model.x_embedder.proj2.weight": "x_embedder.proj2.weight",
23
+ "model.x_embedder.proj2.bias": "x_embedder.proj2.bias",
24
+ "model.final_layer2.linear.weight": "final_layer2.linear.weight",
25
+ "model.final_layer2.linear.bias": "final_layer2.linear.bias",
26
+ }
27
+
28
+ CUSTOM_HEAD_PREFIXES = ("t_embedder1.", "x_embedder.", "final_layer2.")
29
+
30
+
31
+ def remap_hf_to_mlx(hf_key: str) -> str | None:
32
+ if hf_key == "lm_head.weight":
33
+ # mlx-vlm Qwen3-VL Model has language_model.lm_head when tie_word_embeddings=False.
34
+ # We don't call it (HiDream uses final_layer2 for image patches), but keeping the
35
+ # weight avoids strict-load failures when mlx-vlm imports the checkpoint.
36
+ return "language_model.lm_head.weight"
37
+ if hf_key in HF_TO_MLX_CUSTOM:
38
+ return HF_TO_MLX_CUSTOM[hf_key]
39
+ if hf_key.startswith("model.language_model."):
40
+ return "language_model.model." + hf_key[len("model.language_model."):]
41
+ if hf_key.startswith("model.visual."):
42
+ return "vision_tower." + hf_key[len("model.visual."):]
43
+ return hf_key
44
+
45
+
46
+ def stream_safetensors(shard_paths: Iterable[Path]) -> dict[str, mx.array]:
47
+ out: dict[str, mx.array] = {}
48
+ for p in shard_paths:
49
+ print(f" loading {p.name} ({p.stat().st_size / 1e9:.2f} GB) ...", flush=True)
50
+ shard = mx.load(str(p))
51
+ for k, v in shard.items():
52
+ mlx_key = remap_hf_to_mlx(k)
53
+ if mlx_key is None:
54
+ continue
55
+ out[mlx_key] = v
56
+ return out
57
+
58
+
59
+ def quantise(weights: dict[str, mx.array], bits: int, group_size: int = 64) -> dict[str, mx.array]:
60
+ if bits == 16:
61
+ # No quantization — return weights unchanged. Caller has already cast f32 -> bf16.
62
+ return weights
63
+ if bits not in (4, 6, 8):
64
+ raise ValueError(f"--bits must be 4, 6, 8, or 16 (got {bits})")
65
+
66
+ out: dict[str, mx.array] = {}
67
+ quantised = 0
68
+ skipped = 0
69
+
70
+ for k, v in weights.items():
71
+ if any(k.startswith(p) for p in CUSTOM_HEAD_PREFIXES):
72
+ out[k] = v
73
+ skipped += 1
74
+ continue
75
+ if v.ndim != 2 or "embed_tokens" in k or "norm" in k or k.endswith(".bias"):
76
+ out[k] = v
77
+ skipped += 1
78
+ continue
79
+ try:
80
+ qw, scales, biases = mx.quantize(v, group_size=group_size, bits=bits)
81
+ base = k[: -len(".weight")] if k.endswith(".weight") else k
82
+ out[k] = qw
83
+ out[base + ".scales"] = scales
84
+ out[base + ".biases"] = biases
85
+ quantised += 1
86
+ except Exception as e:
87
+ print(f" [warn] could not quantise {k}: {e!s}; keeping fp", flush=True)
88
+ out[k] = v
89
+ skipped += 1
90
+
91
+ print(f"quantised {quantised} tensors, skipped {skipped} (kept as-is)")
92
+ return out
93
+
94
+
95
+ def copy_metadata(src_dir: Path, dst_dir: Path):
96
+ keep = {
97
+ "config.json", "configuration.json", "generation_config.json",
98
+ "preprocessor_config.json", "video_preprocessor_config.json",
99
+ "chat_template.json", "tokenizer.json", "tokenizer_config.json",
100
+ "vocab.json", "merges.txt", "README.md",
101
+ }
102
+ for name in keep:
103
+ src = src_dir / name
104
+ if src.exists():
105
+ shutil.copy2(src, dst_dir / name)
106
+
107
+
108
+ def resolve_source(arg: str, cache_dir: Path | None) -> Path:
109
+ p = Path(arg)
110
+ if p.exists() and p.is_dir():
111
+ return p
112
+ try:
113
+ from huggingface_hub import snapshot_download
114
+ except ImportError as e:
115
+ sys.exit(f"huggingface_hub is required to download {arg!r}: {e}")
116
+ print(f"downloading {arg!r} ...", flush=True)
117
+ return Path(snapshot_download(arg, cache_dir=str(cache_dir) if cache_dir else None))
118
+
119
+
120
+ def main(argv=None):
121
+ ap = argparse.ArgumentParser()
122
+ ap.add_argument("--hf-source", default="HiDream-ai/HiDream-O1-Image-Dev")
123
+ ap.add_argument("--cache-dir", default=None)
124
+ ap.add_argument("--out-dir", default="./mlx_models/hidream-o1-dev-q4")
125
+ ap.add_argument("--bits", type=int, default=4, choices=[4, 6, 8, 16],
126
+ help="16 = no quantization, store as bf16 (matches upstream's master-weight precision)")
127
+ ap.add_argument("--group-size", type=int, default=64)
128
+ ap.add_argument("--check-disk", action="store_true")
129
+ ap.add_argument("--dry-run", action="store_true")
130
+ args = ap.parse_args(argv)
131
+
132
+ out_dir = Path(args.out_dir).resolve()
133
+ out_dir.mkdir(parents=True, exist_ok=True)
134
+
135
+ if args.check_disk:
136
+ free_gb = shutil.disk_usage(out_dir).free / 1e9
137
+ if free_gb < 40:
138
+ sys.exit(f"free disk on {out_dir.parent}: {free_gb:.1f} GB; need >= 40 GB")
139
+
140
+ src_dir = resolve_source(args.hf_source, Path(args.cache_dir) if args.cache_dir else None)
141
+ print(f"source: {src_dir}")
142
+
143
+ idx_path = src_dir / "model.safetensors.index.json"
144
+ if not idx_path.exists():
145
+ sys.exit(f"no model.safetensors.index.json under {src_dir}")
146
+ idx = json.loads(idx_path.read_text())
147
+ shard_names = sorted({v for v in idx["weight_map"].values()})
148
+ shard_paths = [src_dir / name for name in shard_names]
149
+ total_gb = sum(p.stat().st_size for p in shard_paths) / 1e9
150
+ print(f"shards: {len(shard_paths)}, total {total_gb:.2f} GB")
151
+
152
+ if args.dry_run:
153
+ for p in shard_paths:
154
+ print(f" {p.stat().st_size / 1e9:6.2f} GB {p.name}")
155
+ return
156
+
157
+ t0 = time.time()
158
+ weights = stream_safetensors(shard_paths)
159
+ print(f"loaded {len(weights)} tensors in {time.time() - t0:.1f}s")
160
+
161
+ for k in list(weights.keys()):
162
+ if weights[k].dtype == mx.float32:
163
+ weights[k] = weights[k].astype(mx.bfloat16)
164
+
165
+ weights = quantise(weights, bits=args.bits, group_size=args.group_size)
166
+
167
+ # Split: mlx-vlm-loadable backbone goes to model.safetensors;
168
+ # our diffusion-side heads go to custom_heads.safetensors so mlx-vlm's
169
+ # strict load doesn't reject them.
170
+ backbone = {k: v for k, v in weights.items()
171
+ if not any(k.startswith(p) for p in CUSTOM_HEAD_PREFIXES)}
172
+ custom = {k: v for k, v in weights.items()
173
+ if any(k.startswith(p) for p in CUSTOM_HEAD_PREFIXES)}
174
+
175
+ out_path = out_dir / "model.safetensors"
176
+ print(f"saving {len(backbone)} backbone tensors to {out_path} ...")
177
+ mx.save_safetensors(str(out_path), backbone)
178
+
179
+ extras_dir = out_dir / "extras"
180
+ extras_dir.mkdir(exist_ok=True)
181
+ custom_path = extras_dir / "custom_heads.safetensors"
182
+ print(f"saving {len(custom)} custom-head tensors to {custom_path} ...")
183
+ mx.save_safetensors(str(custom_path), custom)
184
+
185
+ copy_metadata(src_dir, out_dir)
186
+
187
+ # Update config.json. For bits<16, write the quantization field so mlx-vlm
188
+ # wraps Linear -> QuantizedLinear before loading. For bits=16 (BF16, no quant),
189
+ # remove any pre-existing quantization field so mlx-vlm loads as plain Linear.
190
+ cfg_path = out_dir / "config.json"
191
+ cfg = json.loads(cfg_path.read_text())
192
+ if args.bits == 16:
193
+ cfg.pop("quantization", None)
194
+ print("config.json: bits=16 (BF16), no quantization field written")
195
+ else:
196
+ cfg["quantization"] = {"group_size": args.group_size, "bits": args.bits}
197
+ print(f"config.json: quantization={{group_size: {args.group_size}, bits: {args.bits}}}")
198
+ cfg_path.write_text(json.dumps(cfg, indent=2))
199
+
200
+ (out_dir / "mlx_lab_meta.json").write_text(json.dumps({
201
+ "format": "hidream-o1-mlx-lab/v0",
202
+ "bits": args.bits,
203
+ "group_size": args.group_size,
204
+ "source": str(src_dir),
205
+ }, indent=2))
206
+ print(f"done; output dir size: {sum(f.stat().st_size for f in out_dir.glob('*')) / 1e9:.2f} GB")
207
+
208
+
209
+ if __name__ == "__main__":
210
+ main()
scripts/hidream_o1/creative_showcase.sh ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Creative showcase — push the envelope.
3
+ # Vertical social-media format (1440x2560) + cinema ultrawide (3104x1312).
4
+ set -euo pipefail
5
+
6
+ LAB="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
7
+ PY="$LAB/.venv/bin/python"
8
+ MODEL="$LAB/mlx_models/hidream-o1-dev-q6"
9
+ OUT_V="$LAB/sample_outputs/showcase_creative/vertical_1440x2560"
10
+ OUT_W="$LAB/sample_outputs/showcase_creative/wide_3104x1312"
11
+ mkdir -p "$OUT_V" "$OUT_W"
12
+
13
+ run() {
14
+ local out_dir="$1" name="$2" w="$3" h="$4" prompt="$5" seed="${6:-42}"
15
+ echo "=== $name ${w}x${h} (seed=$seed) ==="
16
+ cd "$LAB" && /usr/bin/time -l "$PY" scripts/hidream_o1/generate_hidream_o1_mlx.py \
17
+ --model-path "$MODEL" \
18
+ --prompt "$prompt" \
19
+ --width "$w" --height "$h" \
20
+ --output "$out_dir/$name.png" \
21
+ --seed "$seed" 2>&1 | grep -E "loaded|using|generation:|saved" | tail -3
22
+ echo ""
23
+ }
24
+
25
+ # === VERTICAL — social media influencer aesthetic (9:16 trained = 1440x2560) ===
26
+ run "$OUT_V" "01_fitness_influencer" 1440 2560 \
27
+ "vertical full-body shot of a strong female fitness influencer in matching black lululemon top and high-waist leggings, mid-deadlift in a sunlit industrial gym, hair in a slick high ponytail, chalk dust in the air around her hands, hyperreal sweat detail, dramatic side lighting from large factory windows, professional sports photography, hyper sharp focus, 50mm lens" \
28
+ 101
29
+
30
+ run "$OUT_V" "02_glamour_mirror" 1440 2560 \
31
+ "vertical full-body mirror selfie of a stylish female fashion influencer in a luxury Paris hotel bathroom, wearing a black silk slip dress and gold hoop earrings, holding a vintage Polaroid camera up to take the shot, marble walls and warm sconce lighting behind her, vogue magazine aesthetic, soft glamour, photoreal" \
32
+ 202
33
+
34
+ run "$OUT_V" "03_travel_iceland" 1440 2560 \
35
+ "vertical full-body shot of a male travel photographer standing alone on a black volcanic Iceland beach with massive basalt sea stacks behind him, wearing a heavy grey wool turtleneck, dark technical pants, and a beanie, breath visible in the cold air, dramatic overcast moody light, churning North Atlantic waves, cinematic landscape photography, 35mm" \
36
+ 303
37
+
38
+ run "$OUT_V" "04_streetwear_tokyo" 1440 2560 \
39
+ "vertical full-body shot of a Japanese streetwear influencer in front of a graffiti-covered wall in Harajuku, wearing oversized black raf simons hoodie with white text, baggy washed denim, chunky asics sneakers, ear-length neon green dyed hair, hands in pockets, head tilted with a slight smirk, golden-hour rim light, fashion editorial, photoreal" \
40
+ 404
41
+
42
+ # === CINEMA — ultrawide environments (21:9 trained = 3104x1312) ===
43
+ run "$OUT_W" "05_samurai_bamboo" 3104 1312 \
44
+ "ultrawide cinematic shot of two samurai facing off in a misty bamboo forest at dawn, swords drawn, kimonos in deep red and indigo, golden first light filtering through tall bamboo stalks, particles of mist floating between them, anamorphic 35mm film, Akira Kurosawa composition, Roger Deakins cinematography" \
45
+ 505
46
+
47
+ run "$OUT_W" "06_astronaut_mars" 3104 1312 \
48
+ "ultrawide cinematic shot of a single astronaut walking across a vast Martian canyon floor, deep red rocks rising hundreds of meters on either side, footprints behind in the rust dust, faint planet earth visible as a blue dot in the salmon sky, helmet visor reflecting the alien landscape, sense of cosmic scale and isolation, Denis Villeneuve aesthetic" \
49
+ 606
50
+
51
+ run "$OUT_W" "07_dragon_kingdom" 3104 1312 \
52
+ "ultrawide cinematic shot of an enormous black dragon flying low over a medieval mountain kingdom at dusk, scales glistening, wings catching the last orange sunlight, peasants in the foreground looking up in awe from a stone bridge, the kingdom castle and snow-capped peaks in the deep background, fantasy epic, Peter Jackson composition, wide-angle lens, painterly atmosphere" \
53
+ 707
54
+
55
+ echo "=== creative showcase complete ==="
56
+ ls -la "$OUT_V" "$OUT_W"
scripts/hidream_o1/flow_match.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MLX port of FlashFlowMatchEulerDiscreteScheduler from HiDream-O1.
2
+
3
+ Reference: HiDream-ai/HiDream-O1-Image @ models/flash_scheduler.py.
4
+ Trimmed to the path the Dev recipe actually uses:
5
+ - num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False
6
+ - timesteps overridden by DEFAULT_TIMESTEPS after construction
7
+ - karras/exponential/beta sigmas not used
8
+ - step() with s_churn/s_tmin/s_tmax stripped (always defaults)
9
+
10
+ The math is verbatim from upstream — only the framework swap (torch -> mlx).
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import mlx.core as mx
15
+ import numpy as np
16
+
17
+
18
+ # Verbatim from HiDream-O1 models/pipeline.py
19
+ DEFAULT_TIMESTEPS = [
20
+ 999, 987, 974, 960, 945, 929, 913, 895, 877, 857, 836, 814, 790, 764, 737,
21
+ 707, 675, 640, 602, 560, 515, 464, 409, 347, 278, 199, 110, 8,
22
+ ]
23
+
24
+
25
+ class FlashFlowMatchScheduler:
26
+ """Euler scheduler for flow matching, with optional noise injection."""
27
+
28
+ def __init__(self, num_train_timesteps: int = 1000, shift: float = 1.0):
29
+ self.num_train_timesteps = num_train_timesteps
30
+ self.shift = shift
31
+
32
+ sigmas = np.linspace(1.0, 1.0 / num_train_timesteps, num_train_timesteps, dtype=np.float32)
33
+ sigmas = shift * sigmas / (1.0 + (shift - 1.0) * sigmas)
34
+ self.sigmas_np = sigmas
35
+ self.timesteps_np = sigmas * num_train_timesteps
36
+
37
+ self.num_inference_steps: int | None = None
38
+ self._step_index: int | None = None
39
+
40
+ def set_timesteps(self, num_inference_steps: int, custom_timesteps: list[int] | None = None):
41
+ if custom_timesteps is not None:
42
+ timesteps = np.asarray(custom_timesteps, dtype=np.float32)
43
+ sigmas = (timesteps / self.num_train_timesteps).astype(np.float32)
44
+ sigmas = np.append(sigmas, 0.0).astype(np.float32)
45
+ else:
46
+ timesteps = np.linspace(self.num_train_timesteps, 1.0, num_inference_steps, dtype=np.float32)
47
+ sigmas = (timesteps / self.num_train_timesteps).astype(np.float32)
48
+ sigmas = self.shift * sigmas / (1.0 + (self.shift - 1.0) * sigmas)
49
+ sigmas = np.append(sigmas, 0.0).astype(np.float32)
50
+
51
+ self.num_inference_steps = len(timesteps)
52
+ self.timesteps_np = timesteps
53
+ self.sigmas_np = sigmas
54
+ self._step_index = None
55
+
56
+ @property
57
+ def timesteps(self) -> mx.array:
58
+ return mx.array(self.timesteps_np)
59
+
60
+ @property
61
+ def sigmas(self) -> mx.array:
62
+ return mx.array(self.sigmas_np)
63
+
64
+ def _init_step_index(self, timestep_value: float):
65
+ ts = self.timesteps_np
66
+ matches = np.where(np.isclose(ts, timestep_value, atol=1e-3))[0]
67
+ if len(matches) == 0:
68
+ raise ValueError(f"timestep {timestep_value!r} not in scheduler.timesteps")
69
+ self._step_index = int(matches[1] if len(matches) > 1 else matches[0])
70
+
71
+ def step(self, model_output, timestep, sample,
72
+ s_noise=1.0, noise_clip_std=0.0, seed=None):
73
+ if self._step_index is None:
74
+ self._init_step_index(float(timestep))
75
+ idx = self._step_index
76
+
77
+ sigma = float(self.sigmas_np[idx])
78
+ sigma_next = float(self.sigmas_np[idx + 1])
79
+
80
+ sample_f = sample.astype(mx.float32)
81
+ model_output_f = model_output.astype(mx.float32)
82
+
83
+ denoised = sample_f - model_output_f * sigma
84
+
85
+ if idx < self.num_inference_steps:
86
+ if seed is not None:
87
+ key = mx.random.key(seed + idx)
88
+ noise = mx.random.normal(model_output_f.shape, key=key)
89
+ else:
90
+ noise = mx.random.normal(model_output_f.shape)
91
+
92
+ if noise_clip_std > 0:
93
+ std = float(mx.std(noise))
94
+ clip = noise_clip_std * std
95
+ noise = mx.clip(noise, -clip, clip)
96
+
97
+ new_sample = sigma_next * noise * s_noise + (1.0 - sigma_next) * denoised
98
+ else:
99
+ new_sample = denoised
100
+
101
+ self._step_index += 1
102
+ return new_sample.astype(sample.dtype)
scripts/hidream_o1/generate_hidream_o1_mlx.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """HiDream-O1-Image-Dev inference on MLX (T2I, Dev recipe only)."""
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import sys
7
+ import time
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+
12
+ HERE = Path(__file__).parent
13
+ sys.path.insert(0, str(HERE))
14
+
15
+ from pipeline_helpers import (
16
+ PATCH_SIZE, NOISE_SCALE_DEFAULT, T_EPS,
17
+ build_attention_mask, find_closest_resolution, patchify, unpatchify,
18
+ )
19
+
20
+
21
+ def shape_test():
22
+ print("=== HiDream-O1 MLX lab — shape sanity test ===")
23
+
24
+ H, W = 512, 512
25
+ x = np.random.randn(3, H, W).astype(np.float32)
26
+ p = patchify(x)
27
+ expected_patches = (H // PATCH_SIZE) * (W // PATCH_SIZE)
28
+ expected_dim = 3 * PATCH_SIZE * PATCH_SIZE
29
+ assert p.shape == (expected_patches, expected_dim)
30
+ x2 = unpatchify(p, H // PATCH_SIZE, W // PATCH_SIZE)
31
+ assert np.allclose(x, x2)
32
+ print(f" [ok] patchify roundtrip: {x.shape} -> {p.shape} -> {x2.shape}")
33
+
34
+ txt_seq_len = 12
35
+ img_seq_len = expected_patches
36
+ total = txt_seq_len + img_seq_len
37
+ token_types = np.zeros((1, total), dtype=np.int64)
38
+ token_types[0, txt_seq_len - 1: total] = 1
39
+ DTYPE_MIN = -1e9
40
+ mask = build_attention_mask(token_types, DTYPE_MIN)
41
+ assert mask.shape == (1, 1, total, total)
42
+ assert mask[0, 0, 0, 1] == DTYPE_MIN and mask[0, 0, 0, 0] == 0
43
+ assert (mask[0, 0, txt_seq_len + 5] == 0).all()
44
+ print(f" [ok] mask: shape={mask.shape}, text rows causal, gen rows bidirectional")
45
+
46
+ from flow_match import FlashFlowMatchScheduler, DEFAULT_TIMESTEPS
47
+ sched = FlashFlowMatchScheduler(num_train_timesteps=1000, shift=1.0)
48
+ sched.set_timesteps(28, custom_timesteps=DEFAULT_TIMESTEPS)
49
+ assert sched.num_inference_steps == 28
50
+ assert len(sched.sigmas_np) == 29
51
+ diffs = np.diff(sched.sigmas_np)
52
+ assert (diffs <= 1e-6).all()
53
+ print(f" [ok] scheduler: 28 steps, sigmas {sched.sigmas_np[0]:.4f} -> {sched.sigmas_np[-2]:.4f} -> 0")
54
+
55
+ try:
56
+ import mlx.core as mx
57
+ except ImportError:
58
+ print(" [skip] mlx not available")
59
+ else:
60
+ B, N, D = 1, expected_patches, expected_dim
61
+ z = mx.random.normal((B, N, D))
62
+ model_output = mx.random.normal((B, N, D))
63
+ ts0 = float(sched.timesteps_np[0])
64
+ z2 = sched.step(model_output, ts0, z, s_noise=7.5, noise_clip_std=2.5, seed=42)
65
+ assert z2.shape == z.shape
66
+ print(f" [ok] mlx step: z {z.shape} -> z' {z2.shape}, dtype {z2.dtype}")
67
+
68
+ snapped = find_closest_resolution(540, 960)
69
+ assert snapped in [(1440, 2560), (1312, 3104)]
70
+ print(f" [ok] resolution snap 540x960 -> {snapped}")
71
+ print("=== all shape tests passed ===")
72
+
73
+
74
+ def run_inference(args):
75
+ import mlx.core as mx
76
+ try:
77
+ from mlx_vlm import load as mlx_vlm_load
78
+ except ImportError:
79
+ sys.exit("mlx-vlm not installed. uv pip install 'mlx-vlm>=0.3.3'")
80
+ from PIL import Image
81
+ import tqdm
82
+
83
+ from pipeline_helpers import build_t2i_text_sample
84
+ from flow_match import FlashFlowMatchScheduler, DEFAULT_TIMESTEPS
85
+ from hidream_model import HiDreamConfig, build_model, forward_generation, precompute_text_embeds_with_vision
86
+
87
+ print(f"loading model from {args.model_path} ...", flush=True)
88
+ t0 = time.time()
89
+ backbone, processor = mlx_vlm_load(args.model_path)
90
+ print(f" loaded in {time.time() - t0:.1f}s")
91
+
92
+ cfg = HiDreamConfig()
93
+ model = build_model(cfg, backbone)
94
+ custom_path = Path(args.model_path) / "extras" / "custom_heads.safetensors"
95
+ if not custom_path.exists():
96
+ sys.exit(f"missing {custom_path}; rerun the converter")
97
+ custom_weights = mx.load(str(custom_path))
98
+ model.load_weights(list(custom_weights.items()), strict=False)
99
+ print(f" loaded {len(custom_weights)} custom-head tensors")
100
+
101
+ width, height = args.width, args.height
102
+ if not args.no_snap_resolution:
103
+ sw, sh = find_closest_resolution(width, height)
104
+ if (sw, sh) != (width, height):
105
+ print(f" resolution snapped {width}x{height} -> {sw}x{sh} (trained dim)")
106
+ width, height = sw, sh
107
+ # patch-align fallback (HiDream only operates on multiples of PATCH_SIZE)
108
+ width = (width // PATCH_SIZE) * PATCH_SIZE
109
+ height = (height // PATCH_SIZE) * PATCH_SIZE
110
+ print(f" using {width}x{height} ({(width//PATCH_SIZE)*(height//PATCH_SIZE)} patches)")
111
+
112
+ h_patches = height // PATCH_SIZE
113
+ w_patches = width // PATCH_SIZE
114
+
115
+ tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
116
+ for n in ("boi", "bor", "eor", "bot", "tms"):
117
+ if not hasattr(tokenizer, f"{n}_token"):
118
+ setattr(tokenizer, f"{n}_token", f"<|{n}_token|>")
119
+
120
+ refs = list(args.ref_images or [])
121
+ if refs:
122
+ from pipeline_helpers import build_edit_text_sample
123
+ print(f" edit mode: K={len(refs)} reference image(s)")
124
+ sample = build_edit_text_sample(
125
+ args.prompt, refs, height, width, tokenizer, processor, backbone.config,
126
+ )
127
+ else:
128
+ sample = build_t2i_text_sample(args.prompt, height, width, tokenizer, processor, backbone.config)
129
+
130
+ input_ids = mx.array(sample["input_ids"])
131
+ position_ids = mx.array(sample["position_ids"])
132
+ token_types = mx.array(sample["token_types"])
133
+ vinput_mask = sample["vinput_mask"]
134
+
135
+ # Edit-mode extras (None for T2I)
136
+ pixel_values_mx = mx.array(sample["pixel_values"]).astype(mx.bfloat16) if refs else None
137
+ image_grid_thw_mx = mx.array(sample["image_grid_thw"]) if refs else None
138
+ ref_patches_mx = mx.array(sample["ref_patches"]).astype(mx.bfloat16) if refs else None
139
+ tgt_image_len = sample.get("tgt_image_len", (height // PATCH_SIZE) * (width // PATCH_SIZE))
140
+
141
+ DTYPE_MIN = -1e4
142
+ mask4d = mx.array(build_attention_mask(sample["token_types"], DTYPE_MIN)).astype(mx.bfloat16)
143
+
144
+ rng_key = mx.random.key(args.seed + 1)
145
+ noise = args.noise_scale_start * mx.random.normal((1, 3, height, width), key=rng_key)
146
+ noise_np = np.asarray(noise)
147
+ z = mx.array(patchify(noise_np[0])[None]).astype(mx.bfloat16)
148
+
149
+ sched = FlashFlowMatchScheduler(num_train_timesteps=1000, shift=1.0)
150
+ sched.set_timesteps(args.num_inference_steps, custom_timesteps=DEFAULT_TIMESTEPS)
151
+ noise_scale_schedule = np.linspace(args.noise_scale_start, args.noise_scale_end,
152
+ len(sched.timesteps_np))
153
+
154
+ # Precompute MLX-native indices for the patch slice. Two index sets:
155
+ # - vinput_idx: positions where the model gets vinputs (tgt + refs in edit mode, tgt only in T2I)
156
+ # - tgt_idx: positions where the TARGET image patches live (subset of vinput_idx in edit mode)
157
+ vinput_idx = mx.array(np.where(vinput_mask[0])[0].astype(np.int32))
158
+ if refs:
159
+ tgt_mask = sample["vinput_mask_tgt_only"]
160
+ tgt_idx = mx.array(np.where(tgt_mask[0])[0].astype(np.int32))
161
+ else:
162
+ tgt_idx = vinput_idx
163
+
164
+ # Precompute text+vision inputs_embeds — these are constant across denoising
165
+ # steps (only the vinputs / timestep change), so we save 28x the vision work.
166
+ inputs_embeds_pre = precompute_text_embeds_with_vision(
167
+ model, cfg, input_ids,
168
+ pixel_values=pixel_values_mx, image_grid_thw=image_grid_thw_mx,
169
+ )
170
+ mx.eval(inputs_embeds_pre)
171
+ print(f" precomputed inputs_embeds: {inputs_embeds_pre.shape}")
172
+
173
+ t_start = time.time()
174
+ for step_idx, step_t in enumerate(tqdm.tqdm(sched.timesteps_np, desc="generating")):
175
+ # Native MLX scalar — no numpy roundtrip
176
+ t_pixeldit = mx.full([1], 1.0 - float(step_t) / 1000.0, dtype=mx.float32)
177
+ sigma = max(float(step_t) / 1000.0, T_EPS)
178
+
179
+ # Edit mode: vinputs is the target z concatenated with the clean ref patches.
180
+ # The forward_generation embeds + concatenates these to inputs_embeds; the
181
+ # mask routes attention so refs are bidirectional too.
182
+ if refs:
183
+ vinputs = mx.concatenate([z, ref_patches_mx], axis=1)
184
+ else:
185
+ vinputs = z
186
+
187
+ x_pred = forward_generation(
188
+ model, cfg,
189
+ inputs_embeds_with_vision=inputs_embeds_pre,
190
+ position_ids=position_ids,
191
+ vinputs=vinputs,
192
+ timestep=t_pixeldit,
193
+ input_ids=input_ids,
194
+ token_types=token_types,
195
+ attention_mask_4d=mask4d,
196
+ )
197
+ # Slice the target patches only (excludes refs in edit mode).
198
+ gen_patches_mx = mx.take(x_pred, tgt_idx, axis=1).astype(mx.float32)
199
+
200
+ # Optional: clamp x_pred to [q, 1-q] quantile per step. Upstream has this
201
+ # commented out (pipeline.py line 327: `x_pred = clamp_tensor(x_pred, percentage=0.01)`).
202
+ # The patch-grid artifact comes from per-patch outliers — clamping the per-step
203
+ # x_pred range trims the worst extremes that show up as 32-pixel grid lines.
204
+ if args.clamp_x_pred > 0:
205
+ gp_np = np.asarray(gen_patches_mx)
206
+ lo = float(np.quantile(gp_np, args.clamp_x_pred))
207
+ hi = float(np.quantile(gp_np, 1.0 - args.clamp_x_pred))
208
+ gen_patches_mx = mx.clip(gen_patches_mx, lo, hi)
209
+
210
+ if args.diag and step_idx in (0, 1, 13, 27):
211
+ zarr = np.asarray(z.astype(mx.float32))
212
+ gp = np.asarray(gen_patches_mx)
213
+ print(f" [diag step {step_idx}] sigma={sigma:.4f} "
214
+ f"z(mean={zarr.mean():.3f},std={zarr.std():.3f}) "
215
+ f"x_pred(mean={gp.mean():.3f},std={gp.std():.3f},"
216
+ f"min={gp.min():.3f},max={gp.max():.3f})")
217
+
218
+ v = (gen_patches_mx - z.astype(mx.float32)) / sigma
219
+ model_output = -v
220
+ z = sched.step(model_output, float(step_t), z,
221
+ s_noise=float(noise_scale_schedule[step_idx]),
222
+ noise_clip_std=args.noise_clip_std,
223
+ seed=args.seed)
224
+ # Force eval per step so timing is honest (otherwise mlx's lazy eval
225
+ # batches the entire loop into the final save, hiding per-step cost).
226
+ mx.eval(z)
227
+
228
+ elapsed = time.time() - t_start
229
+ print(f" generation: {elapsed:.1f}s ({elapsed / args.num_inference_steps:.2f}s/step)")
230
+
231
+ img = (z + 1) / 2
232
+ img_np = np.asarray(img[0].astype(mx.float32))
233
+ rgb = unpatchify(img_np, h_patches, w_patches)
234
+ arr = np.clip(rgb.transpose(1, 2, 0) * 255, 0, 255).astype(np.uint8)
235
+ if args.blend_seams > 0:
236
+ arr = _blend_patch_seams(arr, PATCH_SIZE, radius=args.blend_seams)
237
+ Image.fromarray(arr).save(args.output)
238
+ print(f"saved -> {args.output}")
239
+
240
+
241
+ def _blend_patch_seams(rgb: np.ndarray, patch: int, radius: int = 1) -> np.ndarray:
242
+ """Smooth the 32-pixel patch grid seams.
243
+
244
+ HiDream's final_layer2 predicts each patch independently. In flat
245
+ regions adjacent patches don't quite agree on color, so a regular grid
246
+ of seams shows up. We replace each seam row/col with a weighted average
247
+ of its (2*radius+1) neighbours using a triangular kernel — heavier
248
+ smoothing than the previous box average, but still limited to a thin
249
+ band around each seam so non-seam content is preserved.
250
+
251
+ rgb: [H, W, 3] uint8
252
+ radius: 0 = off, 1 = ±1 row blur (subtle), 2-3 = visible at HD/QHD,
253
+ 4+ = noticeable softening of seams
254
+ """
255
+ if radius <= 0:
256
+ return rgb
257
+ out = rgb.astype(np.float32).copy()
258
+ H, W, _ = rgb.shape
259
+
260
+ # Triangular kernel weights for (2*radius+1) rows
261
+ weights = np.array([radius - abs(i - radius) + 1 for i in range(2 * radius + 1)], dtype=np.float32)
262
+ weights = weights / weights.sum()
263
+
264
+ # Horizontal seams (vary by y); rebuild the seam row(s) as a triangular
265
+ # weighted average of (2*radius+1) neighbouring rows.
266
+ for y in range(patch, H, patch):
267
+ for offset in (-1, 0): # blend the rows immediately above and at the seam
268
+ yy = y + offset
269
+ if 0 <= yy < H:
270
+ lo = max(0, yy - radius)
271
+ hi = min(H, yy + radius + 1)
272
+ # Pad weights to match the actually-clipped band
273
+ w_slice = weights[radius - (yy - lo): radius + (hi - yy)]
274
+ w_slice = w_slice / w_slice.sum()
275
+ band = out[lo:hi] # [n, W, 3]
276
+ smoothed = (band * w_slice[:, None, None]).sum(axis=0)
277
+ out[yy] = smoothed
278
+
279
+ # Vertical seams
280
+ for x in range(patch, W, patch):
281
+ for offset in (-1, 0):
282
+ xx = x + offset
283
+ if 0 <= xx < W:
284
+ lo = max(0, xx - radius)
285
+ hi = min(W, xx + radius + 1)
286
+ w_slice = weights[radius - (xx - lo): radius + (hi - xx)]
287
+ w_slice = w_slice / w_slice.sum()
288
+ band = out[:, lo:hi] # [H, n, 3]
289
+ smoothed = (band * w_slice[None, :, None]).sum(axis=1)
290
+ out[:, xx] = smoothed
291
+
292
+ return np.clip(out, 0, 255).astype(np.uint8)
293
+
294
+
295
+ def main(argv=None):
296
+ ap = argparse.ArgumentParser()
297
+ ap.add_argument("--shape-test", action="store_true")
298
+ ap.add_argument("--model-path", default="mlx_models/hidream-o1-dev-q4")
299
+ ap.add_argument("--prompt", default="a small red mushroom on a bed of moss, soft daylight, macro photo")
300
+ ap.add_argument("--output", default="sample_outputs/out.png")
301
+ ap.add_argument("--width", type=int, default=512)
302
+ ap.add_argument("--height", type=int, default=512)
303
+ ap.add_argument("--num-inference-steps", type=int, default=28)
304
+ ap.add_argument("--no-snap-resolution", action="store_true",
305
+ help="Disable snapping to trained PREDEFINED_RESOLUTIONS list (off-spec dims produce visible patch artifacts)")
306
+ ap.add_argument("--diag", action="store_true",
307
+ help="Print stats of z and x_pred at a few key steps")
308
+ ap.add_argument("--blend-seams", type=int, default=0,
309
+ help="Post-process: smooth patch-grid seams with this radius (0 = off, 1-2 typical)")
310
+ ap.add_argument("--clamp-x-pred", type=float, default=0.0,
311
+ help="Per-step quantile clamp on x_pred (0 = off; 0.01 = upstream's commented-out value)")
312
+ ap.add_argument("--ref-images", nargs="*", default=[],
313
+ help="Reference image paths for edit/multi-ref mode (1-3). Empty = pure T2I.")
314
+ ap.add_argument("--noise-scale-start", type=float, default=NOISE_SCALE_DEFAULT)
315
+ ap.add_argument("--noise-scale-end", type=float, default=NOISE_SCALE_DEFAULT)
316
+ ap.add_argument("--noise-clip-std", type=float, default=2.5)
317
+ ap.add_argument("--seed", type=int, default=32)
318
+ args = ap.parse_args(argv)
319
+
320
+ if args.shape_test:
321
+ shape_test()
322
+ return
323
+ run_inference(args)
324
+
325
+
326
+ if __name__ == "__main__":
327
+ main()
scripts/hidream_o1/hidream_model.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Standalone MLX model wrapper for HiDream-O1-Image."""
2
+ from __future__ import annotations
3
+
4
+ import math
5
+ from dataclasses import dataclass
6
+ from typing import Optional
7
+
8
+ import mlx.core as mx
9
+ import mlx.nn as nn
10
+ import numpy as np
11
+
12
+
13
+ class TimestepEmbedder(nn.Module):
14
+ def __init__(self, hidden_size: int, frequency_embedding_size: int = 256):
15
+ super().__init__()
16
+ self.frequency_embedding_size = frequency_embedding_size
17
+ self.fc1 = nn.Linear(frequency_embedding_size, hidden_size, bias=True)
18
+ self.fc2 = nn.Linear(hidden_size, hidden_size, bias=True)
19
+
20
+ @staticmethod
21
+ def timestep_embedding(t: mx.array, dim: int, max_period: float = 10000.0) -> mx.array:
22
+ half = dim // 2
23
+ freqs = mx.exp(-math.log(max_period) * mx.arange(0, half, dtype=mx.float32) / half)
24
+ args = t[:, None].astype(mx.float32) * freqs[None]
25
+ emb = mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1)
26
+ if dim % 2:
27
+ emb = mx.concatenate([emb, mx.zeros_like(emb[:, :1])], axis=-1)
28
+ return emb
29
+
30
+ def __call__(self, t: mx.array) -> mx.array:
31
+ t_freq = self.timestep_embedding(t * 1000.0, self.frequency_embedding_size)
32
+ return self.fc2(nn.silu(self.fc1(t_freq.astype(self.fc1.weight.dtype))))
33
+
34
+
35
+ class BottleneckPatchEmbed(nn.Module):
36
+ def __init__(self, patch_size: int = 32, in_chans: int = 3,
37
+ pca_dim: int = 1024, embed_dim: int = 4096):
38
+ super().__init__()
39
+ self.proj1 = nn.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False)
40
+ self.proj2 = nn.Linear(pca_dim, embed_dim, bias=True)
41
+
42
+ def __call__(self, x: mx.array) -> mx.array:
43
+ return self.proj2(self.proj1(x))
44
+
45
+
46
+ class FinalLayer(nn.Module):
47
+ def __init__(self, hidden_size: int, patch_size: int = 32, out_channels: int = 3):
48
+ super().__init__()
49
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
50
+
51
+ def __call__(self, x: mx.array) -> mx.array:
52
+ return self.linear(x)
53
+
54
+
55
+ CUSTOM_HEAD_KEY_MAP = {
56
+ "model.t_embedder1.mlp.0.weight": "t_embedder1.fc1.weight",
57
+ "model.t_embedder1.mlp.0.bias": "t_embedder1.fc1.bias",
58
+ "model.t_embedder1.mlp.2.weight": "t_embedder1.fc2.weight",
59
+ "model.t_embedder1.mlp.2.bias": "t_embedder1.fc2.bias",
60
+ "model.x_embedder.proj1.weight": "x_embedder.proj1.weight",
61
+ "model.x_embedder.proj2.weight": "x_embedder.proj2.weight",
62
+ "model.x_embedder.proj2.bias": "x_embedder.proj2.bias",
63
+ "model.final_layer2.linear.weight": "final_layer2.linear.weight",
64
+ "model.final_layer2.linear.bias": "final_layer2.linear.bias",
65
+ }
66
+
67
+
68
+ @dataclass
69
+ class HiDreamConfig:
70
+ hidden_size: int = 4096
71
+ patch_size: int = 32
72
+ in_channels: int = 3
73
+ bottleneck_dim: int = 1024
74
+ tms_token_id: int = 151673
75
+ image_token_id: int = 151655
76
+ video_token_id: int = 151656
77
+ vision_start_token_id: int = 151652
78
+
79
+
80
+ def build_model(cfg: HiDreamConfig, mlx_vlm_qwen3_vl_model):
81
+ class HiDream(nn.Module):
82
+ def __init__(self):
83
+ super().__init__()
84
+ self.visual = mlx_vlm_qwen3_vl_model.vision_tower
85
+ self.language_model = mlx_vlm_qwen3_vl_model.language_model
86
+ self.t_embedder1 = TimestepEmbedder(cfg.hidden_size)
87
+ self.x_embedder = BottleneckPatchEmbed(
88
+ patch_size=cfg.patch_size, in_chans=cfg.in_channels,
89
+ pca_dim=cfg.bottleneck_dim, embed_dim=cfg.hidden_size,
90
+ )
91
+ self.final_layer2 = FinalLayer(
92
+ hidden_size=cfg.hidden_size,
93
+ patch_size=cfg.patch_size,
94
+ out_channels=cfg.in_channels,
95
+ )
96
+
97
+ return HiDream()
98
+
99
+
100
+ def precompute_text_embeds_with_vision(model, cfg, input_ids, pixel_values=None, image_grid_thw=None):
101
+ """Compute text embeddings + (in edit mode) inject vision features at image_token
102
+ positions. Returns embeds [B, S_text, hidden]. Call once before the denoising
103
+ loop — output is constant across timesteps.
104
+ """
105
+ embed_tokens = model.language_model.model.embed_tokens
106
+ inputs_embeds = embed_tokens(input_ids)
107
+
108
+ if pixel_values is None or image_grid_thw is None:
109
+ return inputs_embeds
110
+
111
+ vt_out = model.visual(pixel_values, image_grid_thw)
112
+ image_features = vt_out[0] if isinstance(vt_out, tuple) else vt_out
113
+ if isinstance(image_features, (list, tuple)):
114
+ image_features = mx.concatenate(image_features, axis=0)
115
+
116
+ # Build a [B, S, H] tensor that has image_features at image_token positions
117
+ # and inputs_embeds everywhere else, via mx.where on a broadcast mask.
118
+ ids_np = np.asarray(input_ids)
119
+ img_positions = np.where(ids_np[0] == cfg.image_token_id)[0]
120
+ if img_positions.shape[0] != image_features.shape[0]:
121
+ raise RuntimeError(
122
+ f"image_features {image_features.shape[0]} != "
123
+ f"image_token_id positions {img_positions.shape[0]} (input_ids was: {ids_np.shape})"
124
+ )
125
+
126
+ B, S, H = inputs_embeds.shape
127
+ # Build aligned-to-S features: zero everywhere except at image positions.
128
+ aligned = np.zeros((B, S, H), dtype=np.float32)
129
+ aligned[0, img_positions] = np.asarray(image_features.astype(mx.float32))
130
+ aligned_mx = mx.array(aligned).astype(inputs_embeds.dtype)
131
+
132
+ # Mask: 1 at image positions, 0 elsewhere
133
+ mask_2d = (ids_np == cfg.image_token_id).astype(np.bool_)
134
+ mask_3d = np.broadcast_to(mask_2d[..., None], (B, S, H))
135
+ mask_mx = mx.array(mask_3d.copy())
136
+
137
+ return mx.where(mask_mx, aligned_mx, inputs_embeds)
138
+
139
+
140
+ def forward_generation(model, cfg, inputs_embeds_with_vision, position_ids, vinputs, timestep,
141
+ input_ids, token_types, attention_mask_4d):
142
+ """Per-step forward. Takes the precomputed text+vision inputs_embeds, the
143
+ fresh-noise vinputs, and the timestep. Returns x_pred [B, S_total, patch_dim].
144
+
145
+ Signature change vs the T2I-only version: pixel_values/image_grid_thw moved
146
+ out (call precompute_text_embeds_with_vision once before the loop). input_ids
147
+ is still needed inside because we look up tms_token positions for t_emb scatter.
148
+ """
149
+ inputs_embeds = inputs_embeds_with_vision
150
+
151
+ t_emb = model.t_embedder1(timestep)
152
+ tms_mask = (input_ids == cfg.tms_token_id)
153
+ tms_mask_3d = mx.broadcast_to(tms_mask[..., None], inputs_embeds.shape)
154
+ t_emb_expanded = mx.broadcast_to(t_emb[:, None, :], inputs_embeds.shape)
155
+ inputs_embeds = mx.where(tms_mask_3d, t_emb_expanded, inputs_embeds)
156
+
157
+ vinputs_embedded = model.x_embedder(vinputs).astype(inputs_embeds.dtype)
158
+ inputs_embeds = mx.concatenate([inputs_embeds, vinputs_embedded], axis=1)
159
+
160
+ text_model = model.language_model.model
161
+ # mlx-vlm Qwen3VLModel.__call__ accepts (inputs, inputs_embeds, mask, cache, position_ids, ...).
162
+ # Pass our 4D additive mask directly; it bypasses the internal causal mask.
163
+ # `inputs` is required positionally but ignored when inputs_embeds is set
164
+ # in mlx-vlm's implementation — pass a placeholder of correct shape.
165
+ placeholder = mx.zeros(inputs_embeds.shape[:2], dtype=mx.int32)
166
+ h = text_model(
167
+ placeholder,
168
+ inputs_embeds=inputs_embeds,
169
+ mask=attention_mask_4d,
170
+ cache=None,
171
+ position_ids=position_ids,
172
+ )
173
+ # Apply final norm. mlx-vlm's Qwen3VLModel applies it internally and returns hidden_states.
174
+ x_pred = model.final_layer2(h)
175
+ return x_pred
scripts/hidream_o1/pipeline_helpers.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Helpers ported from HiDream-O1 models/pipeline.py + models/utils.py."""
2
+ from __future__ import annotations
3
+
4
+ import math
5
+ from typing import Sequence
6
+ import numpy as np
7
+
8
+ PATCH_SIZE = 32
9
+ TIMESTEP_TOKEN_NUM = 1
10
+ NOISE_SCALE_DEFAULT = 7.5
11
+ T_EPS = 0.001
12
+ TMS_TOKEN_ID = 151673 # from qwen3_vl_transformers.py — Qwen3VLModel.tms_token_id
13
+ CONDITION_IMAGE_SIZE = 384 # vision-tower-side size for reference images
14
+
15
+ PREDEFINED_RESOLUTIONS = [
16
+ (2048, 2048),
17
+ (2304, 1728), (1728, 2304),
18
+ (2560, 1440), (1440, 2560),
19
+ (2496, 1664), (1664, 2496),
20
+ (3104, 1312), (1312, 3104),
21
+ (2304, 1792), (1792, 2304),
22
+ ]
23
+
24
+
25
+ def find_closest_resolution(width: int, height: int) -> tuple[int, int]:
26
+ img_ratio = width / height
27
+ best, min_diff = None, float("inf")
28
+ for w, h in PREDEFINED_RESOLUTIONS:
29
+ diff = abs(w / h - img_ratio)
30
+ if diff < min_diff:
31
+ min_diff, best = diff, (w, h)
32
+ return best
33
+
34
+
35
+ def patchify(img_chw: np.ndarray, patch: int = PATCH_SIZE) -> np.ndarray:
36
+ C, H, W = img_chw.shape
37
+ assert H % patch == 0 and W % patch == 0
38
+ x = img_chw.reshape(C, H // patch, patch, W // patch, patch)
39
+ x = np.transpose(x, (1, 3, 0, 2, 4))
40
+ return x.reshape(H // patch * W // patch, C * patch * patch)
41
+
42
+
43
+ def unpatchify(patches_nd, h_patches, w_patches, patch=PATCH_SIZE, channels=3):
44
+ x = patches_nd.reshape(h_patches, w_patches, channels, patch, patch)
45
+ x = np.transpose(x, (2, 0, 3, 1, 4))
46
+ return x.reshape(channels, h_patches * patch, w_patches * patch)
47
+
48
+
49
+ def build_t2i_text_sample(prompt, height, width, tokenizer, processor, model_config):
50
+ image_token_id = model_config.image_token_id
51
+ video_token_id = model_config.video_token_id
52
+ vision_start_token_id = model_config.vision_start_token_id
53
+
54
+ image_len = (height // PATCH_SIZE) * (width // PATCH_SIZE)
55
+ boi_token = getattr(tokenizer, "boi_token", "<|boi_token|>")
56
+ tms_token = getattr(tokenizer, "tms_token", "<|tms_token|>")
57
+
58
+ messages = [{"role": "user", "content": prompt}]
59
+ template_caption = (
60
+ processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
+ + boi_token + tms_token * TIMESTEP_TOKEN_NUM
62
+ )
63
+ input_ids = np.asarray(
64
+ tokenizer.encode(template_caption, add_special_tokens=False),
65
+ dtype=np.int64,
66
+ ).reshape(1, -1)
67
+
68
+ image_grid_thw = np.asarray(
69
+ [[1, height // PATCH_SIZE, width // PATCH_SIZE]], dtype=np.int64
70
+ )
71
+
72
+ vision_tokens = np.full((1, image_len), image_token_id, dtype=input_ids.dtype)
73
+ vision_tokens[0, 0] = vision_start_token_id
74
+ input_ids_pad = np.concatenate([input_ids, vision_tokens], axis=-1)
75
+
76
+ position_ids, _ = get_rope_index_fix_point(
77
+ spatial_merge_size=1,
78
+ image_token_id=image_token_id,
79
+ video_token_id=video_token_id,
80
+ vision_start_token_id=vision_start_token_id,
81
+ input_ids=input_ids_pad,
82
+ image_grid_thw=image_grid_thw,
83
+ skip_vision_start_token=[1],
84
+ )
85
+
86
+ txt_seq_len = input_ids.shape[-1]
87
+ all_seq_len = position_ids.shape[-1]
88
+
89
+ token_types = np.zeros((1, all_seq_len), dtype=np.int64)
90
+ bgn = txt_seq_len - TIMESTEP_TOKEN_NUM
91
+ token_types[0, bgn: bgn + image_len + TIMESTEP_TOKEN_NUM] = 1
92
+ # Tag the tms positions distinctly so vinput_mask excludes them — they're
93
+ # for the timestep embedding, not actual image patches.
94
+ token_types[0, txt_seq_len - TIMESTEP_TOKEN_NUM: txt_seq_len] = 3
95
+ vinput_mask = (token_types == 1)
96
+ token_types_bin = (token_types > 0).astype(np.int64)
97
+
98
+ return {
99
+ "input_ids": input_ids,
100
+ "position_ids": position_ids,
101
+ "token_types": token_types_bin,
102
+ "vinput_mask": vinput_mask,
103
+ }
104
+
105
+
106
+ def get_rope_index_fix_point(
107
+ spatial_merge_size, image_token_id, video_token_id, vision_start_token_id,
108
+ input_ids, image_grid_thw=None, video_grid_thw=None, attention_mask=None,
109
+ skip_vision_start_token=None, fix_point=4096,
110
+ ):
111
+ if input_ids is None:
112
+ raise ValueError("input_ids is required")
113
+ if attention_mask is None:
114
+ attention_mask = np.ones_like(input_ids)
115
+
116
+ B, S = input_ids.shape
117
+ position_ids = np.ones((3, B, S), dtype=input_ids.dtype)
118
+
119
+ image_index = 0
120
+ video_index = 0
121
+ mrope_position_deltas: list[int] = []
122
+
123
+ for i in range(B):
124
+ ids_i = input_ids[i][attention_mask[i] == 1]
125
+ vision_start_indices = np.argwhere(ids_i == vision_start_token_id).reshape(-1)
126
+ vision_tokens = ids_i[vision_start_indices + 1] if len(vision_start_indices) else np.array([], dtype=ids_i.dtype)
127
+ image_nums = int((vision_tokens == image_token_id).sum())
128
+ video_nums = int((vision_tokens == video_token_id).sum())
129
+
130
+ toks = ids_i.tolist()
131
+ llm_pos_ids: list[np.ndarray] = []
132
+ st = 0
133
+ remain_images, remain_videos = image_nums, video_nums
134
+ local_fix_point = fix_point
135
+
136
+ for _ in range(image_nums + video_nums):
137
+ ed_image = toks.index(image_token_id, st) if (image_token_id in toks[st:] and remain_images > 0) else len(toks) + 1
138
+ ed_video = toks.index(video_token_id, st) if (video_token_id in toks[st:] and remain_videos > 0) else len(toks) + 1
139
+ if ed_image < ed_video:
140
+ t, h, w = image_grid_thw[image_index]
141
+ image_index += 1
142
+ remain_images -= 1
143
+ ed = ed_image
144
+ else:
145
+ t, h, w = video_grid_thw[video_index]
146
+ video_index += 1
147
+ remain_videos -= 1
148
+ ed = ed_video
149
+
150
+ llm_grid_t = int(t)
151
+ llm_grid_h = int(h) // spatial_merge_size
152
+ llm_grid_w = int(w) // spatial_merge_size
153
+ text_len = ed - st
154
+ text_len -= int(skip_vision_start_token[image_index - 1])
155
+ text_len = max(0, text_len)
156
+
157
+ st_idx = (llm_pos_ids[-1].max() + 1) if llm_pos_ids else 0
158
+ llm_pos_ids.append(np.broadcast_to(np.arange(text_len) + st_idx, (3, text_len)).copy())
159
+
160
+ t_index = np.repeat(np.arange(llm_grid_t), llm_grid_h * llm_grid_w)
161
+ h_index = np.tile(np.repeat(np.arange(llm_grid_h), llm_grid_w), llm_grid_t)
162
+ w_index = np.tile(np.arange(llm_grid_w), llm_grid_t * llm_grid_h)
163
+
164
+ if int(skip_vision_start_token[image_index - 1]):
165
+ if local_fix_point > 0:
166
+ local_fix_point = local_fix_point - st_idx
167
+ llm_pos_ids.append(np.stack([t_index, h_index, w_index]) + local_fix_point + st_idx)
168
+ local_fix_point = 0
169
+ else:
170
+ llm_pos_ids.append(np.stack([t_index, h_index, w_index]) + text_len + st_idx)
171
+
172
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
173
+
174
+ if st < len(toks):
175
+ st_idx = (llm_pos_ids[-1].max() + 1) if llm_pos_ids else 0
176
+ text_len = len(toks) - st
177
+ llm_pos_ids.append(np.broadcast_to(np.arange(text_len) + st_idx, (3, text_len)).copy())
178
+
179
+ llm_positions = np.concatenate(llm_pos_ids, axis=1).reshape(3, -1)
180
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions
181
+ mrope_position_deltas.append(int(llm_positions.max() + 1 - input_ids.shape[1]))
182
+
183
+ deltas = np.asarray(mrope_position_deltas, dtype=np.int64).reshape(-1, 1)
184
+ return position_ids, deltas
185
+
186
+
187
+ def resize_pilimage(pil_image, image_size: int, patch_size: int = PATCH_SIZE, resampler=None):
188
+ """Port of HiDream-O1 utils.py:resize_pilimage.
189
+
190
+ Reduce by 2x box resamples until min dim < 2*image_size, then bicubic-fit
191
+ + center-crop to the largest patch-aligned size that doesn't exceed
192
+ image_size**2 area.
193
+ """
194
+ from PIL import Image
195
+ if resampler is None:
196
+ resampler = Image.BICUBIC
197
+ while min(pil_image.size) >= 2 * image_size:
198
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
199
+
200
+ m = patch_size
201
+ width, height = pil_image.width, pil_image.height
202
+ S_max = image_size * image_size
203
+ scale = math.sqrt(S_max / (width * height))
204
+
205
+ candidates = [
206
+ (round(width * scale) // m * m, round(height * scale) // m * m),
207
+ (round(width * scale) // m * m, math.floor(height * scale) // m * m),
208
+ (math.floor(width * scale) // m * m, round(height * scale) // m * m),
209
+ (math.floor(width * scale) // m * m, math.floor(height * scale) // m * m),
210
+ ]
211
+ candidates = sorted(candidates, key=lambda x: x[0] * x[1], reverse=True)
212
+ new_w, new_h = next((c for c in candidates if c[0] * c[1] <= S_max), candidates[-1])
213
+
214
+ s1 = width / new_w
215
+ s2 = height / new_h
216
+ if s1 < s2:
217
+ pil_image = pil_image.resize([new_w, round(height / s1)], resample=resampler)
218
+ top = (round(height / s1) - new_h) // 2
219
+ pil_image = pil_image.crop((0, top, new_w, top + new_h))
220
+ else:
221
+ pil_image = pil_image.resize([round(width / s2), new_h], resample=resampler)
222
+ left = (round(width / s2) - new_w) // 2
223
+ pil_image = pil_image.crop((left, 0, left + new_w, new_h))
224
+ return pil_image
225
+
226
+
227
+ def calculate_dimensions(max_size: int, ratio: float) -> tuple[int, int]:
228
+ """Port of HiDream-O1 utils.py:calculate_dimensions.
229
+
230
+ Pick (w, h) such that max(w*h) <= max_size**2 and w/h ≈ ratio, both
231
+ multiples of 32 (PATCH_SIZE).
232
+ """
233
+ width = math.sqrt(max_size * max_size * ratio)
234
+ height = width / ratio
235
+ width = int(width / 32) * 32
236
+ height = int(height / 32) * 32
237
+ return width, height
238
+
239
+
240
+ def patchify_ref_image(pil_image, patch: int = PATCH_SIZE) -> np.ndarray:
241
+ """Convert a PIL image (already patch-aligned) into HiDream's diffusion-side
242
+ patches: [N_patches, 3*patch*patch] with float32 in [-1, 1].
243
+
244
+ Mirrors the upstream `TENSOR_TRANSFORM` (ToTensor + Normalize 0.5/0.5).
245
+ """
246
+ arr = np.asarray(pil_image.convert("RGB"), dtype=np.float32) / 255.0 # [H, W, 3] in [0, 1]
247
+ arr = (arr - 0.5) / 0.5 # [-1, 1]
248
+ arr = arr.transpose(2, 0, 1) # [3, H, W]
249
+ return patchify(arr, patch=patch) # [N, 3*p*p]
250
+
251
+
252
+ def build_edit_text_sample(
253
+ prompt: str,
254
+ ref_image_paths: Sequence[str],
255
+ height: int,
256
+ width: int,
257
+ tokenizer,
258
+ processor,
259
+ model_config,
260
+ ) -> dict:
261
+ """Build the unified token sequence + position_ids + masks for image edit
262
+ or multi-reference subject-driven generation.
263
+
264
+ Faithful port of the multi-ref branch of HiDream-O1 pipeline.py
265
+ generate_image. Single-reference (K=1) is the well-tested path.
266
+
267
+ Returns:
268
+ input_ids [1, txt_seq_len]
269
+ position_ids [3, 1, total_seq_len]
270
+ token_types [1, total_seq_len] (0=AR, 1=tgt+tms, 2=ref)
271
+ vinput_mask [1, total_seq_len] (True where diffusion patches go)
272
+ vinput_mask_tgt_only [1, total_seq_len] (True ONLY for the tgt span; for slicing the prediction)
273
+ pixel_values [N_vision_patches, vision_patch_dim] (vision tower input)
274
+ image_grid_thw [K, 3] (vision tower grid for refs)
275
+ ref_patches [1, sum(N_ref_patches), 3*32*32] (clean ref patches for vinputs cat)
276
+ tgt_image_len int (number of target patches)
277
+ """
278
+ from PIL import Image
279
+
280
+ image_token_id = model_config.image_token_id
281
+ video_token_id = model_config.video_token_id
282
+ vision_start_token_id = model_config.vision_start_token_id
283
+ spatial_merge_size = model_config.vision_config.spatial_merge_size
284
+
285
+ ref_pils = [Image.open(p).convert("RGB") for p in ref_image_paths]
286
+ K = len(ref_pils)
287
+
288
+ if K == 1:
289
+ max_size = max(height, width)
290
+ elif K == 2:
291
+ max_size = max(height, width) * 48 // 64
292
+ elif K <= 4:
293
+ max_size = max(height, width) // 2
294
+ elif K <= 8:
295
+ max_size = max(height, width) * 24 // 64
296
+ else:
297
+ max_size = max(height, width) // 4
298
+
299
+ ref_pils_resized: list = []
300
+ ref_patch_lists: list = []
301
+ for pil in ref_pils:
302
+ pil_r = resize_pilimage(pil, max_size, PATCH_SIZE)
303
+ ref_pils_resized.append(pil_r)
304
+ ref_patch_lists.append(patchify_ref_image(pil_r))
305
+
306
+ ref_image_lens = [arr.shape[0] for arr in ref_patch_lists]
307
+ total_ref_len = sum(ref_image_lens)
308
+ ref_patches = np.concatenate(ref_patch_lists, axis=0)[None] # [1, sum(N), 3*32*32]
309
+
310
+ tgt_image_len = (height // PATCH_SIZE) * (width // PATCH_SIZE)
311
+
312
+ if K <= 4:
313
+ cond_img_size = CONDITION_IMAGE_SIZE
314
+ elif K <= 8:
315
+ cond_img_size = CONDITION_IMAGE_SIZE * 48 // 64
316
+ else:
317
+ cond_img_size = CONDITION_IMAGE_SIZE // 2
318
+
319
+ ref_pils_vlm = []
320
+ for pil_r in ref_pils_resized:
321
+ cw, ch = calculate_dimensions(cond_img_size, pil_r.width / pil_r.height)
322
+ ref_pils_vlm.append(pil_r.resize((cw, ch), resample=Image.LANCZOS))
323
+
324
+ image_grid_thw_tgt = np.asarray([[1, height // PATCH_SIZE, width // PATCH_SIZE]], dtype=np.int64)
325
+ image_grid_thw_ref = np.zeros((K, 3), dtype=np.int64)
326
+ for i, pil_r in enumerate(ref_pils_resized):
327
+ rw, rh = pil_r.size
328
+ image_grid_thw_ref[i] = [1, rh // PATCH_SIZE, rw // PATCH_SIZE]
329
+
330
+ boi_token = getattr(tokenizer, "boi_token", "<|boi_token|>")
331
+ tms_token = getattr(tokenizer, "tms_token", "<|tms_token|>")
332
+
333
+ content = [{"type": "image"} for _ in range(K)]
334
+ content.append({"type": "text", "text": prompt})
335
+ messages = [{"role": "user", "content": content}]
336
+ template_caption = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
337
+ proc = processor(text=[template_caption], images=ref_pils_vlm, padding="longest", return_tensors="pt")
338
+
339
+ input_ids_2 = np.asarray(
340
+ tokenizer.encode(boi_token + tms_token * TIMESTEP_TOKEN_NUM, add_special_tokens=False),
341
+ dtype=np.int64,
342
+ ).reshape(1, -1)
343
+ proc_input_ids = np.asarray(proc.input_ids, dtype=np.int64)
344
+ input_ids = np.concatenate([proc_input_ids, input_ids_2], axis=-1)
345
+
346
+ igthw_cond = np.asarray(proc.image_grid_thw, dtype=np.int64).copy()
347
+ for i in range(K):
348
+ igthw_cond[i, 1] //= spatial_merge_size
349
+ igthw_cond[i, 2] //= spatial_merge_size
350
+ igthw_all = np.concatenate([igthw_cond, image_grid_thw_tgt, image_grid_thw_ref], axis=0)
351
+
352
+ # Build the per-image vision-token spans appended after the text:
353
+ # tgt span (tgt_image_len tokens, first slot is vision_start)
354
+ # then for each ref: span of ref_image_lens[i] tokens, first slot vision_start
355
+ vt_pieces = []
356
+ vt_tgt = np.full((1, tgt_image_len), image_token_id, dtype=input_ids.dtype)
357
+ vt_tgt[0, 0] = vision_start_token_id
358
+ vt_pieces.append(vt_tgt)
359
+ for rl in ref_image_lens:
360
+ vt_ref = np.full((1, rl), image_token_id, dtype=input_ids.dtype)
361
+ vt_ref[0, 0] = vision_start_token_id
362
+ vt_pieces.append(vt_ref)
363
+ vision_tokens = np.concatenate(vt_pieces, axis=1)
364
+ input_ids_pad = np.concatenate([input_ids, vision_tokens], axis=-1)
365
+
366
+ position_ids, _ = get_rope_index_fix_point(
367
+ spatial_merge_size=1,
368
+ image_token_id=image_token_id,
369
+ video_token_id=video_token_id,
370
+ vision_start_token_id=vision_start_token_id,
371
+ input_ids=input_ids_pad,
372
+ image_grid_thw=igthw_all,
373
+ video_grid_thw=None,
374
+ attention_mask=None,
375
+ skip_vision_start_token=[0] * K + [1] + [1] * K,
376
+ )
377
+
378
+ txt_seq_len = input_ids.shape[-1]
379
+ all_seq_len = position_ids.shape[-1]
380
+
381
+ token_types_raw = np.zeros((1, all_seq_len), dtype=np.int64)
382
+ bgn = txt_seq_len - TIMESTEP_TOKEN_NUM
383
+ end = bgn + tgt_image_len + TIMESTEP_TOKEN_NUM
384
+ token_types_raw[0, bgn:end] = 1 # tgt span (and tms inside it)
385
+ token_types_raw[0, end: end + total_ref_len] = 2 # ref spans
386
+ token_types_raw[0, txt_seq_len - TIMESTEP_TOKEN_NUM: txt_seq_len] = 3 # tms
387
+
388
+ vinput_mask = np.logical_or(token_types_raw == 1, token_types_raw == 2)
389
+ vinput_mask_tgt_only = (token_types_raw == 1) # excludes tms (=3) and refs (=2)
390
+ token_types_bin = (token_types_raw > 0).astype(np.int64)
391
+
392
+ # Pixel values from the processor are pre-flattened patches of vision-tower size.
393
+ # Shape (after np conversion) is [num_vision_patches, vision_patch_dim].
394
+ pixel_values_np = np.asarray(proc.pixel_values, dtype=np.float32)
395
+ image_grid_thw_for_visual = np.asarray(proc.image_grid_thw, dtype=np.int64)
396
+
397
+ return {
398
+ "input_ids": input_ids,
399
+ "position_ids": position_ids,
400
+ "token_types": token_types_bin,
401
+ "vinput_mask": vinput_mask,
402
+ "vinput_mask_tgt_only": vinput_mask_tgt_only,
403
+ "pixel_values": pixel_values_np,
404
+ "image_grid_thw": image_grid_thw_for_visual,
405
+ "ref_patches": ref_patches,
406
+ "tgt_image_len": tgt_image_len,
407
+ }
408
+
409
+
410
+ def build_attention_mask(token_types_bin: np.ndarray, dtype_min: float) -> np.ndarray:
411
+ """text rows causal, gen rows bidirectional. Returns [B, 1, S, S] additive."""
412
+ B, S = token_types_bin.shape
413
+ mask = np.full((B, 1, S, S), dtype_min, dtype=np.float32)
414
+ causal_2d = np.triu(np.full((S, S), dtype_min, dtype=np.float32), k=1)
415
+ for b in range(B):
416
+ m = causal_2d.copy()
417
+ gen = token_types_bin[b].astype(bool)
418
+ m[gen, :] = 0.0
419
+ mask[b, 0] = m
420
+ return mask
scripts/hidream_o1/postprocess.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Standalone post-process: take an existing HiDream PNG and smooth the
3
+ 32-pixel patch grid seams. NO model load — just numpy + PIL.
4
+
5
+ Usage:
6
+ postprocess.py <input.png> <output.png> [--radius N] [--strength F]
7
+
8
+ Strategy:
9
+ For each seam line (x, y multiples of PATCH_SIZE), apply a 1D gaussian
10
+ blur perpendicular to the seam, blended with the original by --strength.
11
+ The blur kernel is symmetric, so flat regions get more smoothing than
12
+ sharp edges (which the gaussian's centre weight preserves).
13
+
14
+ --radius blur radius in pixels (default 3)
15
+ --strength blend weight 0-1 (default 0.7 = 70% blurred + 30% original)
16
+ """
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import sys
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ from PIL import Image
25
+
26
+ PATCH_SIZE = 32
27
+
28
+
29
+ def gaussian_kernel_1d(radius: int) -> np.ndarray:
30
+ """Build a normalised 1D gaussian kernel with sigma=radius/2."""
31
+ sigma = radius / 2.0
32
+ x = np.arange(-radius, radius + 1, dtype=np.float32)
33
+ k = np.exp(-0.5 * (x / sigma) ** 2)
34
+ return k / k.sum()
35
+
36
+
37
+ def smooth_seams(rgb: np.ndarray, radius: int = 3, strength: float = 0.7) -> np.ndarray:
38
+ """Smooth horizontal+vertical patch seams via local gaussian blur.
39
+
40
+ The blur is applied to the SEAM rows/cols only, then alpha-blended back
41
+ by `strength`. Non-seam pixels are untouched.
42
+ """
43
+ out = rgb.astype(np.float32).copy()
44
+ H, W, C = rgb.shape
45
+ kernel = gaussian_kernel_1d(radius) # length 2*radius+1
46
+
47
+ # --- Horizontal seams (rows at y in {patch, 2*patch, ...}) ---
48
+ # We smooth the 2 rows on each side of each seam (4 rows total per seam).
49
+ for y in range(PATCH_SIZE, H, PATCH_SIZE):
50
+ for offset in (-2, -1, 0, 1):
51
+ yy = y + offset
52
+ if not (0 <= yy < H):
53
+ continue
54
+ lo = max(0, yy - radius)
55
+ hi = min(H, yy + radius + 1)
56
+ k_lo = radius - (yy - lo)
57
+ k_hi = radius + (hi - yy)
58
+ k = kernel[k_lo:k_hi]
59
+ k = k / k.sum()
60
+ band = out[lo:hi] # [n, W, C]
61
+ blurred = (band * k[:, None, None]).sum(axis=0)
62
+ out[yy] = (1 - strength) * out[yy] + strength * blurred
63
+
64
+ # --- Vertical seams (cols at x in {patch, 2*patch, ...}) ---
65
+ for x in range(PATCH_SIZE, W, PATCH_SIZE):
66
+ for offset in (-2, -1, 0, 1):
67
+ xx = x + offset
68
+ if not (0 <= xx < W):
69
+ continue
70
+ lo = max(0, xx - radius)
71
+ hi = min(W, xx + radius + 1)
72
+ k_lo = radius - (xx - lo)
73
+ k_hi = radius + (hi - xx)
74
+ k = kernel[k_lo:k_hi]
75
+ k = k / k.sum()
76
+ band = out[:, lo:hi] # [H, n, C]
77
+ blurred = (band * k[None, :, None]).sum(axis=1)
78
+ out[:, xx] = (1 - strength) * out[:, xx] + strength * blurred
79
+
80
+ return np.clip(out, 0, 255).astype(np.uint8)
81
+
82
+
83
+ def main():
84
+ ap = argparse.ArgumentParser()
85
+ ap.add_argument("input")
86
+ ap.add_argument("output")
87
+ ap.add_argument("--radius", type=int, default=3)
88
+ ap.add_argument("--strength", type=float, default=0.7)
89
+ args = ap.parse_args()
90
+
91
+ inp = Path(args.input)
92
+ if not inp.exists():
93
+ sys.exit(f"input not found: {inp}")
94
+
95
+ rgb = np.array(Image.open(inp).convert("RGB"))
96
+ H, W = rgb.shape[:2]
97
+ print(f"{inp.name}: {W}x{H}, {(W // PATCH_SIZE) - 1} vertical + {(H // PATCH_SIZE) - 1} horizontal seams")
98
+ print(f"smoothing with radius={args.radius}, strength={args.strength}...")
99
+
100
+ out = smooth_seams(rgb, radius=args.radius, strength=args.strength)
101
+ Image.fromarray(out).save(args.output)
102
+ print(f"saved -> {args.output}")
103
+
104
+
105
+ if __name__ == "__main__":
106
+ main()
scripts/hidream_o1/realism_batch.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Anti-AI realism batch — film stocks, documentary photographers, natural light,
3
+ # skin texture cues. BF16 weights (no quantization).
4
+ set -euo pipefail
5
+
6
+ LAB="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
7
+ PY="$LAB/.venv/bin/python"
8
+ MODEL="$LAB/mlx_models/hidream-o1-dev-bf16"
9
+ OUT="$LAB/sample_outputs/showcase_realism"
10
+ mkdir -p "$OUT"
11
+
12
+ run() {
13
+ local name="$1" w="$2" h="$3" prompt="$4" seed="${5:-42}"
14
+ echo "=== $name ${w}x${h} (seed=$seed) ==="
15
+ cd "$LAB" && "$PY" scripts/hidream_o1/generate_hidream_o1_mlx.py \
16
+ --model-path "$MODEL" \
17
+ --prompt "$prompt" \
18
+ --width "$w" --height "$h" \
19
+ --output "$OUT/$name.png" \
20
+ --seed "$seed" 2>&1 | grep -E "loaded|using|generation:|saved" | tail -3
21
+ echo ""
22
+ }
23
+
24
+ # Verticals (1440x2560)
25
+ run "01_barista_morning" 1440 2560 \
26
+ "candid documentary photograph of a tired thirties barista pulling an espresso shot at 6am, natural skin with visible pores and faint under-eye shadows, slight grease on her apron, hair half-loose from her bun, warm overhead pendant lamp lighting only, shot on Kodak Portra 400 film with visible grain, William Eggleston color palette, no retouching, real and lived-in" \
27
+ 111
28
+
29
+ run "02_workshop_oldman" 1440 2560 \
30
+ "documentary portrait of a weathered seventy year old man in a cluttered bike workshop, fixing a vintage racing bicycle, deeply lined hands streaked with grease, faded blue work shirt, reading glasses low on his nose, a half-drunk mug of coffee on the bench, natural overcast daylight from the open garage door, shot on Kodak Vision3 250D 35mm cinema film, Mary Ellen Mark aesthetic" \
31
+ 222
32
+
33
+ run "03_kitchen_morning" 1440 2560 \
34
+ "candid morning photo of a woman in her late thirties at a wooden kitchen table holding a chipped ceramic coffee mug with both hands, no makeup, hair messy from sleep, freckles and faint laugh lines visible, wearing an oversized grey sweater, soft diffuse light from a north-facing window beside her, slight steam rising from the mug, half-eaten toast on a plate, lived-in apartment in soft focus, Saul Leiter colour mood, Cinestill 800T film grain" \
35
+ 333
36
+
37
+ # Wides (3104x1312)
38
+ run "04_bar_friends" 3104 1312 \
39
+ "ultrawide naturalistic photo of two male friends in their forties slumped in a worn leather booth at a dim Brooklyn dive bar around 1am, half-finished beers and a bowl of stale peanuts on the table between them, one mid-laugh wearing a faded Carhartt jacket, the other listening with a tired smile in a wrinkled flannel, neither looking at the camera, single tungsten bulb above the booth as the only light source, shot on Cinestill 800T film with halation around the bulb, Wim Wenders mood, deep shadows, no airbrushing" \
40
+ 444
41
+
42
+ run "05_construction_lunch" 3104 1312 \
43
+ "ultrawide documentary photo of three construction workers sitting on an unfinished concrete floor of a high-rise during their lunch break, sunburnt necks, dust on their boots and arms, eating from foil-wrapped sandwiches and thermos cups, the city skyline visible through the open building structure behind them, harsh midday sun casting hard shadows, Sebastião Salgado documentary aesthetic, shot on a Leica with Kodak Tri-X black and white film, raw and dignified" \
44
+ 555
45
+
46
+ run "06_painter_studio" 3104 1312 \
47
+ "ultrawide editorial photo of a female painter in her fifties standing in her cluttered Brooklyn warehouse studio, paint smeared on her overalls and forearms, holding a long brush in her right hand and a rag in her left, looking off-frame in thought, a half-finished large abstract canvas leaning behind her, north-facing factory windows providing cool diffuse light, wooden floor stained with decades of dropped pigment, Annie Leibovitz Vanity Fair aesthetic, shot on Hasselblad medium format with natural skin tone retention" \
48
+ 666
49
+
50
+ echo "=== batch complete ==="
51
+ ls -la "$OUT"
scripts/hidream_o1/showcase_batch.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Showcase battery: diverse prompts to characterise HiDream-O1-Image-Dev Q8.
3
+ # Sequential. Each generates a single 1024x1024 PNG.
4
+ set -euo pipefail
5
+
6
+ LAB="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
7
+ PY="$LAB/.venv/bin/python"
8
+ # Quant comes from $1 (default q6 — sweet spot). Pass q8 to use Q8 for safety-margin runs.
9
+ QUANT="${1:-q6}"
10
+ MODEL="$LAB/mlx_models/hidream-o1-dev-${QUANT}"
11
+ OUT="$LAB/sample_outputs/showcase_${QUANT}"
12
+ mkdir -p "$OUT"
13
+ echo "showcase quant=${QUANT}, model=${MODEL}, out=${OUT}"
14
+
15
+ run() {
16
+ local name="$1" prompt="$2" seed="${3:-42}"
17
+ echo "=== $name (seed=$seed) ==="
18
+ cd "$LAB" && /usr/bin/time -l "$PY" scripts/hidream_o1/generate_hidream_o1_mlx.py \
19
+ --model-path "$MODEL" \
20
+ --prompt "$prompt" \
21
+ --width 1024 --height 1024 \
22
+ --output "$OUT/$name.png" \
23
+ --seed "$seed" 2>&1 | grep -E "loaded|using|generation:|saved|maximum resident" | tail -6
24
+ echo ""
25
+ }
26
+
27
+ run "01_portrait_photo" "studio photo of an elderly Japanese tea master with a wise smile, holding a ceramic teacup, gentle natural light, shallow depth of field, sharp focus on eyes, 85mm lens" 8
28
+ run "02_anime" "anime girl with pink hair sitting on the rooftop of a Tokyo skyscraper at dusk, neon city lights below, cherry blossom petals floating, soft watercolor style" 19
29
+ run "03_macro_nature" "extreme macro photo of a single dewdrop on a spiderweb at dawn, tiny rainbow refractions, blurred leaf background, ultra sharp focus" 31
30
+ run "04_architecture" "interior of a futuristic library, towering bookshelves, holographic displays, warm golden light streaming through stained glass windows, wide angle" 5
31
+ run "05_surreal" "a giant blue whale floating in the clouds above a vast desert landscape, magical realism, oil painting style, golden hour" 27
32
+ run "06_food_flatlay" "overhead flat lay of a rustic italian breakfast, golden croissants, espresso cup, fresh berries, marble surface, soft morning light, food photography" 53
33
+ run "07_action_cinematic" "samurai warrior mid leap with katana drawn, cherry blossoms swirling around him, mountain backdrop at sunset, dynamic action, cinematic film still" 71
34
+ run "08_fantasy_creature" "majestic dragon perched on a crystal mountain peak, iridescent scales reflecting aurora borealis, snow swirling around, dramatic dramatic lighting, fantasy art" 88
35
+ run "09_wildlife" "close-up portrait of a snow leopard staring directly at the camera, falling snow flakes, mountain background, national geographic style, ultra sharp" 17
36
+ run "10_text_render" "vintage diner neon sign reading BLOOM CAFE in glowing pink letters at night, retro americana 1950s style, rainy street reflection, cinematic" 64
37
+
38
+ echo "=== showcase complete ==="
39
+ ls -la "$OUT"