asdf98 commited on
Commit
65af9bd
·
verified ·
1 Parent(s): eb07d9d

v3: patch_size=4 (64 tokens), 2 core layers, iters [2,3,4], ~16min total training

Browse files
Files changed (1) hide show
  1. IRIS_Training_Notebook.ipynb +9 -8
IRIS_Training_Notebook.ipynb CHANGED
@@ -282,6 +282,7 @@
282
  "metadata": {},
283
  "source": [
284
  "# Create IRIS-Tiny (best for free-tier)\n",
 
285
  "config = IRISConfig(\n",
286
  " latent_channels=4, # SD-VAE standard\n",
287
  " latent_spatial=32, # 256px / 8\n",
@@ -290,17 +291,17 @@
290
  " head_dim=64,\n",
291
  " ffn_ratio=2.667,\n",
292
  " num_prelude_blocks=1,\n",
293
- " num_core_layers=3,\n",
294
  " num_coda_blocks=1,\n",
295
- " default_iterations=6,\n",
296
  " max_iterations=16,\n",
297
  " fourier_num_blocks=6,\n",
298
  " sparsity_threshold=0.01,\n",
299
  " recurrence_dim=192,\n",
300
- " manhattan_window=12,\n",
301
  " text_dim=768,\n",
302
  " max_text_tokens=77,\n",
303
- " patch_size=2,\n",
304
  ")\n",
305
  "\n",
306
  "iris = IRIS(config).to(device)\n",
@@ -309,7 +310,7 @@
309
  "\n",
310
  "print(f\"IRIS Generator: {gen_params:,} params ({gen_params*2/1024/1024:.1f} MB fp16)\")\n",
311
  "print(f\" Core (shared): {core_params:,} ({core_params/gen_params*100:.1f}%)\")\n",
312
- "print(f\" Effective @r=6: ~{gen_params + 5*core_params:,} effective params\")\n",
313
  "print(f\" Input: [B, 4, 32, 32] latent \u2192 Output: [B, 4, 32, 32] velocity\")"
314
  ],
315
  "outputs": [],
@@ -359,7 +360,7 @@
359
  "\n",
360
  "print(f\"Training for {EPOCHS} epochs ({total_steps} optimizer steps)\")\n",
361
  "print(f\"Batch: {BATCH_SIZE} \u00d7 {GRAD_ACCUM} accum = {BATCH_SIZE*GRAD_ACCUM} effective\")\n",
362
- "print(f\"Iterations per step: random from [3, 4, 5]\")\n",
363
  "print()\n",
364
  "\n",
365
  "# \u2500\u2500\u2500 Training Loop \u2500\u2500\u2500\n",
@@ -379,7 +380,7 @@
379
  " text_emb = text_emb.to(device, non_blocking=True)\n",
380
  "\n",
381
  " with torch.amp.autocast('cuda', dtype=torch.float16):\n",
382
- " r = [3, 4, 5][torch.randint(0, 3, (1,)).item()]\n",
383
  " result = iris.train_step_latent(z_0, text_emb, num_iterations=r)\n",
384
  " loss = result[\"loss\"] / GRAD_ACCUM\n",
385
  "\n",
@@ -431,7 +432,7 @@
431
  "\n",
432
  "iris.eval()\n",
433
  "fig, axes = plt.subplots(len(prompts), 4, figsize=(16, len(prompts)*4))\n",
434
- "iter_counts = [2, 4, 6, 8]\n",
435
  "\n",
436
  "for row, prompt in enumerate(prompts):\n",
437
  " text_emb = encode_text([prompt])\n",
 
282
  "metadata": {},
283
  "source": [
284
  "# Create IRIS-Tiny (best for free-tier)\n",
285
+ "# patch_size=4 reduces tokens from 256 to 64 \u2192 4\u00d7 faster training\n",
286
  "config = IRISConfig(\n",
287
  " latent_channels=4, # SD-VAE standard\n",
288
  " latent_spatial=32, # 256px / 8\n",
 
291
  " head_dim=64,\n",
292
  " ffn_ratio=2.667,\n",
293
  " num_prelude_blocks=1,\n",
294
+ " num_core_layers=2, # 2 layers (speed vs quality tradeoff for demo)\n",
295
  " num_coda_blocks=1,\n",
296
+ " default_iterations=4,\n",
297
  " max_iterations=16,\n",
298
  " fourier_num_blocks=6,\n",
299
  " sparsity_threshold=0.01,\n",
300
  " recurrence_dim=192,\n",
301
+ " manhattan_window=8,\n",
302
  " text_dim=768,\n",
303
  " max_text_tokens=77,\n",
304
+ " patch_size=4, # 4\u00d7 larger patches \u2192 64 tokens instead of 256\n",
305
  ")\n",
306
  "\n",
307
  "iris = IRIS(config).to(device)\n",
 
310
  "\n",
311
  "print(f\"IRIS Generator: {gen_params:,} params ({gen_params*2/1024/1024:.1f} MB fp16)\")\n",
312
  "print(f\" Core (shared): {core_params:,} ({core_params/gen_params*100:.1f}%)\")\n",
313
+ "print(f\" Tokens: {config.num_patches} (from {config.latent_spatial}\u00d7{config.latent_spatial} latent, patch_size={config.patch_size})\")\n",
314
  "print(f\" Input: [B, 4, 32, 32] latent \u2192 Output: [B, 4, 32, 32] velocity\")"
315
  ],
316
  "outputs": [],
 
360
  "\n",
361
  "print(f\"Training for {EPOCHS} epochs ({total_steps} optimizer steps)\")\n",
362
  "print(f\"Batch: {BATCH_SIZE} \u00d7 {GRAD_ACCUM} accum = {BATCH_SIZE*GRAD_ACCUM} effective\")\n",
363
+ "print(f\"Iterations per step: random from [2, 3, 4]\")\n",
364
  "print()\n",
365
  "\n",
366
  "# \u2500\u2500\u2500 Training Loop \u2500\u2500\u2500\n",
 
380
  " text_emb = text_emb.to(device, non_blocking=True)\n",
381
  "\n",
382
  " with torch.amp.autocast('cuda', dtype=torch.float16):\n",
383
+ " r = [2, 3, 4][torch.randint(0, 3, (1,)).item()]\n",
384
  " result = iris.train_step_latent(z_0, text_emb, num_iterations=r)\n",
385
  " loss = result[\"loss\"] / GRAD_ACCUM\n",
386
  "\n",
 
432
  "\n",
433
  "iris.eval()\n",
434
  "fig, axes = plt.subplots(len(prompts), 4, figsize=(16, len(prompts)*4))\n",
435
+ "iter_counts = [2, 3, 4, 6]\n",
436
  "\n",
437
  "for row, prompt in enumerate(prompts):\n",
438
  " text_emb = encode_text([prompt])\n",