MVA_GenAI / unet_cifar.yaml
haiphamcse's picture
Upload folder using huggingface_hub
f729117 verified
# UNet + CFM training hyperparameters (used by train_cfm_unet.py --config)
sigma: 0.0
# Image shape C, H, W
dim: [3, 32, 32]
lr: 1.0e-4
weight_decay: 0.0
# NeuralODE visualization / sampling
save_ep: 30
inference_steps: 100
vis_batch_size: 8
# UNet (torchcfm UNetModelWrapper)
num_res_blocks: 2
num_channels: 128
channel_mult: [1, 2, 2, 2]
num_heads: 4
num_head_channels: 64
attention_resolutions: "16"
dropout: 0.1