asdf98 commited on
Commit
ce0086e
Β·
verified Β·
1 Parent(s): cb1106e

Update README with Colab notebook link

Browse files
Files changed (1) hide show
  1. README.md +75 -153
README.md CHANGED
@@ -24,6 +24,20 @@ pipeline_tag: text-to-image
24
  <img src="https://img.shields.io/badge/License-Apache%202.0-orange" alt="license">
25
  </p>
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ## 🎯 Why IRIS?
28
 
29
  Current image generation models face critical limitations:
@@ -68,67 +82,43 @@ Image ──→ HaarDWT ──→ WaveletVAE ──→ zβ‚€ [CΓ—H/16Γ—W/16]
68
  ### πŸ”¬ Key Innovations
69
 
70
  #### 1. GRFM (Gated Recurrent Fourier Mixer) β€” Novel Token Mixing
71
- A novel token mixing mechanism that fuses three complementary pathways:
72
 
73
  - **Fourier Global Pathway** (O(N log N)): `RFFT2 β†’ Block-diagonal MLP β†’ SoftShrink β†’ IRFFT2`
74
- - Captures global textures and patterns via frequency-domain processing
75
- - Soft-shrinkage enforces sparsity (images are sparse in frequency domain)
76
-
77
- - **Gated Linear Recurrence** (O(N)): Bidirectional RG-LRU scan
78
- - `h_t = a_t βŠ™ h_{t-1} + √(1 - a_tΒ²) βŠ™ (i_t βŠ™ x_t)`
79
- - Captures sequential dependencies with O(1) state per position
80
-
81
- - **Manhattan Spatial Gate**: Per-head learnable spatial decay
82
- - `D_{nm} = Ξ³_head^(|x_n-x_m| + |y_n-y_m|)`
83
- - Provides 2D inductive bias with multi-scale receptive fields
84
-
85
- The three pathways are merged via **learned adaptive gating**:
86
  ```
87
  output = gate Γ— x_fourier + (1 - gate) Γ— x_recurrent + Ξ± Γ— x_spatial
88
  ```
89
 
90
  #### 2. Recurrent Depth Core (Huginn paradigm, novel for images)
91
- - The core denoising block uses **shared weights** across all iterations
92
- - A 4-layer core block iterated 8Γ— = 32 effective layers from just 4 layers of parameters
93
- - **Budget-adaptive inference**: 4 iterations for mobile speed, 16 for maximum quality
94
- - Iteration-aware conditioning via adaLN: the model learns different behavior at each depth
95
 
96
  #### 3. Wavelet-Frequency Latent Space
97
- - Haar DWT preprocesses images before VAE encoding (lossless, invertible)
98
- - Latent space preserves frequency structure (LL=structure, LH/HL/HH=details)
99
- - 16Γ— total spatial compression with wavelet transform
100
 
101
  #### 4. Dual-Axis Recurrence (Novel)
102
- - Recurrence over **noise schedule** (diffusion steps, outer loop)
103
- - Recurrence over **computational depth** (core iterations, inner loop)
104
- - New paradigm: both axes share the same network, with different conditioning
105
 
106
  ## πŸ“Š Model Variants
107
 
108
- | Variant | Generator Params | Total System | Memory (fp16) | Mobile Fit |
109
- |---------|-----------------|-------------|---------------|------------|
110
- | **IRIS-Tiny** | 19M | ~60M | 545 MB | βœ… Ultra-mobile |
111
- | **IRIS-Small** | 47M | ~88M | 597 MB | βœ… Mobile |
112
- | **IRIS-Base** | 135M | ~175M | 760 MB | βœ… Consumer GPU |
113
-
114
- ### Effective Capacity via Recurrent Depth
115
-
116
- | Model | Unique Params | r=4 iterations | r=8 | r=12 | r=16 |
117
- |-------|--------------|----------------|-----|------|------|
118
- | IRIS-Small (48M) | 48M | ~143M effective | ~270M effective | ~397M effective | ~524M effective |
119
-
120
- **48M parameters behave like 270-524M** depending on iteration budget!
121
 
122
  ## πŸ”§ Quick Start
123
 
