v3: patch_size=4 (64 tokens), 2 core layers, iters [2,3,4], ~16min total training
Browse files
IRIS_Training_Notebook.ipynb
CHANGED
|
@@ -282,6 +282,7 @@
|
|
| 282 |
"metadata": {},
|
| 283 |
"source": [
|
| 284 |
"# Create IRIS-Tiny (best for free-tier)\n",
|
|
|
|
| 285 |
"config = IRISConfig(\n",
|
| 286 |
" latent_channels=4, # SD-VAE standard\n",
|
| 287 |
" latent_spatial=32, # 256px / 8\n",
|
|
@@ -290,17 +291,17 @@
|
|
| 290 |
" head_dim=64,\n",
|
| 291 |
" ffn_ratio=2.667,\n",
|
| 292 |
" num_prelude_blocks=1,\n",
|
| 293 |
-
" num_core_layers=
|
| 294 |
" num_coda_blocks=1,\n",
|
| 295 |
-
" default_iterations=
|
| 296 |
" max_iterations=16,\n",
|
| 297 |
" fourier_num_blocks=6,\n",
|
| 298 |
" sparsity_threshold=0.01,\n",
|
| 299 |
" recurrence_dim=192,\n",
|
| 300 |
-
" manhattan_window=
|
| 301 |
" text_dim=768,\n",
|
| 302 |
" max_text_tokens=77,\n",
|
| 303 |
-
" patch_size=
|
| 304 |
")\n",
|
| 305 |
"\n",
|
| 306 |
"iris = IRIS(config).to(device)\n",
|
|
@@ -309,7 +310,7 @@
|
|
| 309 |
"\n",
|
| 310 |
"print(f\"IRIS Generator: {gen_params:,} params ({gen_params*2/1024/1024:.1f} MB fp16)\")\n",
|
| 311 |
"print(f\" Core (shared): {core_params:,} ({core_params/gen_params*100:.1f}%)\")\n",
|
| 312 |
-
"print(f\"
|
| 313 |
"print(f\" Input: [B, 4, 32, 32] latent \u2192 Output: [B, 4, 32, 32] velocity\")"
|
| 314 |
],
|
| 315 |
"outputs": [],
|
|
@@ -359,7 +360,7 @@
|
|
| 359 |
"\n",
|
| 360 |
"print(f\"Training for {EPOCHS} epochs ({total_steps} optimizer steps)\")\n",
|
| 361 |
"print(f\"Batch: {BATCH_SIZE} \u00d7 {GRAD_ACCUM} accum = {BATCH_SIZE*GRAD_ACCUM} effective\")\n",
|
| 362 |
-
"print(f\"Iterations per step: random from [
|
| 363 |
"print()\n",
|
| 364 |
"\n",
|
| 365 |
"# \u2500\u2500\u2500 Training Loop \u2500\u2500\u2500\n",
|
|
@@ -379,7 +380,7 @@
|
|
| 379 |
" text_emb = text_emb.to(device, non_blocking=True)\n",
|
| 380 |
"\n",
|
| 381 |
" with torch.amp.autocast('cuda', dtype=torch.float16):\n",
|
| 382 |
-
" r = [
|
| 383 |
" result = iris.train_step_latent(z_0, text_emb, num_iterations=r)\n",
|
| 384 |
" loss = result[\"loss\"] / GRAD_ACCUM\n",
|
| 385 |
"\n",
|
|
@@ -431,7 +432,7 @@
|
|
| 431 |
"\n",
|
| 432 |
"iris.eval()\n",
|
| 433 |
"fig, axes = plt.subplots(len(prompts), 4, figsize=(16, len(prompts)*4))\n",
|
| 434 |
-
"iter_counts = [2,
|
| 435 |
"\n",
|
| 436 |
"for row, prompt in enumerate(prompts):\n",
|
| 437 |
" text_emb = encode_text([prompt])\n",
|
|
|
|
| 282 |
"metadata": {},
|
| 283 |
"source": [
|
| 284 |
"# Create IRIS-Tiny (best for free-tier)\n",
|
| 285 |
+
"# patch_size=4 reduces tokens from 256 to 64 \u2192 4\u00d7 faster training\n",
|
| 286 |
"config = IRISConfig(\n",
|
| 287 |
" latent_channels=4, # SD-VAE standard\n",
|
| 288 |
" latent_spatial=32, # 256px / 8\n",
|
|
|
|
| 291 |
" head_dim=64,\n",
|
| 292 |
" ffn_ratio=2.667,\n",
|
| 293 |
" num_prelude_blocks=1,\n",
|
| 294 |
+
" num_core_layers=2, # 2 layers (speed vs quality tradeoff for demo)\n",
|
| 295 |
" num_coda_blocks=1,\n",
|
| 296 |
+
" default_iterations=4,\n",
|
| 297 |
" max_iterations=16,\n",
|
| 298 |
" fourier_num_blocks=6,\n",
|
| 299 |
" sparsity_threshold=0.01,\n",
|
| 300 |
" recurrence_dim=192,\n",
|
| 301 |
+
" manhattan_window=8,\n",
|
| 302 |
" text_dim=768,\n",
|
| 303 |
" max_text_tokens=77,\n",
|
| 304 |
+
" patch_size=4, # 4\u00d7 larger patches \u2192 64 tokens instead of 256\n",
|
| 305 |
")\n",
|
| 306 |
"\n",
|
| 307 |
"iris = IRIS(config).to(device)\n",
|
|
|
|
| 310 |
"\n",
|
| 311 |
"print(f\"IRIS Generator: {gen_params:,} params ({gen_params*2/1024/1024:.1f} MB fp16)\")\n",
|
| 312 |
"print(f\" Core (shared): {core_params:,} ({core_params/gen_params*100:.1f}%)\")\n",
|
| 313 |
+
"print(f\" Tokens: {config.num_patches} (from {config.latent_spatial}\u00d7{config.latent_spatial} latent, patch_size={config.patch_size})\")\n",
|
| 314 |
"print(f\" Input: [B, 4, 32, 32] latent \u2192 Output: [B, 4, 32, 32] velocity\")"
|
| 315 |
],
|
| 316 |
"outputs": [],
|
|
|
|
| 360 |
"\n",
|
| 361 |
"print(f\"Training for {EPOCHS} epochs ({total_steps} optimizer steps)\")\n",
|
| 362 |
"print(f\"Batch: {BATCH_SIZE} \u00d7 {GRAD_ACCUM} accum = {BATCH_SIZE*GRAD_ACCUM} effective\")\n",
|
| 363 |
+
"print(f\"Iterations per step: random from [2, 3, 4]\")\n",
|
| 364 |
"print()\n",
|
| 365 |
"\n",
|
| 366 |
"# \u2500\u2500\u2500 Training Loop \u2500\u2500\u2500\n",
|
|
|
|
| 380 |
" text_emb = text_emb.to(device, non_blocking=True)\n",
|
| 381 |
"\n",
|
| 382 |
" with torch.amp.autocast('cuda', dtype=torch.float16):\n",
|
| 383 |
+
" r = [2, 3, 4][torch.randint(0, 3, (1,)).item()]\n",
|
| 384 |
" result = iris.train_step_latent(z_0, text_emb, num_iterations=r)\n",
|
| 385 |
" loss = result[\"loss\"] / GRAD_ACCUM\n",
|
| 386 |
"\n",
|
|
|
|
| 432 |
"\n",
|
| 433 |
"iris.eval()\n",
|
| 434 |
"fig, axes = plt.subplots(len(prompts), 4, figsize=(16, len(prompts)*4))\n",
|
| 435 |
+
"iter_counts = [2, 3, 4, 6]\n",
|
| 436 |
"\n",
|
| 437 |
"for row, prompt in enumerate(prompts):\n",
|
| 438 |
" text_emb = encode_text([prompt])\n",
|