| """Configuration dataclass for LuminaRS.""" | |
| from dataclasses import dataclass | |
| class LuminaRSConfig: | |
| # Latent dims | |
| latent_dim: int = 16 # VAE latent channels | |
| latent_h: int = 32 # 32x32 latent = 1024x1024 px | |
| latent_w: int = 32 | |
| # UNet channels at each scale | |
| channels: tuple = (64, 128, 256, 256, 384) | |
| # Bottleneck depth | |
| n_bottleneck: int = 4 | |
| # Iterative refinement depth (TRM-style shared-weight recursion) | |
| n_recurse: int = 2 | |
| # Time embedding | |
| t_embed_dim: int = 256 | |
| # Text conditioning | |
| text_embed_dim: int = 768 | |
| max_text_len: int = 77 | |
| # Training | |
| drop_path: float = 0.05 | |
| # VAE (frozen) | |
| vae_pretrained: str = "madebyollin/sdxl-vae-fp16-fix" | |
| # Text encoder (frozen) | |
| clip_pretrained: str = "openai/clip-vit-large-patch14" | |
| # Flow matching inference steps | |
| n_flow_steps: int = 12 | |