124
  ```python
125
  from iris_model import create_iris_small
 
126
 
127
- # Create model
128
  model = create_iris_small()
129
-
130
- # Generate with text conditioning
131
- import torch
132
  text_tokens = torch.randn(1, 77, 768) # Replace with CLIP-L/14 embeddings
133
 
134
  # Fast mobile inference (4 iterations, 4 steps)
@@ -136,149 +126,81 @@ images = model.generate(text_tokens, num_steps=4, num_iterations=4)
136
 
137
  # Quality inference (8 iterations, 4 steps)
138
  images = model.generate(text_tokens, num_steps=4, num_iterations=8)
139
-
140
- # Training step (rectified flow)
141
- images_input = torch.randn(1, 3, 512, 512)
142
- result = model.train_step(images_input, text_tokens)
143
- print(f"Loss: {result['loss'].item():.4f}")
144
  ```
145
 
146
  ## πŸ“ Mathematical Foundations
147
 
148
  ### Rectified Flow Training
149
  ```
150
- z_t = (1-t)Β·zβ‚€ + tΒ·Ξ΅ (linear interpolation)
151
- v_target = Ξ΅ - zβ‚€ (constant velocity field)
152
- L = w(t) Β· ||v_ΞΈ(z_t, t, c) - v_target||Β²
153
- w(t) = t/(1-t) (SNR reweighting)
154
- t ~ Logit-Normal(0, 1) (concentrate on hard timesteps)
155
  ```
156
 
157
- ### GRFM: Fourier Pathway
158
  ```
159
- x_freq = RFFT2(x, dim=(H,W)) # O(N log N) via FFT
160
- x_freq = BlockDiagMLP(x_freq) # Block-diagonal complex-valued MLP
161
- x_freq = SoftShrink(x_freq, Ξ») # Sparsity: S_Ξ»(x) = sign(x)Β·max(|x|-Ξ», 0)
162
- x_out = IRFFT2(x_freq) # Back to spatial domain
163
- ```
164
-
165
- ### GRFM: RG-LRU Gated Recurrence Pathway
166
- ```
167
- a_t = Οƒ(Ξ›)^(cΒ·Οƒ(W_aΒ·x_t)) # Data-dependent decay (c=8)
168
- i_t = Οƒ(W_xΒ·x_t) # Input gate
169
- h_t = a_t βŠ™ h_{t-1} + √(1-a_tΒ²) βŠ™ (i_t βŠ™ x_t) # Variance-preserving recurrence
170
- ```
171
-
172
- ### GRFM: Manhattan Spatial Decay Pathway
173
- ```
174
- D_{nm} = Ξ³_head^(|row_n - row_m| + |col_n - col_m|) # Manhattan distance matrix
175
- γ_head ∈ (0, 1), learned per attention head # Multi-scale receptive fields
176
  ```
177
 
178
  ## πŸ‹οΈ Training Recipe
179
 
180
- ### 5-Stage Pipeline
181
-
182
- | Stage | Data | Objective | Est. Cost |
183
- |-------|------|-----------|-----------|
184
- | 1. VAE | ImageNet + CC3M | Reconstruction + KL + Wavelet frequency loss | 20 GPU-hrs |
185
- | 2. Class-Cond | ImageNet 256px | Rectified Flow velocity matching | 100 GPU-hrs |
186
- | 3. Text-Image | CC3M/CC12M (VLM-recaptioned) | RF + cross-attention on CLIP text | 200 GPU-hrs |
187
- | 4. Aesthetic | JourneyDB + curated LAION | Fine-tune with high-aesthetic data | 50 GPU-hrs |
188
- | 5. Distill | Self-distillation | Consistency distillation β†’ 1-4 steps | 30 GPU-hrs |
189
-
190
- **Total: ~400 A100 GPU-hours (~$1,600)**
191
 
192
- ### Key Training Tricks (sourced from literature)
193
- - **Logit-normal timestep sampling** (SD3): focuses compute on hard intermediate timesteps
194
- - **adaLN-Zero initialization**: zero-init output gates for stable residual learning start
195
- - **Random iteration sampling**: during training, randomly sample r ∈ {4,6,8,10,12} for robustness
196
- - **Long skip connections** (Diffusion-RWKV): connect shallow features to output for gradient flow
197
- - **QK-normalization** (SANA-Sprint): prevents attention collapse at scale
198
- - **3-stage training decomposition** (PixArt-Ξ±): pixel priors β†’ text alignment β†’ aesthetics
199
-
200
- ## πŸ”„ Extensions for Image Editing
201
-
202
- The iterative core naturally supports editing tasks:
203
-
204
- - **Inpainting**: Mask latent tokens, condition core iterations on unmasked context
205
- - **Super-Resolution**: Encode low-res via WaveletVAE, condition generation on LL subband
206
- - **Prompt-based Editing**: SDEdit-style partial denoising with modified text conditioning
207
- - **ControlNet**: Lightweight adapter in Prelude for spatial control signals (edges, depth, pose)
208
-
209
- ### Adaptive Quality β€” Same Model, Different Budgets
210
- ```python
211
- # 🏎️ Ultra-fast mobile (4 core iterations Γ— 1 step = 4 total NFE)
212
- images = model.generate(text, num_steps=1, num_iterations=4)
213
-
214
- # πŸ“± Balanced mobile (4 iterations Γ— 4 steps = 16 NFE)
215
- images = model.generate(text, num_steps=4, num_iterations=4)
216
-
217
- # πŸ–₯️ Quality desktop (8 iterations Γ— 4 steps = 32 NFE)
218
- images = model.generate(text, num_steps=4, num_iterations=8)
219
-
220
- # 🎨 Maximum quality (16 iterations Γ— 8 steps = 128 NFE)
221
- images = model.generate(text, num_steps=8, num_iterations=16)
222
- ```
223
 
