| # UNet + CFM training hyperparameters (used by train_cfm_unet.py --config) | |
| sigma: 0.0 | |
| # Image shape C, H, W | |
| dim: [3, 32, 32] | |
| lr: 1.0e-4 | |
| weight_decay: 0.0 | |
| # NeuralODE visualization / sampling | |
| save_ep: 30 | |
| inference_steps: 100 | |
| vis_batch_size: 8 | |
| # UNet (torchcfm UNetModelWrapper) | |
| num_res_blocks: 2 | |
| num_channels: 128 | |
| channel_mult: [1, 2, 2, 2] | |
| num_heads: 4 | |
| num_head_channels: 64 | |
| attention_resolutions: "16" | |
| dropout: 0.1 | |