# IRIS: Iterative Refinement Image Synthesizer A mobile-first image generation architecture designed from recent research (2025-2026). **17/17 tests pass.** ## Colab Quick Start (Free Tier T4 — Just Run It) **One script, real dataset, real training, ~20 minutes total.** 1. Open [Google Colab](https://colab.research.google.com) 2. Set runtime to **T4 GPU** (Runtime → Change runtime type → T4 GPU) 3. Create a new cell and paste: ```python !wget -q https://huggingface.co/asdf98/iris-image-gen/resolve/main/colab_train_iris.py %run colab_train_iris.py ``` **What happens:** - Installs deps (~30s) - Downloads IRIS code + DC-AE encoder + text encoder (~2min) - Encodes 833 Pokemon images to latents (~2min on T4) - Trains IRIS-Small (40M params) for 3000 steps (~15min on T4) - Generates sample images and plots loss curve - Saves checkpoint to `./iris_checkpoints/` **Colab free tier (2025):** T4 GPU (16GB VRAM), ~12.7GB RAM, PyTorch 2.5+ ### VRAM Budget (Colab T4, 16 GB) | Phase | Component | VRAM | |-------|-----------|------| | Encoding | DC-AE (fp16) | ~2.4 GB | | Encoding | text encoder | ~0.35 GB | | Training | IRIS-Small (40M, fp32) | ~0.16 GB | | Training | Optimizer states | ~0.48 GB | | Training | Batch + activations (BS=16, R=3, checkpointed) | ~2-4 GB | | **Peak** | **Encoding phase** | **~3 GB** | | **Peak** | **Training phase** | **~5 GB** | Encoders are freed before training starts → plenty of headroom. ### Configs for Different Hardware | Hardware | Config | Command | |----------|--------|---------| | Colab Free (T4 16GB) | `iris-small` | Default — just run the script | | Colab Pro (A100 40GB) | `iris-medium` | Change `get_model_config("iris-medium")` and `text_dim=384` | | Kaggle (P100 16GB) | `iris-small` | Same as Colab free tier | | Local RTX 3090 (24GB) | `iris-base` | Use `iris-base` config, BS=32 | ### Dependencies (All pip-installable, no special builds) ``` torch>=2.0 # preinstalled in Colab diffusers>=0.32.0 # for AutoencoderDC sentence-transformers # for text encoding (all-MiniLM-L6-v2, 87MB) datasets # for HF dataset loading accelerate # diffusers dependency huggingface_hub # for downloading IRIS code ``` No `flash-attn`, no `triton`, no `apex`, no custom CUDA kernels. Pure PyTorch. --- ## Model Variants | Config | Params | Tokens | Patch Size | Target Device | FP16 Memory | |--------|--------|--------|------------|---------------|-------------| | `iris-tiny` | 10.3M | 16 | 4 | Any phone | 21 MB | | `iris-small` | 40.0M | 16 | 4 | Modern phone / Colab | 80 MB | | `iris-base` | 53.4M | 64 | 2 | Phone/tablet | 107 MB | | `iris-medium` | 181.2M | 64 | 2 | Desktop/cloud | 362 MB | | `iris-large` | 430.9M | 64 | 2 | Cloud | 862 MB | ## Architecture IRIS combines five innovations: 1. **PDE-SSM spatial mixing** — O(N log N) Fourier-domain PDE, native 2D. [PDE-SSM-DiT](https://arxiv.org/abs/2603.13663) 2. **Weight-shared refinement** — 6 blocks × R iterations, same weights. [GRN](https://arxiv.org/abs/2604.13030) 3. **Structured latent canvas** — DC-AE with channel masking. [DC-AE 1.5](https://arxiv.org/abs/2508.00413) 4. **Tiny decoder** — 0.1M params PixelShuffle. [SnapGen](https://arxiv.org/abs/2412.09619) 5. **MQA + 2D RoPE + QK-RMSNorm** — Mobile-optimized. [SnapGen++](https://arxiv.org/abs/2601.08303) ## Python API ```python from iris import IRIS, get_model_config, flow_matching_loss, euler_sample import torch # Create model (with text projection for 384-dim MiniLM embeddings) model = IRIS(**get_model_config("iris-small"), text_dim=384) # Training step z_0 = torch.randn(4, 32, 16, 16) * 2.5 # DC-AE latents text_emb = torch.randn(4, 1, 384) # MiniLM text embeddings losses = flow_matching_loss(model, z_0, text_emb, num_iterations=3) losses["loss"].backward() # Sampling model.eval() noise = torch.randn(1, 32, 16, 16) with torch.no_grad(): z_pred = euler_sample(model, noise, text_emb[:1], num_steps=20, num_iterations=3) image = model.decode_latent(z_pred) # (1, 3, 512, 512) ``` ## Files ``` colab_train_iris.py # <-- ONE-CLICK COLAB NOTEBOOK iris/ __init__.py # Public API model.py # IRIS model (Patchify, Unpatchify, TinyDecoder, IRIS) core.py # RefinementCore (weight-shared block loop) pde_ssm.py # SpectralConv2d, TokenDifferential, PDESSMBlock blocks.py # MQA, RoPE2D, UIB-FFN, TimestepEmbed, IterationEmbed flow_matching.py # Rectified flow loss, Euler sampler configs.py # 5 model configurations (tiny → large) train.py # Training utilities (dataset, scheduler) train_production.py # CLI training script test_all.py # 17-test suite ``` ## License MIT