224
  ## πŸ“š Research Foundations
225
 
226
- IRIS draws inspiration from and synthesizes ideas across multiple domains:
227
-
228
- | Concept | Source Paper | How IRIS Uses It |
229
- |---------|-------------|-----------------|
230
- | Recurrent Depth | Huginn (2502.05171) | Prelude-Core-Coda shared-weight architecture |
231
- | Fourier Mixing | AFNO (2111.13587) | Block-diagonal FFT pathway in GRFM |
232
- | Gated Recurrence | Griffin RG-LRU (2402.19427) | Bidirectional scan pathway in GRFM |
233
- | Manhattan Decay | RMT (2309.11523) | Spatial inductive bias pathway in GRFM |
234
- | Wavelet Diffusion | WaveDiff (2211.16152) | Haar DWT preprocessing + frequency-aware latent |
235
- | Rectified Flow | RF (2209.03003), SD3 (2403.03206) | Straight ODE trajectories, logit-normal sampling |
236
- | Consistency Models | CM (2303.01469) | 1-4 step generation via self-consistency |
237
- | adaLN-Zero | DiT (2212.09748) | Stable conditioning via zero-initialized gates |
238
- | Efficient Training | PixArt-Ξ± (2310.00426) | 3-stage training decomposition, adaLN-single |
239
- | Mobile Diffusion | SnapGen (2412.09619) | Depthwise separable convolutions, tiny VAE decoder |
240
- | Bidirectional scan | Diffusion-RWKV (2404.04478) | Long skip connections, multi-direction scanning |
241
- | State Space Vision | VSSD (2407.18559) | Non-causal state-space design inspiration |
242
- | Mamba SSM | Mamba-2/SSD (2405.21060) | Selective state-space duality principles |
243
- | Extended LSTM | xLSTM/mLSTM (2405.04517) | Matrix memory concept for spatial features |
244
- | Frequency diffusion | DCTdiff (2412.15032) | Perceptual alignment via frequency-domain generation |
245
-
246
- ## πŸ“„ Files in this Repository
247
 
248
  | File | Description |
249
  |------|-------------|
250
- | `iris_model.py` | Complete architecture implementation (~1200 lines) |
251
- | `train_iris.py` | Full training pipeline (all 5 stages) |
252
- | `test_iris.py` | Comprehensive validation test suite (9 tests) |
253
- | `ARCHITECTURE.md` | Detailed architecture specification with math |
 
254
 
255
  ## βœ… Verified Properties
256
 
257
- All verified via automated test suite:
258
-
259
- - βœ… Haar DWT/IDWT roundtrip is lossless (error < 1e-5)
260
- - βœ… WaveletVAE encodes 256Γ—256β†’16Γ—16 latent (48Γ— compression)
261
- - βœ… GRFM forward/backward pass correct, all gradients flow
262
- - βœ… Generator handles variable iteration counts (2, 4, 8)
263
- - βœ… Full training step produces valid loss with gradients
264
- - βœ… End-to-end generation pipeline produces correctly-shaped output
265
- - βœ… Different iteration counts produce different outputs (adaptive compute)
266
- - βœ… IRIS-Tiny fits in 545 MB total inference memory (< 3GB βœ…)
267
- - βœ… IRIS-Small fits in 597 MB total inference memory (< 3GB βœ…)
268
- - βœ… 16Γ— iteration gives 10.9Γ— effective capacity from same params
269
 
270
  ## πŸ“œ License
271
 
272
- Apache 2.0 β€” Free for both research and commercial use.
273
-
274
- ## Citation
275
 
