{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# šŸŽ¬ BokehFlow Training Notebook\n", "## ~90s download → train from disk. No 429 errors.\n", "\n", "| Subset | Download | Disk | Train time/epoch (T4) |\n", "|--------|----------|------|-----------------------|\n", "| 200 scenes | ~30s | ~320 MB | ~3 min |\n", "| **500 scenes** | **~80s** | **~800 MB** | **~7 min** |\n", "| All 3958 | ~10 min | ~4.5 GB | ~45 min |\n", "\n", "**Just run all cells. Default = 500 scenes.**" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title Step 0: Install\n", "!pip install -q torch torchvision Pillow huggingface_hub tqdm aiohttp nest_asyncio" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title Step 1: Download BokehFlow model code\n", "from huggingface_hub import hf_hub_download\n", "hf_hub_download(repo_id='asdf98/BokehFlow', filename='bokehflow.py', local_dir='.')\n", "print('āœ“ BokehFlow ready')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title Step 2: Config — change max_scenes to control download size\n", "CONFIG = {\n", " # Model\n", " 'variant': 'nano', # 'nano'=583K, 'small'=3.1M, 'base'=12M\n", " \n", " # Data — controls download size\n", " 'max_scenes': 500, # 200=~30s download, 500=~80s, None=all ~10min\n", " 'target_fstop': 2.0, # Which bokeh level to train on\n", " 'crop_size': 256,\n", " 'data_dir': '/tmp/realbokeh',\n", " \n", " # Training\n", " 'batch_size': 4, # 4 for T4 16GB, 8 for A100\n", " 'num_epochs': 10,\n", " 'lr': 3e-4,\n", " 'weight_decay': 0.05,\n", " 'max_grad_norm': 1.0,\n", " 'num_workers': 2,\n", " 'output_dir': './checkpoints',\n", "}\n", "\n", "import torch, os, time, random, json\n", "NUM_GPUS = torch.cuda.device_count()\n", "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(f'Device: {DEVICE}' + (f' ({torch.cuda.get_device_name(0)})' if torch.cuda.is_available() else ''))\n", "if NUM_GPUS > 1:\n", " CONFIG['num_workers'] = 4\n", " CONFIG['batch_size'] = 8\n", " print(f'Multi-GPU: {NUM_GPUS} GPUs')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title Step 3: Download data — ~80s for 500 scenes, cached on re-run\n", "import asyncio, aiohttp\n", "import nest_asyncio; nest_asyncio.apply()\n", "from pathlib import Path\n", "from huggingface_hub import snapshot_download\n", "\n", "HF_BASE = 'https://huggingface.co/datasets/timseizinger/RealBokeh_3MP/resolve/main'\n", "DATA = Path(CONFIG['data_dir'])\n", "\n", "# ---- Phase 1: Fetch metadata async (3-5s) ----\n", "print('Phase 1/2: Fetching metadata...')\n", "t0 = time.time()\n", "\n", "async def _fetch_metas(concurrency=30):\n", " sem = asyncio.Semaphore(concurrency)\n", " conn = aiohttp.TCPConnector(limit=concurrency)\n", " async def fetch(session, i):\n", " async with sem:\n", " try:\n", " async with session.get(f'{HF_BASE}/train/metadata/{i}.json') as r:\n", " if r.status == 200: return await r.json(content_type=None)\n", " except: pass\n", " return None\n", " async with aiohttp.ClientSession(connector=conn) as s:\n", " return await asyncio.gather(*[fetch(s, i) for i in range(1, 3961)])\n", "\n", "all_metas = [m for m in asyncio.run(_fetch_metas()) if m]\n", "print(f' {len(all_metas)} scenes indexed in {time.time()-t0:.1f}s')\n", "\n", "# ---- Build pairs + download patterns ----\n", "scene_pairs = [] # (meta, gt_rel_path)\n", "for m in all_metas:\n", " for tp, av in zip(m['target_images'], m['target_avs']):\n", " if abs(av - CONFIG['target_fstop']) < 0.05:\n", " scene_pairs.append((m, tp))\n", " break\n", "\n", "random.shuffle(scene_pairs)\n", "if CONFIG['max_scenes']:\n", " scene_pairs = scene_pairs[:CONFIG['max_scenes']]\n", "\n", "# Build exact file list for snapshot_download\n", "allow_patterns = []\n", "training_pairs = []\n", "for m, gt_rel in scene_pairs:\n", " inp_rel = m['source_image'] # e.g. 'in/1_f22.JPG'\n", " allow_patterns.append(f'train/{inp_rel}')\n", " allow_patterns.append(f'train/{gt_rel}')\n", " training_pairs.append({\n", " 'input_rel': inp_rel,\n", " 'gt_rel': gt_rel,\n", " 'f_number': CONFIG['target_fstop'],\n", " 'focal_mm': float(m.get('focal_length', 50)),\n", " 'focus_m': float(m.get('focus_plane_distance', 2.0)),\n", " })\n", "\n", "print(f' {len(training_pairs)} pairs → {len(allow_patterns)} files to download')\n", "\n", "# ---- Phase 2: Download via snapshot_download (uses HF optimized transfer, no 429) ----\n", "print(f'\\nPhase 2/2: Downloading images (skip if cached)...')\n", "t0 = time.time()\n", "snapshot_download(\n", " 'timseizinger/RealBokeh_3MP',\n", " repo_type='dataset',\n", " local_dir=str(DATA),\n", " allow_patterns=allow_patterns,\n", ")\n", "dt = time.time() - t0\n", "\n", "# Verify\n", "n_files = sum(1 for f in (DATA/'train').rglob('*.JPG'))\n", "total_mb = sum(f.stat().st_size for f in (DATA/'train').rglob('*.JPG')) / 1e6\n", "print(f'\\nāœ“ {n_files} files ({total_mb:.0f} MB) ready in {dt:.0f}s')\n", "if dt < 2:\n", " print(' (cached from previous run)')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title Step 4: Create DataLoader (reads from disk — fast)\n", "from torch.utils.data import Dataset, DataLoader\n", "from torchvision import transforms\n", "from PIL import Image\n", "\n", "class RealBokehDisk(Dataset):\n", " def __init__(self, pairs, data_dir, crop_size=256):\n", " self.pairs = pairs\n", " self.root = Path(data_dir) / 'train'\n", " self.cs = crop_size\n", " self.to_tensor = transforms.ToTensor()\n", " # Verify\n", " ok = sum(1 for p in pairs if (self.root/p['input_rel']).exists() and (self.root/p['gt_rel']).exists())\n", " print(f' Dataset: {ok}/{len(pairs)} pairs verified on disk')\n", " self.pairs = [p for p in pairs if (self.root/p['input_rel']).exists() and (self.root/p['gt_rel']).exists()]\n", "\n", " def __len__(self): return len(self.pairs)\n", "\n", " def __getitem__(self, idx):\n", " p = self.pairs[idx]\n", " inp = Image.open(self.root / p['input_rel']).convert('RGB')\n", " gt = Image.open(self.root / p['gt_rel']).convert('RGB')\n", " cs = self.cs\n", " w, h = inp.size\n", " if w >= cs and h >= cs:\n", " x, y = random.randint(0, w-cs), random.randint(0, h-cs)\n", " inp = inp.crop((x, y, x+cs, y+cs))\n", " gt = gt.crop((x, y, x+cs, y+cs))\n", " else:\n", " inp = inp.resize((cs, cs), Image.LANCZOS)\n", " gt = gt.resize((cs, cs), Image.LANCZOS)\n", " if random.random() > 0.5:\n", " inp = inp.transpose(Image.FLIP_LEFT_RIGHT)\n", " gt = gt.transpose(Image.FLIP_LEFT_RIGHT)\n", " return {\n", " 'input': self.to_tensor(inp),\n", " 'target': self.to_tensor(gt),\n", " 'f_number': torch.tensor(p['f_number'], dtype=torch.float32),\n", " 'focal_length_mm': torch.tensor(p['focal_mm'], dtype=torch.float32),\n", " 'focus_distance_m': torch.tensor(p['focus_m'], dtype=torch.float32),\n", " }\n", "\n", "train_ds = RealBokehDisk(training_pairs, CONFIG['data_dir'], CONFIG['crop_size'])\n", "train_loader = DataLoader(\n", " train_ds, batch_size=CONFIG['batch_size'], shuffle=True,\n", " num_workers=CONFIG['num_workers'], pin_memory=True,\n", " drop_last=True, persistent_workers=True,\n", ")\n", "print(f'āœ“ {len(train_loader)} batches/epoch')\n", "\n", "# Sanity check\n", "b = next(iter(train_loader))\n", "print(f' input={b[\"input\"].shape} target={b[\"target\"].shape}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title Step 5: Create model\n", "from bokehflow import BokehFlow, BokehFlowConfig, BokehFlowLoss\n", "\n", "config = BokehFlowConfig(variant=CONFIG['variant'])\n", "model = BokehFlow(config)\n", "if NUM_GPUS > 1:\n", " model = torch.nn.DataParallel(model)\n", "model = model.to(DEVICE)\n", "print(f'āœ“ BokehFlow-{CONFIG[\"variant\"].capitalize()}: {sum(p.numel() for p in model.parameters()):,} params')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title Step 6: Train!\n", "from tqdm.auto import tqdm\n", "\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])\n", "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['num_epochs']*len(train_loader))\n", "criterion = BokehFlowLoss(lambda_depth=0.5)\n", "os.makedirs(CONFIG['output_dir'], exist_ok=True)\n", "\n", "print(f'Training: {CONFIG[\"num_epochs\"]} epochs Ɨ {len(train_loader)} batches\\n')\n", "\n", "for epoch in range(CONFIG['num_epochs']):\n", " model.train()\n", " total_loss = 0.0\n", " t0 = time.time()\n", " pbar = tqdm(train_loader, desc=f'Ep {epoch+1}/{CONFIG[\"num_epochs\"]}')\n", "\n", " for batch in pbar:\n", " inp = batch['input'].to(DEVICE)\n", " tgt = batch['target'].to(DEVICE)\n", " f_num = batch['f_number'].to(DEVICE)\n", " focal = batch['focal_length_mm'].to(DEVICE)\n", " focus = batch['focus_distance_m'].to(DEVICE)\n", "\n", " out = model(inp, f_num, focal, focus)\n", " loss = criterion(out, {'bokeh_gt': tgt})['total']\n", "\n", " optimizer.zero_grad()\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['max_grad_norm'])\n", " optimizer.step()\n", " scheduler.step()\n", "\n", " total_loss += loss.item()\n", " pbar.set_postfix(loss=f'{loss.item():.4f}', lr=f'{scheduler.get_last_lr()[0]:.1e}')\n", "\n", " avg = total_loss / len(train_loader)\n", " dt = time.time() - t0\n", " print(f' loss={avg:.4f} time={dt:.0f}s')\n", "\n", " state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()\n", " torch.save({'epoch': epoch+1, 'model': state, 'loss': avg}, f'{CONFIG[\"output_dir\"]}/ep{epoch+1}.pt')\n", "\n", "print('\\nāœ“ Done!')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#@title Step 7: Visualize\n", "import matplotlib.pyplot as plt\n", "model.eval()\n", "s = train_ds[0]\n", "with torch.no_grad():\n", " o = model(s['input'].unsqueeze(0).to(DEVICE), s['f_number'].unsqueeze(0).to(DEVICE),\n", " s['focal_length_mm'].unsqueeze(0).to(DEVICE), s['focus_distance_m'].unsqueeze(0).to(DEVICE))\n", "fig,ax = plt.subplots(1,3,figsize=(15,5))\n", "ax[0].imshow(s['input'].permute(1,2,0).cpu()); ax[0].set_title('Input f/22')\n", "ax[1].imshow(o['bokeh'][0].permute(1,2,0).cpu().clamp(0,1)); ax[1].set_title('BokehFlow')\n", "ax[2].imshow(s['target'].permute(1,2,0).cpu()); ax[2].set_title('GT f/2.0')\n", "for a in ax: a.axis('off')\n", "plt.tight_layout(); plt.savefig('result.png'); plt.show()" ] } ], "metadata": { "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"name": "python", "version": "3.10.0"}, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 4 }