diff --git "a/1e-6_sampling/logEma.txt" "b/1e-6_sampling/logEma.txt" new file mode 100644--- /dev/null +++ "b/1e-6_sampling/logEma.txt" @@ -0,0 +1,2445 @@ +Using devices [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] +Device count 4 +Global device count 4 +Global Batch: 512 +Node Batch: 512 +Device Batch: 128 +/tmp/tmp3xcbnh3c +Loading dataset +Loading dataset +creating model +beta1: 0.9 +beta2: 0.999 +bootstrap_cfg: 1 +bootstrap_dt_bias: 0 +bootstrap_ema: 1 +bootstrap_every: 8 +cfg_scale: 1.5 +class_dropout_prob: 0.1 +denoise_timesteps: 128 +depth: 12 +dropout: 0.0 +dt_sampling: uniform +hidden_size: 768 +lr: 0.0001 +mlp_ratio: 4 +num_classes: 1000 +num_heads: 12 +patch_size: 2 +sharding: dp +t_sampling: discrete-dt +target_update_rate: 0.999 +train_type: naive +use_cosine: 0 +use_ema: 0 +use_stable_vae: 1 +warmup: 0 +weight_decay: 0.1 + +Total devices TPU_0(process=0,(0,0,0,0)) +Initializing encoder. +Incoming encoder shape (1, 256, 256, 3) +Encoder layer (1, 256, 256, 128) +doing downsample +Encoder layer (1, 128, 128, 128) +doing downsample +Encoder layer (1, 64, 64, 256) +doing downsample +Encoder layer (1, 32, 32, 512) +Encoder layer (1, 32, 32, 512) +Encoder layer final (1, 32, 32, 512) +Encoder layer final (1, 32, 32, 512) +Final embeddings are size (1, 32, 32, 8) +After quant (1, 32, 32, 4) +encode finished +Decoder incoming shape (1, 32, 32, 4) +Decoder input (1, 32, 32, 512) +Mid Block Decoder layer (1, 32, 32, 512) +Mid Block Decoder layer (1, 32, 32, 512) +Decoder layer (1, 64, 64, 512) +Decoder layer (1, 128, 128, 512) +Decoder layer (1, 256, 256, 256) +Decoder layer (1, 256, 256, 128) +Total num of VQVAE parameters: 67565323 +Disc shape (1, 128, 128, 128) +Disc shape (1, 64, 64, 256) +Disc shape (1, 32, 32, 512) +Disc shape (1, 16, 16, 512) +Disc shape (1, 8, 8, 512) +Disc shape (1, 4, 4, 512) +Total num of Discriminator parameters: 23998017 +Loaded checkpoint from 19146940 seconds ago. +Loaded model with step 447001 +┌──────────────────────────────────────────────────────────────────────────────┐ +│ TPU 0 │ +├──────────────────────────────────────────────────────────────────────────────┤ +│ TPU 1 │ +├──────────────────────────────────────────────────────────────────────────────┤ +│ TPU 2 │ +├──────────────────────────────────────────────────────────────────────────────┤ +│ TPU 3 │ +└──────────────────────────────────────────────────────────────────────────────┘ +returning model +model done +Input to vae (4, 1, 256, 256, 3) +encode image shape (1, 256, 256, 3) +Initializing encoder. +Incoming encoder shape (1, 256, 256, 3) +Encoder layer (1, 256, 256, 128) +doing downsample +Encoder layer (1, 128, 128, 128) +doing downsample +Encoder layer (1, 64, 64, 256) +doing downsample +Encoder layer (1, 32, 32, 512) +Encoder layer (1, 32, 32, 512) +Encoder layer final (1, 32, 32, 512) +Encoder layer final (1, 32, 32, 512) +Final embeddings are size (1, 32, 32, 8) +After quant (1, 32, 32, 4) +output example shape (4, 1, 32, 32, 4) +Test data shape (4, 256, 256, 3) +x shape (4, 1, 256, 256, 3) +encoded shape (4, 1, 32, 32, 4) +z_vectors shape (1, 32, 32, 4) +Decoder incoming shape (1, 32, 32, 4) +Decoder input (1, 32, 32, 512) +Mid Block Decoder layer (1, 32, 32, 512) +Mid Block Decoder layer (1, 32, 32, 512) +Decoder layer (1, 64, 64, 512) +Decoder layer (1, 128, 128, 512) +Decoder layer (1, 256, 256, 256) +Decoder layer (1, 256, 256, 128) +image shape (4, 1, 256, 256, 3) +decoded img shape (256, 256, 3) +obs shape (4, 32, 32, 4) +DiT: Input of shape (4, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (1, 768) dtype float32 + + DiT Summary  +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ path  ┃ module  ┃ inputs  ┃ outputs  ┃ params  ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ │ DiT │ - float32[4,32,32,4] │ bfloat16[4,32,32,4] │ │ +│ │ │ - float32[1] │ │ │ +│ │ │ - float32[1] │ │ │ +│ │ │ - int32[1] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ PatchEmbed_0 │ PatchEmbed │ float32[4,32,32,4] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ PatchEmbed_0/Conv_0 │ Conv │ float32[4,32,32,4] │ bfloat16[4,16,16,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[2,2,4,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 13,056 (52.2 KB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_0 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_0/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[256,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 197,376 (789.5 KB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼──────���────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_0/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_1 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_1/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[256,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 197,376 (789.5 KB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_1/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ LabelEmbedder_0 │ LabelEmbedder │ int32[1] │ bfloat16[1,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ LabelEmbedder_0/Embed_0 │ Embed │ int32[1] │ bfloat16[1,768] │ embedding: float32[1001,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 768,768 (3.1 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────��───────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼────────��─────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼──────���────────────────┼──────────────────────────────┤ +│ DiTBlock_1/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──���───────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼───────��──────────────────────┤ +│ DiTBlock_5/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├───────────────────────────��──────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────��───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├───────────────────────────��──────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼──────────────���────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11 │ DiTBlock │ - bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/Dense_2 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/Dense_3 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/Dense_4 │ Dense │ float32[4,256,768] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/LayerNorm_1 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/MlpBlock_0 │ MlpBlock │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/MlpBlock_0/Dense_0 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[4,256,3072] │ bfloat16[4,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/MlpBlock_0/Dense_1 │ Dense │ bfloat16[4,256,3072] │ bfloat16[4,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ FinalLayer_0 │ FinalLayer │ - bfloat16[4,256,768] │ bfloat16[4,256,16] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ FinalLayer_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,1536] │ bias: float32[1536] │ +│ │ │ │ │ kernel: float32[768,1536] │ +│ │ │ │ │ │ +│ │ │ │ │ 1,181,184 (4.7 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ FinalLayer_0/LayerNorm_0 │ LayerNorm │ bfloat16[4,256,768] │ bfloat16[4,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼─────────────────��────────────┤ +│ FinalLayer_0/Dense_1 │ Dense │ bfloat16[4,256,768] │ bfloat16[4,256,16] │ bias: float32[16] │ +│ │ │ │ │ kernel: float32[768,16] │ +│ │ │ │ │ │ +│ │ │ │ │ 12,304 (49.2 KB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ Embed_0 │ Embed │ int32[1] │ float32[1,1] │ embedding: float32[256,1] │ +│ │ │ │ │ │ +│ │ │ │ │ 256 (1.0 KB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│   │   │   │  Total │ 131,091,728 (524.4 MB)  │ +└──────────────────────────────────┴──────────────────┴───────────────────────┴───────────────────────┴──────────────────────────────┘ +  + Total Parameters: 131,091,728 (524.4 MB)  + + +DiT: Input of shape (4, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (4, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (4, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (1, 768) dtype float32 +Loaded checkpoint from 873780 seconds ago. + + parameter shapes: +('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (768,) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) +('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_0', 'Dense_0', 'bias'): (4608,) +('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (768,) +('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (768,) +('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (768,) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_1', 'Dense_0', 'bias'): (4608,) +('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_1', 'bias'): (768,) +('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (768,) +('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (768,) +('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (768,) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_2', 'Dense_0', 'bias'): (4608,) +('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_1', 'bias'): (768,) +('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (768,) +('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (768,) +('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (768,) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_3', 'Dense_0', 'bias'): (4608,) +('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_1', 'bias'): (768,) +('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (768,) +('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (768,) +('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (768,) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_4', 'Dense_0', 'bias'): (4608,) +('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_1', 'bias'): (768,) +('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (768,) +('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (768,) +('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (768,) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_5', 'Dense_0', 'bias'): (4608,) +('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_1', 'bias'): (768,) +('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (768,) +('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (768,) +('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (768,) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_6', 'Dense_0', 'bias'): (4608,) +('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_1', 'bias'): (768,) +('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (768,) +('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (768,) +('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (768,) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_7', 'Dense_0', 'bias'): (4608,) +('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_1', 'bias'): (768,) +('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (768,) +('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (768,) +('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (768,) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_8', 'Dense_0', 'bias'): (4608,) +('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_1', 'bias'): (768,) +('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (768,) +('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (768,) +('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (768,) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_9', 'Dense_0', 'bias'): (4608,) +('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_1', 'bias'): (768,) +('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (768,) +('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (768,) +('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (768,) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_10', 'Dense_0', 'bias'): (4608,) +('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_1', 'bias'): (768,) +('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (768,) +('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (768,) +('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (768,) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_11', 'Dense_0', 'bias'): (4608,) +('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_1', 'bias'): (768,) +('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (768,) +('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (768,) +('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (768,) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) +('FinalLayer_0', 'Dense_0', 'bias'): (1536,) +('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) +('FinalLayer_0', 'Dense_1', 'bias'): (16,) +('Embed_0', 'embedding'): (256, 1) +regular + + parameter shapes: +('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('Embed_0', 'embedding'): (1, 256, 1) +('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) +('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) +('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) +('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) +('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) +flat + + parameter shapes: +('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('Embed_0', 'embedding'): (1, 256, 1) +('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) +('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) +('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) +('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) +('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) +flat ema + + parameter shapes: +('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('Embed_0', 'embedding'): (1, 256, 1) +('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) +('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) +('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) +('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) +('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) + + parameter shapes: +('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('Embed_0', 'embedding'): (1, 256, 1) +('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) +('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) +('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) +('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) +('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) + + parameter shapes: +('DiTBlock_0', 'Dense_0', 'bias'): (4608,) +('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (768,) +('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (768,) +('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (768,) +('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_1', 'Dense_0', 'bias'): (4608,) +('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_1', 'Dense_1', 'bias'): (768,) +('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (768,) +('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (768,) +('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (768,) +('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_10', 'Dense_0', 'bias'): (4608,) +('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_10', 'Dense_1', 'bias'): (768,) +('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (768,) +('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (768,) +('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (768,) +('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_11', 'Dense_0', 'bias'): (4608,) +('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_11', 'Dense_1', 'bias'): (768,) +('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (768,) +('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (768,) +('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (768,) +('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_2', 'Dense_0', 'bias'): (4608,) +('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_2', 'Dense_1', 'bias'): (768,) +('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (768,) +('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (768,) +('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (768,) +('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_3', 'Dense_0', 'bias'): (4608,) +('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_3', 'Dense_1', 'bias'): (768,) +('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (768,) +('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (768,) +('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (768,) +('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_4', 'Dense_0', 'bias'): (4608,) +('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_4', 'Dense_1', 'bias'): (768,) +('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (768,) +('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (768,) +('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (768,) +('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_5', 'Dense_0', 'bias'): (4608,) +('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_5', 'Dense_1', 'bias'): (768,) +('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (768,) +('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (768,) +('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (768,) +('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_6', 'Dense_0', 'bias'): (4608,) +('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_6', 'Dense_1', 'bias'): (768,) +('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (768,) +('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (768,) +('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (768,) +('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_7', 'Dense_0', 'bias'): (4608,) +('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_7', 'Dense_1', 'bias'): (768,) +('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (768,) +('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (768,) +('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (768,) +('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_8', 'Dense_0', 'bias'): (4608,) +('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_8', 'Dense_1', 'bias'): (768,) +('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (768,) +('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (768,) +('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (768,) +('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_9', 'Dense_0', 'bias'): (4608,) +('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_9', 'Dense_1', 'bias'): (768,) +('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (768,) +('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (768,) +('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (768,) +('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('Embed_0', 'embedding'): (256, 1) +('FinalLayer_0', 'Dense_0', 'bias'): (1536,) +('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) +('FinalLayer_0', 'Dense_1', 'bias'): (16,) +('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (768,) +('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) +┌────────────────────────────────────────────────┐ +│ │ +│ │ +│ │ +│ �� +│ TPU 0,1,2,3 │ +│ │ +│ │ +│ │ +│ │ +└────────────────────────────────────────────────┘ +┌─────────────────────────────────────────────────────────────────────────┐ +│ │ +│ │ +│ │ +│ │ +│ TPU 0,1,2,3 │ +│ │ +│ │ +│ │ +│ │ +└─────────────────────────────────────────────────────────────────────────┘ +doing the else +(512, 256, 256, 3) +encode image shape (128, 256, 256, 3) +Initializing encoder. +Incoming encoder shape (128, 256, 256, 3) +Encoder layer (128, 256, 256, 128) +doing downsample +Encoder layer (128, 128, 128, 128) +doing downsample +Encoder layer (128, 64, 64, 256) +doing downsample +Encoder layer (128, 32, 32, 512) +Encoder layer (128, 32, 32, 512) +Encoder layer final (128, 32, 32, 512) +Encoder layer final (128, 32, 32, 512) +Final embeddings are size (128, 32, 32, 8) +After quant (128, 32, 32, 4) +Calc FID for CFG 1.0 and denoise_timesteps 128 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +z_vectors shape (128, 32, 32, 4) +Decoder incoming shape (128, 32, 32, 4) +Decoder input (128, 32, 32, 512) +Mid Block Decoder layer (128, 32, 32, 512) +Mid Block Decoder layer (128, 32, 32, 512) +Decoder layer (128, 64, 64, 512) +Decoder layer (128, 128, 128, 512) +Decoder layer (128, 256, 256, 256) +Decoder layer (128, 256, 256, 128) +FID is 27.245651245117188 +(512, 256, 256, 3) +Calc FID for CFG 1.0 and denoise_timesteps 64 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 27.795284271240234 +(512, 256, 256, 3) +Calc FID for CFG 1.0 and denoise_timesteps 32 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 29.345104217529297 +(512, 256, 256, 3) +Calc FID for CFG 1.0 and denoise_timesteps 16 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 34.199710845947266 +(512, 256, 256, 3) +Calc FID for CFG 1.0 and denoise_timesteps 8 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 49.71440887451172 +(512, 256, 256, 3) +Calc FID for CFG 1.0 and denoise_timesteps 4 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 110.30722045898438 +(512, 256, 256, 3) +Calc FID for CFG 1.0 and denoise_timesteps 2 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 253.87759399414062 +(512, 256, 256, 3) +Calc FID for CFG 1.0 and denoise_timesteps 1 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 327.1197509765625 +(512, 256, 256, 3) +Calc FID for CFG 1.25 and denoise_timesteps 128 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 15.681865692138672 +(512, 256, 256, 3) +Calc FID for CFG 1.25 and denoise_timesteps 64 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 16.033981323242188 +(512, 256, 256, 3) +Calc FID for CFG 1.25 and denoise_timesteps 32 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 17.150859832763672 +(512, 256, 256, 3) +Calc FID for CFG 1.25 and denoise_timesteps 16 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 20.734865188598633 +(512, 256, 256, 3) +Calc FID for CFG 1.25 and denoise_timesteps 8 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 33.42859649658203 +(512, 256, 256, 3) +Calc FID for CFG 1.25 and denoise_timesteps 4 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 87.03020477294922 +(512, 256, 256, 3) +Calc FID for CFG 1.25 and denoise_timesteps 2 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 244.01852416992188 +(512, 256, 256, 3) +Calc FID for CFG 1.25 and denoise_timesteps 1 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 311.54217529296875 +(512, 256, 256, 3) +Calc FID for CFG 1.5 and denoise_timesteps 128 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 9.86760425567627 +(512, 256, 256, 3) +Calc FID for CFG 1.5 and denoise_timesteps 64 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 10.090585708618164 +(512, 256, 256, 3) +Calc FID for CFG 1.5 and denoise_timesteps 32 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 10.822160720825195 +(512, 256, 256, 3) +Calc FID for CFG 1.5 and denoise_timesteps 16 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 13.285333633422852 +(512, 256, 256, 3) +Calc FID for CFG 1.5 and denoise_timesteps 8 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 22.932682037353516 +(512, 256, 256, 3) +Calc FID for CFG 1.5 and denoise_timesteps 4 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 68.92096710205078 +(512, 256, 256, 3) +Calc FID for CFG 1.5 and denoise_timesteps 2 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 235.88409423828125 +(512, 256, 256, 3) +Calc FID for CFG 1.5 and denoise_timesteps 1 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 300.24005126953125 +(512, 256, 256, 3) +Calc FID for CFG 1.75 and denoise_timesteps 128 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 7.5391693115234375 +(512, 256, 256, 3) +Calc FID for CFG 1.75 and denoise_timesteps 64 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 7.657982349395752 +(512, 256, 256, 3) +Calc FID for CFG 1.75 and denoise_timesteps 32 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 8.088539123535156 +(512, 256, 256, 3) +Calc FID for CFG 1.75 and denoise_timesteps 16 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 9.734247207641602 +(512, 256, 256, 3) +Calc FID for CFG 1.75 and denoise_timesteps 8 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 16.740436553955078 +(512, 256, 256, 3) +Calc FID for CFG 1.75 and denoise_timesteps 4 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 55.433101654052734 +(512, 256, 256, 3) +Calc FID for CFG 1.75 and denoise_timesteps 2 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 228.83120727539062 +(512, 256, 256, 3) +Calc FID for CFG 1.75 and denoise_timesteps 1 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 291.853759765625 +(512, 256, 256, 3) +Calc FID for CFG 2.0 and denoise_timesteps 128 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 7.187783718109131 +(512, 256, 256, 3) +Calc FID for CFG 2.0 and denoise_timesteps 64 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 7.233066558837891 +(512, 256, 256, 3) +Calc FID for CFG 2.0 and denoise_timesteps 32 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 7.4409308433532715 +(512, 256, 256, 3) +Calc FID for CFG 2.0 and denoise_timesteps 16 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 8.426179885864258 +(512, 256, 256, 3) +Calc FID for CFG 2.0 and denoise_timesteps 8 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 13.42213249206543 +(512, 256, 256, 3) +Calc FID for CFG 2.0 and denoise_timesteps 4 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 45.36845397949219 +(512, 256, 256, 3) +Calc FID for CFG 2.0 and denoise_timesteps 2 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 222.56890869140625 +(512, 256, 256, 3) +Calc FID for CFG 2.0 and denoise_timesteps 1 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 285.29339599609375 +(512, 256, 256, 3) +Calc FID for CFG 2.25 and denoise_timesteps 128 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 7.834066867828369 +(512, 256, 256, 3) +Calc FID for CFG 2.25 and denoise_timesteps 64 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 7.815821647644043 +(512, 256, 256, 3) +Calc FID for CFG 2.25 and denoise_timesteps 32 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 7.896724700927734 +(512, 256, 256, 3) +Calc FID for CFG 2.25 and denoise_timesteps 16 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 8.43143081665039 +(512, 256, 256, 3) +Calc FID for CFG 2.25 and denoise_timesteps 8 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 11.801475524902344 +(512, 256, 256, 3) +Calc FID for CFG 2.25 and denoise_timesteps 4 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 38.00835418701172 +(512, 256, 256, 3) +Calc FID for CFG 2.25 and denoise_timesteps 2 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 217.1334991455078 +(512, 256, 256, 3) +Calc FID for CFG 2.25 and denoise_timesteps 1 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 279.98748779296875 +(512, 256, 256, 3) +Calc FID for CFG 2.5 and denoise_timesteps 128 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 8.856758117675781 +(512, 256, 256, 3) +Calc FID for CFG 2.5 and denoise_timesteps 64 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 8.807829856872559 +(512, 256, 256, 3) +Calc FID for CFG 2.5 and denoise_timesteps 32 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 8.81913948059082 +(512, 256, 256, 3) +Calc FID for CFG 2.5 and denoise_timesteps 16 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 9.074731826782227 +(512, 256, 256, 3) +Calc FID for CFG 2.5 and denoise_timesteps 8 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 11.30735969543457 +(512, 256, 256, 3) +Calc FID for CFG 2.5 and denoise_timesteps 4 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 32.55115509033203 +(512, 256, 256, 3) +Calc FID for CFG 2.5 and denoise_timesteps 2 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 212.34420776367188 +(512, 256, 256, 3) +Calc FID for CFG 2.5 and denoise_timesteps 1 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 275.5500793457031 +(512, 256, 256, 3) +Calc FID for CFG 2.75 and denoise_timesteps 128 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 10.09544563293457 +(512, 256, 256, 3) +Calc FID for CFG 2.75 and denoise_timesteps 64 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 10.022919654846191 +(512, 256, 256, 3) +Calc FID for CFG 2.75 and denoise_timesteps 32 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 9.957592010498047 +(512, 256, 256, 3) +Calc FID for CFG 2.75 and denoise_timesteps 16 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 10.00850772857666 +(512, 256, 256, 3) +Calc FID for CFG 2.75 and denoise_timesteps 8 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 11.42867374420166 +(512, 256, 256, 3) +Calc FID for CFG 2.75 and denoise_timesteps 4 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 28.6097469329834 +(512, 256, 256, 3) +Calc FID for CFG 2.75 and denoise_timesteps 2 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 208.177490234375 +(512, 256, 256, 3) +Calc FID for CFG 2.75 and denoise_timesteps 1 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 271.6966552734375 +(512, 256, 256, 3) +Calc FID for CFG 3.0 and denoise_timesteps 128 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 11.357705116271973 +(512, 256, 256, 3) +Calc FID for CFG 3.0 and denoise_timesteps 64 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 11.280121803283691 +(512, 256, 256, 3) +Calc FID for CFG 3.0 and denoise_timesteps 32 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 11.18954849243164 +(512, 256, 256, 3) +Calc FID for CFG 3.0 and denoise_timesteps 16 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 11.083646774291992 +(512, 256, 256, 3) +Calc FID for CFG 3.0 and denoise_timesteps 8 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 11.861461639404297 +(512, 256, 256, 3) +Calc FID for CFG 3.0 and denoise_timesteps 4 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 25.789085388183594 +(512, 256, 256, 3) +Calc FID for CFG 3.0 and denoise_timesteps 2 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 204.48948669433594 +(512, 256, 256, 3) +Calc FID for CFG 3.0 and denoise_timesteps 1 +DiT: Input of shape (512, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (512, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (512, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (512, 768) dtype float32 +FID is 268.224853515625 +wandb: +wandb: 🚀 View run shortcut_imagenet256 at: https://wandb.ai/daniel-z-kaplan/shortcut/runs/shortcut_imagenet256_20250826_115637_345353_10 +wandb: Find logs at: ../../../tmp/tmp3xcbnh3c/wandb/run-20250826_115637-shortcut_imagenet256_20250826_115637_345353_10/logs