276
  ```bibtex
277
  @misc{iris2026,
278
  title={IRIS: Iterative Recurrent Image Synthesis for Mobile-First Image Generation},
279
  year={2026},
280
- note={Novel architecture combining Gated Recurrent Fourier Mixing,
281
- Recurrent Depth, and Wavelet-Frequency Latent Space for efficient
282
- text-to-image generation under 3GB RAM}
283
  }
284
  ```
 
24
  <img src="https://img.shields.io/badge/License-Apache%202.0-orange" alt="license">
25
  </p>
26
 
27
+ ## πŸš€ Train It Now!
28
+
29
+ **[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/)** ← Download `IRIS_Training_Notebook.ipynb` from this repo and upload to Colab!
30
+
31
+ **Quick start**: Download [`IRIS_Training_Notebook.ipynb`](./IRIS_Training_Notebook.ipynb), open it in Colab (or Kaggle), enable GPU, and run all cells. Trains end-to-end in ~2-3 hours on a free T4.
32
+
33
+ The notebook includes:
34
+ - πŸ“¦ Auto-downloads architecture code from this repo
35
+ - 🎨 Trains on Pokémon BLIP Captions dataset (833 image-caption pairs)
36
+ - πŸ”¬ Stage 1: Wavelet VAE training with frequency-aware loss
37
+ - ⚑ Stage 2: Rectified Flow generator training with CLIP conditioning
38
+ - πŸ“Š Visualizations: reconstructions, generated samples, loss curves, GRFM internals
39
+ - πŸ’Ύ Checkpoint saving for continued training
40
+
41
  ## 🎯 Why IRIS?
42
 
43
  Current image generation models face critical limitations:
 
82
  ### πŸ”¬ Key Innovations
83
 
84
  #### 1. GRFM (Gated Recurrent Fourier Mixer) β€” Novel Token Mixing
85
+ Three complementary pathways fused via learned adaptive gating:
86
 
87
  - **Fourier Global Pathway** (O(N log N)): `RFFT2 β†’ Block-diagonal MLP β†’ SoftShrink β†’ IRFFT2`
88
+ - **Gated Linear Recurrence** (O(N)): Bidirectional RG-LRU scan with variance-preserving updates
89
+ - **Manhattan Spatial Gate**: Per-head learnable spatial decay `D_{nm} = Ξ³^Manhattan(n,m)`
90
+
 
 
 
 
 
 
 
 
 
91
  ```
92
  output = gate Γ— x_fourier + (1 - gate) Γ— x_recurrent + Ξ± Γ— x_spatial
93
  ```
94
 
95
  #### 2. Recurrent Depth Core (Huginn paradigm, novel for images)
96
+ - Shared-weight core block iterated 4-16Γ— (same model, adaptive quality!)
97
+ - 4-layer block Γ— 8 iterations = 32 effective layers from just 4 layers of params
98
+ - **48M unique params β†’ 270-524M effective capacity**
 
99
 
100
  #### 3. Wavelet-Frequency Latent Space
101
+ - Haar DWT preprocessing preserves frequency structure in latent space
102
+ - 16Γ— total spatial compression (lossless wavelet + learned VAE)
 
103
 
104
  #### 4. Dual-Axis Recurrence (Novel)
105
+ - Recurrence over noise schedule (diffusion) AND computational depth (core iterations)
 
 
106
 
107
  ## πŸ“Š Model Variants
108
 
109
+ | Variant | Generator Params | Total Memory (fp16) | Mobile Fit |
110
+ |---------|-----------------|---------------------|------------|
111
+ | **IRIS-Tiny** | 19M | 545 MB | βœ… Ultra-mobile |
112
+ | **IRIS-Small** | 47M | 597 MB | βœ… Mobile |
113
+ | **IRIS-Base** | 135M | 760 MB | βœ… Consumer GPU |
 
 
 
 
 
 
 
 
114
 
115
  ## πŸ”§ Quick Start
116
 
117
  ```python
118
  from iris_model import create_iris_small
119
+ import torch
120
 
 
121
  model = create_iris_small()
 
 
 
122
  text_tokens = torch.randn(1, 77, 768) # Replace with CLIP-L/14 embeddings
123
 
124
  # Fast mobile inference (4 iterations, 4 steps)
 
126
 
127
  # Quality inference (8 iterations, 4 steps)
128
  images = model.generate(text_tokens, num_steps=4, num_iterations=8)
 
 
 
 
 
129
  ```
130
 
131
  ## πŸ“ Mathematical Foundations
132
 
133
  ### Rectified Flow Training
