diff --git "a/baseline/shortcut_mybase_secondary.txt" "b/baseline/shortcut_mybase_secondary.txt" new file mode 100644--- /dev/null +++ "b/baseline/shortcut_mybase_secondary.txt" @@ -0,0 +1,2411 @@ +Using devices [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)] +Device count 4 +Global device count 4 +Global Batch: 256 +Node Batch: 256 +Device Batch: 64 +Loading dataset +Loading dataset +DiT: Input of shape (1, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (1, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (1, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (1, 768) dtype float32 + + DiT Summary  +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ path  ┃ module  ┃ inputs  ┃ outputs  ┃ params  ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ │ DiT │ - float32[1,32,32,4] │ bfloat16[1,32,32,4] │ │ +│ │ │ - float32[1] │ │ │ +│ │ │ - float32[1] │ │ │ +│ │ │ - int32[1] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ PatchEmbed_0 │ PatchEmbed │ float32[1,32,32,4] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ PatchEmbed_0/Conv_0 │ Conv │ float32[1,32,32,4] │ bfloat16[1,16,16,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[2,2,4,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 13,056 (52.2 KB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_0 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_0/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[256,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 197,376 (789.5 KB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_0/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_1 │ TimestepEmbedder │ float32[1] │ float32[1,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_1/Dense_0 │ Dense │ bfloat16[1,256] │ bfloat16[1,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[256,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 197,376 (789.5 KB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ TimestepEmbedder_1/Dense_1 │ Dense │ bfloat16[1,768] │ float32[1,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ LabelEmbedder_0 │ LabelEmbedder │ int32[1] │ bfloat16[1,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────���───────────────────────┼──────────────────────────────┤ +│ LabelEmbedder_0/Embed_0 │ Embed │ int32[1] │ bfloat16[1,768] │ embedding: float32[1001,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 768,768 (3.1 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────��───────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼─���────────────────────────────┤ +│ DiTBlock_0/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_0/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────���───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_1/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ ��� │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_2/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_3/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_4/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_5/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼──────────────────────��┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_6/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_7/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼─────────���─────────────┼──────────────────────────────┤ +│ DiTBlock_8/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_8/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────���───────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_9/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_10/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11 │ DiTBlock │ - bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,4608] │ bias: float32[4608] │ +│ │ │ │ │ kernel: float32[768,4608] │ +│ │ │ │ │ │ +│ │ │ │ │ 3,543,552 (14.2 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/Dense_2 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/Dense_3 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/Dense_4 │ Dense │ float32[1,256,768] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[768,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 590,592 (2.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/LayerNorm_1 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/MlpBlock_0 │ MlpBlock │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/MlpBlock_0/Dense_0 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,3072] │ bias: float32[3072] │ +│ │ │ │ │ kernel: float32[768,3072] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,362,368 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/MlpBlock_0/Dropout_0 │ Dropout │ bfloat16[1,256,3072] │ bfloat16[1,256,3072] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/MlpBlock_0/Dense_1 │ Dense │ bfloat16[1,256,3072] │ bfloat16[1,256,768] │ bias: float32[768] │ +│ │ │ │ │ kernel: float32[3072,768] │ +│ │ │ │ │ │ +│ │ │ │ │ 2,360,064 (9.4 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ DiTBlock_11/MlpBlock_0/Dropout_1 │ Dropout │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ FinalLayer_0 │ FinalLayer │ - bfloat16[1,256,768] │ bfloat16[1,256,16] │ │ +│ │ │ - float32[1,768] │ │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ FinalLayer_0/Dense_0 │ Dense │ float32[1,768] │ bfloat16[1,1536] │ bias: float32[1536] │ +│ │ │ │ │ kernel: float32[768,1536] │ +│ │ │ │ │ │ +│ │ │ │ │ 1,181,184 (4.7 MB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ FinalLayer_0/LayerNorm_0 │ LayerNorm │ bfloat16[1,256,768] │ bfloat16[1,256,768] │ │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ FinalLayer_0/Dense_1 │ Dense │ bfloat16[1,256,768] │ bfloat16[1,256,16] │ bias: float32[16] │ +│ │ │ │ │ kernel: float32[768,16] │ +│ │ │ │ │ │ +│ │ │ │ │ 12,304 (49.2 KB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│ Embed_0 │ Embed │ int32[1] │ float32[1,1] │ embedding: float32[256,1] │ +│ │ │ │ │ │ +│ │ │ │ │ 256 (1.0 KB) │ +├──────────────────────────────────┼──────────────────┼───────────────────────┼───────────────────────┼──────────────────────────────┤ +│   │   │   │  Total │ 131,091,728 (524.4 MB)  │ +└──────────────────────────────────┴──────────────────┴───────────────────────┴───────────────────────┴──────────────────────────────┘ +  + Total Parameters: 131,091,728 (524.4 MB)  + + +DiT: Input of shape (1, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (1, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (1, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (1, 768) dtype float32 +Loaded checkpoint from 975303 seconds ago. + + parameter shapes: +('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (768,) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) +('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_0', 'Dense_0', 'bias'): (4608,) +('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (768,) +('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (768,) +('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (768,) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_1', 'Dense_0', 'bias'): (4608,) +('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_1', 'bias'): (768,) +('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (768,) +('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (768,) +('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (768,) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_2', 'Dense_0', 'bias'): (4608,) +('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_1', 'bias'): (768,) +('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (768,) +('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (768,) +('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (768,) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_3', 'Dense_0', 'bias'): (4608,) +('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_1', 'bias'): (768,) +('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (768,) +('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (768,) +('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (768,) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_4', 'Dense_0', 'bias'): (4608,) +('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_1', 'bias'): (768,) +('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (768,) +('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (768,) +('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (768,) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_5', 'Dense_0', 'bias'): (4608,) +('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_1', 'bias'): (768,) +('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (768,) +('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (768,) +('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (768,) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_6', 'Dense_0', 'bias'): (4608,) +('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_1', 'bias'): (768,) +('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (768,) +('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (768,) +('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (768,) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_7', 'Dense_0', 'bias'): (4608,) +('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_1', 'bias'): (768,) +('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (768,) +('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (768,) +('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (768,) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_8', 'Dense_0', 'bias'): (4608,) +('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_1', 'bias'): (768,) +('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (768,) +('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (768,) +('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (768,) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_9', 'Dense_0', 'bias'): (4608,) +('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_1', 'bias'): (768,) +('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (768,) +('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (768,) +('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (768,) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_10', 'Dense_0', 'bias'): (4608,) +('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_1', 'bias'): (768,) +('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (768,) +('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (768,) +('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (768,) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_11', 'Dense_0', 'bias'): (4608,) +('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_1', 'bias'): (768,) +('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (768,) +('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (768,) +('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (768,) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) +('FinalLayer_0', 'Dense_0', 'bias'): (1536,) +('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) +('FinalLayer_0', 'Dense_1', 'bias'): (16,) +('Embed_0', 'embedding'): (256, 1) + + parameter shapes: +('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('Embed_0', 'embedding'): (1, 256, 1) +('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) +('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) +('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) +('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) +('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) + + parameter shapes: +('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('Embed_0', 'embedding'): (1, 256, 1) +('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) +('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) +('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) +('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) +('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) + + parameter shapes: +('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('Embed_0', 'embedding'): (1, 256, 1) +('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) +('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) +('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) +('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) +('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) + + parameter shapes: +('DiTBlock_0', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_0', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_0', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_1', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_1', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_1', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_1', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_10', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_10', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_10', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_10', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_11', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_11', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_11', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_11', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_2', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_2', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_2', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_2', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_3', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_3', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_3', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_3', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_4', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_4', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_4', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_4', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_5', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_5', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_5', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_5', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_6', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_6', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_6', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_6', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_7', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_7', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_7', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_7', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_8', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_8', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_8', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_8', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('DiTBlock_9', 'Dense_0', 'bias'): (1, 4608) +('DiTBlock_9', 'Dense_0', 'kernel'): (1, 768, 4608) +('DiTBlock_9', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_1', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_2', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_3', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (1, 768) +('DiTBlock_9', 'Dense_4', 'kernel'): (1, 768, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (1, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (1, 768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (1, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (1, 3072, 768) +('Embed_0', 'embedding'): (1, 256, 1) +('FinalLayer_0', 'Dense_0', 'bias'): (1, 1536) +('FinalLayer_0', 'Dense_0', 'kernel'): (1, 768, 1536) +('FinalLayer_0', 'Dense_1', 'bias'): (1, 16) +('FinalLayer_0', 'Dense_1', 'kernel'): (1, 768, 16) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1, 1001, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (1, 768) +('PatchEmbed_0', 'Conv_0', 'kernel'): (1, 2, 2, 4, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (1, 768, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (1, 256, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (1, 768) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (1, 768, 768) + + parameter shapes: +('DiTBlock_0', 'Dense_0', 'bias'): (4608,) +('DiTBlock_0', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_0', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_2', 'bias'): (768,) +('DiTBlock_0', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_3', 'bias'): (768,) +('DiTBlock_0', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_0', 'Dense_4', 'bias'): (768,) +('DiTBlock_0', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_0', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_0', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_1', 'Dense_0', 'bias'): (4608,) +('DiTBlock_1', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_1', 'Dense_1', 'bias'): (768,) +('DiTBlock_1', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_2', 'bias'): (768,) +('DiTBlock_1', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_3', 'bias'): (768,) +('DiTBlock_1', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_1', 'Dense_4', 'bias'): (768,) +('DiTBlock_1', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_1', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_1', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_10', 'Dense_0', 'bias'): (4608,) +('DiTBlock_10', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_10', 'Dense_1', 'bias'): (768,) +('DiTBlock_10', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_2', 'bias'): (768,) +('DiTBlock_10', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_3', 'bias'): (768,) +('DiTBlock_10', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_10', 'Dense_4', 'bias'): (768,) +('DiTBlock_10', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_10', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_10', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_11', 'Dense_0', 'bias'): (4608,) +('DiTBlock_11', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_11', 'Dense_1', 'bias'): (768,) +('DiTBlock_11', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_2', 'bias'): (768,) +('DiTBlock_11', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_3', 'bias'): (768,) +('DiTBlock_11', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_11', 'Dense_4', 'bias'): (768,) +('DiTBlock_11', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_11', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_11', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_2', 'Dense_0', 'bias'): (4608,) +('DiTBlock_2', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_2', 'Dense_1', 'bias'): (768,) +('DiTBlock_2', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_2', 'bias'): (768,) +('DiTBlock_2', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_3', 'bias'): (768,) +('DiTBlock_2', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_2', 'Dense_4', 'bias'): (768,) +('DiTBlock_2', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_2', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_2', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_3', 'Dense_0', 'bias'): (4608,) +('DiTBlock_3', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_3', 'Dense_1', 'bias'): (768,) +('DiTBlock_3', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_2', 'bias'): (768,) +('DiTBlock_3', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_3', 'bias'): (768,) +('DiTBlock_3', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_3', 'Dense_4', 'bias'): (768,) +('DiTBlock_3', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_3', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_3', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_4', 'Dense_0', 'bias'): (4608,) +('DiTBlock_4', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_4', 'Dense_1', 'bias'): (768,) +('DiTBlock_4', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_2', 'bias'): (768,) +('DiTBlock_4', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_3', 'bias'): (768,) +('DiTBlock_4', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_4', 'Dense_4', 'bias'): (768,) +('DiTBlock_4', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_4', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_4', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_5', 'Dense_0', 'bias'): (4608,) +('DiTBlock_5', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_5', 'Dense_1', 'bias'): (768,) +('DiTBlock_5', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_2', 'bias'): (768,) +('DiTBlock_5', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_3', 'bias'): (768,) +('DiTBlock_5', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_5', 'Dense_4', 'bias'): (768,) +('DiTBlock_5', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_5', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_5', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_6', 'Dense_0', 'bias'): (4608,) +('DiTBlock_6', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_6', 'Dense_1', 'bias'): (768,) +('DiTBlock_6', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_2', 'bias'): (768,) +('DiTBlock_6', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_3', 'bias'): (768,) +('DiTBlock_6', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_6', 'Dense_4', 'bias'): (768,) +('DiTBlock_6', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_6', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_6', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_7', 'Dense_0', 'bias'): (4608,) +('DiTBlock_7', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_7', 'Dense_1', 'bias'): (768,) +('DiTBlock_7', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_2', 'bias'): (768,) +('DiTBlock_7', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_3', 'bias'): (768,) +('DiTBlock_7', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_7', 'Dense_4', 'bias'): (768,) +('DiTBlock_7', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_7', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_7', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_8', 'Dense_0', 'bias'): (4608,) +('DiTBlock_8', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_8', 'Dense_1', 'bias'): (768,) +('DiTBlock_8', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_2', 'bias'): (768,) +('DiTBlock_8', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_3', 'bias'): (768,) +('DiTBlock_8', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_8', 'Dense_4', 'bias'): (768,) +('DiTBlock_8', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_8', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_8', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('DiTBlock_9', 'Dense_0', 'bias'): (4608,) +('DiTBlock_9', 'Dense_0', 'kernel'): (768, 4608) +('DiTBlock_9', 'Dense_1', 'bias'): (768,) +('DiTBlock_9', 'Dense_1', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_2', 'bias'): (768,) +('DiTBlock_9', 'Dense_2', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_3', 'bias'): (768,) +('DiTBlock_9', 'Dense_3', 'kernel'): (768, 768) +('DiTBlock_9', 'Dense_4', 'bias'): (768,) +('DiTBlock_9', 'Dense_4', 'kernel'): (768, 768) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'bias'): (3072,) +('DiTBlock_9', 'MlpBlock_0', 'Dense_0', 'kernel'): (768, 3072) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'bias'): (768,) +('DiTBlock_9', 'MlpBlock_0', 'Dense_1', 'kernel'): (3072, 768) +('Embed_0', 'embedding'): (256, 1) +('FinalLayer_0', 'Dense_0', 'bias'): (1536,) +('FinalLayer_0', 'Dense_0', 'kernel'): (768, 1536) +('FinalLayer_0', 'Dense_1', 'bias'): (16,) +('FinalLayer_0', 'Dense_1', 'kernel'): (768, 16) +('LabelEmbedder_0', 'Embed_0', 'embedding'): (1001, 768) +('PatchEmbed_0', 'Conv_0', 'bias'): (768,) +('PatchEmbed_0', 'Conv_0', 'kernel'): (2, 2, 4, 768) +('TimestepEmbedder_0', 'Dense_0', 'bias'): (768,) +('TimestepEmbedder_0', 'Dense_0', 'kernel'): (256, 768) +('TimestepEmbedder_0', 'Dense_1', 'bias'): (768,) +('TimestepEmbedder_0', 'Dense_1', 'kernel'): (768, 768) +('TimestepEmbedder_1', 'Dense_0', 'bias'): (768,) +('TimestepEmbedder_1', 'Dense_0', 'kernel'): (256, 768) +('TimestepEmbedder_1', 'Dense_1', 'bias'): (768,) +('TimestepEmbedder_1', 'Dense_1', 'kernel'): (768, 768) +┌────────────────────────────────────────────────┐ +│ │ +│ │ +│ │ +│ │ +│ TPU 0,1,2,3 │ +│ │ +│ │ +│ │ +│ │ +└────────────────────────────────────────────────┘ +┌─────────────────────────────────────────────────────────────────────────┐ +│ │ +│ │ +│ │ +│ │ +│ TPU 0,1,2,3 │ +│ │ +│ │ +│ │ +│ │ +└─────────────────────────────────────────────────────────────────────────┘ +Calc FID for CFG 1.0 and denoise_timesteps 128 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 25.99567413330078 +Calc FID for CFG 1.0 and denoise_timesteps 64 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 19.025453567504883 +Calc FID for CFG 1.0 and denoise_timesteps 32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 20.295011520385742 +Calc FID for CFG 1.0 and denoise_timesteps 16 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 22.538732528686523 +Calc FID for CFG 1.0 and denoise_timesteps 8 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 25.382877349853516 +Calc FID for CFG 1.0 and denoise_timesteps 4 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 31.228776931762695 +Calc FID for CFG 1.0 and denoise_timesteps 2 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 42.37808609008789 +Calc FID for CFG 1.0 and denoise_timesteps 1 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 39.091678619384766 +Calc FID for CFG 1.25 and denoise_timesteps 128 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 14.789426803588867 +Calc FID for CFG 1.25 and denoise_timesteps 64 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 11.829596519470215 +Calc FID for CFG 1.25 and denoise_timesteps 32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 12.960319519042969 +Calc FID for CFG 1.25 and denoise_timesteps 16 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 14.657230377197266 +Calc FID for CFG 1.25 and denoise_timesteps 8 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 17.24106216430664 +Calc FID for CFG 1.25 and denoise_timesteps 4 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 22.685482025146484 +Calc FID for CFG 1.25 and denoise_timesteps 2 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 33.70756530761719 +Calc FID for CFG 1.25 and denoise_timesteps 1 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 30.63207244873047 +Calc FID for CFG 1.75 and denoise_timesteps 128 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 8.49964714050293 +Calc FID for CFG 1.75 and denoise_timesteps 64 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 9.636996269226074 +Calc FID for CFG 1.75 and denoise_timesteps 32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 10.191990852355957 +Calc FID for CFG 1.75 and denoise_timesteps 16 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 10.939922332763672 +Calc FID for CFG 1.75 and denoise_timesteps 8 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 12.511274337768555 +Calc FID for CFG 1.75 and denoise_timesteps 4 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 17.25038719177246 +Calc FID for CFG 1.75 and denoise_timesteps 2 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 28.857677459716797 +Calc FID for CFG 1.75 and denoise_timesteps 1 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 32.443660736083984 +Calc FID for CFG 2.0 and denoise_timesteps 128 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 8.623676300048828 +Calc FID for CFG 2.0 and denoise_timesteps 64 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 10.539774894714355 +Calc FID for CFG 2.0 and denoise_timesteps 32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 10.816583633422852 +Calc FID for CFG 2.0 and denoise_timesteps 16 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 11.117993354797363 +Calc FID for CFG 2.0 and denoise_timesteps 8 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 12.251923561096191 +Calc FID for CFG 2.0 and denoise_timesteps 4 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 16.70867156982422 +Calc FID for CFG 2.0 and denoise_timesteps 2 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 28.657535552978516 +Calc FID for CFG 2.0 and denoise_timesteps 1 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 37.30630111694336 +Calc FID for CFG 2.25 and denoise_timesteps 128 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 9.414321899414062 +Calc FID for CFG 2.25 and denoise_timesteps 64 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 11.669325828552246 +Calc FID for CFG 2.25 and denoise_timesteps 32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 11.743907928466797 +Calc FID for CFG 2.25 and denoise_timesteps 16 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 11.700990676879883 +Calc FID for CFG 2.25 and denoise_timesteps 8 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 12.487598419189453 +Calc FID for CFG 2.25 and denoise_timesteps 4 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 16.715736389160156 +Calc FID for CFG 2.25 and denoise_timesteps 2 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 29.09388542175293 +Calc FID for CFG 2.25 and denoise_timesteps 1 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 43.063232421875 +Calc FID for CFG 2.5 and denoise_timesteps 128 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 10.502670288085938 +Calc FID for CFG 2.5 and denoise_timesteps 64 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 12.845773696899414 +Calc FID for CFG 2.5 and denoise_timesteps 32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 12.717670440673828 +Calc FID for CFG 2.5 and denoise_timesteps 16 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 12.473001480102539 +Calc FID for CFG 2.5 and denoise_timesteps 8 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 12.964001655578613 +Calc FID for CFG 2.5 and denoise_timesteps 4 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 17.037921905517578 +Calc FID for CFG 2.5 and denoise_timesteps 2 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 29.91904067993164 +Calc FID for CFG 2.5 and denoise_timesteps 1 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 48.83190155029297 +Calc FID for CFG 2.75 and denoise_timesteps 128 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 11.668487548828125 +Calc FID for CFG 2.75 and denoise_timesteps 64 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 13.983850479125977 +Calc FID for CFG 2.75 and denoise_timesteps 32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 13.729737281799316 +Calc FID for CFG 2.75 and denoise_timesteps 16 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 13.269363403320312 +Calc FID for CFG 2.75 and denoise_timesteps 8 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 13.49327564239502 +Calc FID for CFG 2.75 and denoise_timesteps 4 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 17.548564910888672 +Calc FID for CFG 2.75 and denoise_timesteps 2 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 31.028226852416992 +Calc FID for CFG 2.75 and denoise_timesteps 1 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 54.34938049316406 +Calc FID for CFG 3.0 and denoise_timesteps 128 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 12.788758277893066 +Calc FID for CFG 3.0 and denoise_timesteps 64 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 14.978116989135742 +Calc FID for CFG 3.0 and denoise_timesteps 32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 14.65621566772461 +Calc FID for CFG 3.0 and denoise_timesteps 16 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 14.00898551940918 +Calc FID for CFG 3.0 and denoise_timesteps 8 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 14.094219207763672 +Calc FID for CFG 3.0 and denoise_timesteps 4 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 18.206756591796875 +Calc FID for CFG 3.0 and denoise_timesteps 2 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 32.331520080566406 +Calc FID for CFG 3.0 and denoise_timesteps 1 +DiT: Input of shape (256, 32, 32, 4) dtype float32 +DiT: After patch embed, shape is (256, 256, 768) dtype bfloat16 +DiT: Patch Embed of shape (256, 256, 768) dtype bfloat16 +DiT: Conditioning of shape (256, 768) dtype float32 +FID is 59.46251678466797