| """IRIS model configurations. Each config is tested and has known parameter count.""" |
|
|
| CONFIGS = { |
| "iris-tiny": {"latent_channels": 32, "dim": 256, "patch_size": 4, "num_blocks": 6, "num_heads": 4, "max_iterations": 8, "ffn_expansion": 2, "params_M": 10.3, "tokens": 16, "description": "Ultra-mobile, 10M params, 16 tokens, trains on Colab free tier"}, |
| "iris-small": {"latent_channels": 32, "dim": 512, "patch_size": 4, "num_blocks": 6, "num_heads": 8, "max_iterations": 8, "ffn_expansion": 2, "params_M": 40.0, "tokens": 16, "description": "Mobile, 40M params, 16 tokens, trains on 16GB GPU"}, |
| "iris-base": {"latent_channels": 32, "dim": 512, "patch_size": 2, "num_blocks": 8, "num_heads": 8, "max_iterations": 8, "ffn_expansion": 2, "params_M": 53.4, "tokens": 64, "description": "Base quality, 53M params, 64 tokens, trains on 16GB GPU"}, |
| "iris-medium": {"latent_channels": 32, "dim": 768, "patch_size": 2, "num_blocks": 12, "num_heads": 12, "max_iterations": 8, "ffn_expansion": 2, "params_M": 181.2, "tokens": 64, "description": "Full quality, 181M params, 64 tokens, needs 24GB GPU"}, |
| "iris-large": {"latent_channels": 32, "dim": 1024, "patch_size": 2, "num_blocks": 16, "num_heads": 16, "max_iterations": 8, "ffn_expansion": 2, "params_M": 430.9, "tokens": 64, "description": "Maximum quality, 431M params, 64 tokens, needs 40GB+ GPU"}, |
| } |
|
|
|
|
| def get_model_config(name: str) -> dict: |
| if name not in CONFIGS: |
| raise ValueError(f"Unknown config: {name}. Available: {list(CONFIGS.keys())}") |
| cfg = CONFIGS[name].copy() |
| for k in ["params_M", "tokens", "description"]: |
| cfg.pop(k, None) |
| return cfg |
|
|