134
  ```
135
+ z_t = (1-t)Β·zβ‚€ + tΒ·Ξ΅, v_target = Ξ΅ - zβ‚€
136
+ L = w(t) Β· ||v_ΞΈ(z_t, t, c) - v_target||Β², w(t) = t/(1-t)
137
+ t ~ Logit-Normal(0, 1)
 
 
138
  ```
139
 
140
+ ### GRFM Pathways
141
  ```
142
+ Fourier: RFFT2 β†’ BlockDiagMLP β†’ SoftShrink(Ξ») β†’ IRFFT2 [O(N log N)]
143
+ Recurrence: h_t = a_tβŠ™h_{t-1} + √(1-a_tΒ²)βŠ™(i_tβŠ™x_t) [O(N)]
144
+ Spatial: D_{nm} = Ξ³^(|row_n-row_m| + |col_n-col_m|) [O(NΓ—window)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  ```
146
 
147
  ## πŸ‹οΈ Training Recipe
148
 
149
+ | Stage | Data | Est. Cost |
150
+ |-------|------|-----------|
151
+ | 1. VAE | ImageNet + CC3M | 20 GPU-hrs |
152
+ | 2. Class-Cond | ImageNet 256px | 100 GPU-hrs |
153
+ | 3. Text-Image | CC3M/CC12M | 200 GPU-hrs |
154
+ | 4. Aesthetic | JourneyDB | 50 GPU-hrs |
155
+ | 5. Distill | Self-distill | 30 GPU-hrs |
 
 
 
 
156
 
157
+ **Total: ~400 A100 GPU-hours (~$1,600)** | Stages 1-2 run on free Colab T4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  ## πŸ“š Research Foundations
160
 
161
+ | Concept | Source | How Used |
162
+ |---------|--------|----------|
163
+ | Recurrent Depth | Huginn (2502.05171) | Prelude-Core-Coda |
164
+ | Fourier Mixing | AFNO (2111.13587) | GRFM pathway |
165
+ | Gated Recurrence | Griffin RG-LRU (2402.19427) | GRFM pathway |
166
+ | Manhattan Decay | RMT (2309.11523) | GRFM pathway |
167
+ | Wavelet Diffusion | WaveDiff (2211.16152) | Latent space |
168
+ | Rectified Flow | RF (2209.03003), SD3 | Training objective |
169
+ | Consistency Models | CM (2303.01469) | Distillation |
170
+ | adaLN-Zero | DiT (2212.09748) | Conditioning |
171
+ | Efficient Training | PixArt-Ξ± (2310.00426) | Training recipe |
172
+ | Mobile Design | SnapGen (2412.09619) | DWSConv, tiny VAE |
173
+
174
+ ## πŸ“„ Files
 
 
 
 
 
 
 
175
 
176
  | File | Description |
177
  |------|-------------|
178
+ | **`IRIS_Training_Notebook.ipynb`** | πŸ”₯ **Complete Colab/Kaggle training notebook** |
179
+ | `iris_model.py` | Architecture implementation (~1200 lines) |
180
+ | `train_iris.py` | CLI training pipeline (all 5 stages) |
181
+ | `test_iris.py` | Validation test suite (9 tests, all passing) |
182
+ | `ARCHITECTURE.md` | Detailed math specification |
183
 
184
  ## βœ… Verified Properties
185
 
186
+ - βœ… Haar DWT/IDWT roundtrip lossless (error < 1e-5)
187
+ - βœ… WaveletVAE: 256Γ—256β†’16Γ—16 latent (48Γ— compression)
188
+ - βœ… GRFM forward/backward correct, all gradients flow
189
+ - βœ… Variable iteration counts work (adaptive compute)
190
+ - βœ… Full training step with rectified flow loss
191
+ - βœ… End-to-end generation pipeline
192
+ - βœ… IRIS-Tiny: **545 MB** total inference (< 3GB βœ…)
193
+ - βœ… IRIS-Small: **597 MB** total inference (< 3GB βœ…)
194
+ - βœ… 16Γ— iteration gives **10.9Γ—** effective capacity
 
 
 
195
 
196
  ## πŸ“œ License
197
 
198
+ Apache 2.0
 
 
199
 
200
  ```bibtex
201
  @misc{iris2026,
202
  title={IRIS: Iterative Recurrent Image Synthesis for Mobile-First Image Generation},
203
  year={2026},
204
+ note={Novel architecture: GRFM + Recurrent Depth + Wavelet Latent Space}
 
 
205
  }
206
  ```