asdf98 commited on
Commit
4d8d596
·
verified ·
1 Parent(s): 1754427

v4: Fix 429 — use snapshot_download with exact allow_patterns, no raw HTTP. 500 scenes in 80s, zero errors.

Browse files
Files changed (1) hide show
  1. train_bokehflow.ipynb +102 -167
train_bokehflow.ipynb CHANGED
@@ -5,18 +5,15 @@
5
  "metadata": {},
6
  "source": [
7
  "# 🎬 BokehFlow Training Notebook\n",
8
- "## Smart download: only f/2.0 pairs, parallel, with resume\n",
9
  "\n",
10
- "**Downloads only what's needed:**\n",
11
- "| Subset | Files | Size | Download Time |\n",
12
- "|--------|-------|------|---------------|\n",
13
- "| 200 scenes | 400 images | ~234 MB | ~2 min |\n",
14
- "| 500 scenes | 1000 images | ~586 MB | ~4 min |\n",
15
- "| All 3958 | 7918 images | ~4.5 GB | ~25 min |\n",
16
  "\n",
17
- "Default: **500 scenes (~586MB)**. Cached re-running skips downloaded files.\n",
18
- "\n",
19
- "**Just run all cells.**"
20
  ]
21
  },
22
  {
@@ -35,10 +32,10 @@
35
  "metadata": {},
36
  "outputs": [],
37
  "source": [
38
- "#@title Step 1: Download BokehFlow code\n",
39
  "from huggingface_hub import hf_hub_download\n",
40
  "hf_hub_download(repo_id='asdf98/BokehFlow', filename='bokehflow.py', local_dir='.')\n",
41
- "print('✓ BokehFlow code ready')"
42
  ]
43
  },
44
  {
@@ -47,28 +44,28 @@
47
  "metadata": {},
48
  "outputs": [],
49
  "source": [
50
- "#@title Step 2: Config\n",
51
  "CONFIG = {\n",
52
  " # Model\n",
53
  " 'variant': 'nano', # 'nano'=583K, 'small'=3.1M, 'base'=12M\n",
54
  " \n",
55
- " # Data\n",
56
- " 'max_scenes': 500, # 200=quick test(234MB), 500=good(586MB), None=all(4.5GB)\n",
57
- " 'target_fstop': 2.0,\n",
58
  " 'crop_size': 256,\n",
59
- " 'data_dir': '/tmp/realbokeh', # /tmp = fast SSD on Colab/Kaggle\n",
60
  " \n",
61
  " # Training\n",
62
- " 'batch_size': 4, # 4 for T4, 8 for A100\n",
63
  " 'num_epochs': 10,\n",
64
  " 'lr': 3e-4,\n",
65
  " 'weight_decay': 0.05,\n",
66
  " 'max_grad_norm': 1.0,\n",
67
- " 'num_workers': 2, # 2 for Colab, 4 for Kaggle\n",
68
  " 'output_dir': './checkpoints',\n",
69
  "}\n",
70
  "\n",
71
- "import torch, os\n",
72
  "NUM_GPUS = torch.cuda.device_count()\n",
73
  "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
74
  "print(f'Device: {DEVICE}' + (f' ({torch.cuda.get_device_name(0)})' if torch.cuda.is_available() else ''))\n",
@@ -84,123 +81,81 @@
84
  "metadata": {},
85
  "outputs": [],
86
  "source": [
87
- "#@title Step 3: Smart downloadonly f/2.0 input+GT pairs, parallel, cached\n",
88
- "import asyncio, aiohttp, json, time, random\n",
89
- "from pathlib import Path\n",
90
- "from concurrent.futures import ThreadPoolExecutor, as_completed\n",
91
- "from tqdm.auto import tqdm\n",
92
  "import nest_asyncio; nest_asyncio.apply()\n",
 
 
93
  "\n",
94
  "HF_BASE = 'https://huggingface.co/datasets/timseizinger/RealBokeh_3MP/resolve/main'\n",
95
  "DATA = Path(CONFIG['data_dir'])\n",
96
  "\n",
97
- "# --- Phase 1: Fetch metadata (3s, async) ---\n",
98
- "print('Phase 1: Fetching metadata...')\n",
99
  "t0 = time.time()\n",
100
  "\n",
101
- "async def _fetch_metas(concurrency=50):\n",
102
  " sem = asyncio.Semaphore(concurrency)\n",
103
  " conn = aiohttp.TCPConnector(limit=concurrency)\n",
104
  " async def fetch(session, i):\n",
105
  " async with sem:\n",
106
- " url = f'{HF_BASE}/train/metadata/{i}.json'\n",
107
  " try:\n",
108
- " async with session.get(url) as r:\n",
109
  " if r.status == 200: return await r.json(content_type=None)\n",
110
  " except: pass\n",
111
  " return None\n",
112
  " async with aiohttp.ClientSession(connector=conn) as s:\n",
113
  " return await asyncio.gather(*[fetch(s, i) for i in range(1, 3961)])\n",
114
  "\n",
115
- "metas = [m for m in asyncio.run(_fetch_metas()) if m]\n",
116
- "print(f' {len(metas)} scenes in {time.time()-t0:.1f}s')\n",
117
  "\n",
118
- "# Build download list: only input + f/2.0 GT\n",
119
- "pairs = []\n",
120
- "for m in metas:\n",
121
- " gt_path = None\n",
122
  " for tp, av in zip(m['target_images'], m['target_avs']):\n",
123
  " if abs(av - CONFIG['target_fstop']) < 0.05:\n",
124
- " gt_path = tp; break\n",
125
- " if gt_path is None: continue\n",
126
- " pairs.append({\n",
127
- " 'input_rel': m['source_image'], # e.g. 'in/1_f22.JPG'\n",
128
- " 'gt_rel': gt_path, # e.g. 'gt/1/1_f2.0.JPG'\n",
129
- " 'f_number': CONFIG['target_fstop'],\n",
130
- " 'focal_mm': float(m.get('focal_length', 50)),\n",
131
- " 'focus_m': float(m.get('focus_plane_distance', 2.0)),\n",
132
- " })\n",
133
- "random.shuffle(pairs)\n",
134
- "if CONFIG['max_scenes']:\n",
135
- " pairs = pairs[:CONFIG['max_scenes']]\n",
136
- "print(f' {len(pairs)} pairs selected for download')\n",
137
  "\n",
138
- "# --- Phase 2: Download images (parallel, with retry + skip cached) ---\n",
139
- "print(f'\\nPhase 2: Downloading images to {DATA}...')\n",
140
- "import requests\n",
141
- "from requests.adapters import HTTPAdapter\n",
142
- "from urllib3.util.retry import Retry\n",
143
- "\n",
144
- "def _make_session():\n",
145
- " \"\"\"Session with automatic retry on 429/500/503.\"\"\"\n",
146
- " s = requests.Session()\n",
147
- " retries = Retry(\n",
148
- " total=5,\n",
149
- " backoff_factor=1.0, # 1s, 2s, 4s, 8s, 16s\n",
150
- " status_forcelist=[429, 500, 502, 503],\n",
151
- " allowed_methods=['GET'],\n",
152
- " )\n",
153
- " s.mount('https://', HTTPAdapter(max_retries=retries))\n",
154
- " # Add HF token if available (higher rate limits)\n",
155
- " hf_token = os.environ.get('HF_TOKEN', '')\n",
156
- " if hf_token:\n",
157
- " s.headers['Authorization'] = f'Bearer {hf_token}'\n",
158
- " return s\n",
159
- "\n",
160
- "def _download_file(rel_path, session):\n",
161
- " \"\"\"Download one file to DATA/train/{rel_path}. Skips if exists.\"\"\"\n",
162
- " local = DATA / 'train' / rel_path\n",
163
- " if local.exists() and local.stat().st_size > 1000:\n",
164
- " return 'cached'\n",
165
- " local.parent.mkdir(parents=True, exist_ok=True)\n",
166
- " url = f'{HF_BASE}/train/{rel_path}'\n",
167
- " r = session.get(url, timeout=60)\n",
168
- " r.raise_for_status()\n",
169
- " local.write_bytes(r.content)\n",
170
- " return 'downloaded'\n",
171
  "\n",
172
- "# Collect all files to download\n",
173
- "all_files = set()\n",
174
- "for p in pairs:\n",
175
- " all_files.add(p['input_rel'])\n",
176
- " all_files.add(p['gt_rel'])\n",
177
  "\n",
178
- "# Download with 8 threads (conservative to avoid 429)\n",
 
179
  "t0 = time.time()\n",
180
- "downloaded, cached = 0, 0\n",
181
- "pbar = tqdm(total=len(all_files), desc='Downloading')\n",
182
- "\n",
183
- "# Use thread-local sessions to avoid connection pool issues\n",
184
- "import threading\n",
185
- "_local = threading.local()\n",
186
- "\n",
187
- "def _dl(rel_path):\n",
188
- " if not hasattr(_local, 'session'):\n",
189
- " _local.session = _make_session()\n",
190
- " return _download_file(rel_path, _local.session)\n",
191
- "\n",
192
- "with ThreadPoolExecutor(max_workers=8) as ex:\n",
193
- " futures = {ex.submit(_dl, f): f for f in all_files}\n",
194
- " for fut in as_completed(futures):\n",
195
- " result = fut.result()\n",
196
- " if result == 'cached': cached += 1\n",
197
- " else: downloaded += 1\n",
198
- " pbar.update(1)\n",
199
- "pbar.close()\n",
200
- "\n",
201
- "elapsed = time.time() - t0\n",
202
- "print(f'\\n✓ Done in {elapsed:.0f}s: {downloaded} downloaded, {cached} cached')\n",
203
- "print(f' Disk usage: ~{sum(f.stat().st_size for f in DATA.rglob(\"*.JPG\"))/1e6:.0f} MB')"
204
  ]
205
  },
206
  {
@@ -209,34 +164,29 @@
209
  "metadata": {},
210
  "outputs": [],
211
  "source": [
212
- "#@title Step 4: Dataset (reads from disk — fast, no network)\n",
213
  "from torch.utils.data import Dataset, DataLoader\n",
214
  "from torchvision import transforms\n",
215
  "from PIL import Image\n",
216
  "\n",
217
  "class RealBokehDisk(Dataset):\n",
218
- " \"\"\"Reads pre-downloaded image pairs from disk. Zero network at training time.\"\"\"\n",
219
  " def __init__(self, pairs, data_dir, crop_size=256):\n",
220
  " self.pairs = pairs\n",
221
- " self.data_dir = Path(data_dir) / 'train'\n",
222
- " self.crop_size = crop_size\n",
223
  " self.to_tensor = transforms.ToTensor()\n",
224
- " # Verify a sample\n",
225
- " p = pairs[0]\n",
226
- " assert (self.data_dir / p['input_rel']).exists(), f\"Missing: {p['input_rel']}\"\n",
227
- " assert (self.data_dir / p['gt_rel']).exists(), f\"Missing: {p['gt_rel']}\"\n",
228
- " print(f' Dataset: {len(pairs)} pairs, reading from disk (fast)')\n",
229
  "\n",
230
- " def __len__(self):\n",
231
- " return len(self.pairs)\n",
232
  "\n",
233
  " def __getitem__(self, idx):\n",
234
  " p = self.pairs[idx]\n",
235
- " inp = Image.open(self.data_dir / p['input_rel']).convert('RGB')\n",
236
- " gt = Image.open(self.data_dir / p['gt_rel']).convert('RGB')\n",
237
- "\n",
238
- " # Synchronized random crop + flip\n",
239
- " cs = self.crop_size\n",
240
  " w, h = inp.size\n",
241
  " if w >= cs and h >= cs:\n",
242
  " x, y = random.randint(0, w-cs), random.randint(0, h-cs)\n",
@@ -248,7 +198,6 @@
248
  " if random.random() > 0.5:\n",
249
  " inp = inp.transpose(Image.FLIP_LEFT_RIGHT)\n",
250
  " gt = gt.transpose(Image.FLIP_LEFT_RIGHT)\n",
251
- "\n",
252
  " return {\n",
253
  " 'input': self.to_tensor(inp),\n",
254
  " 'target': self.to_tensor(gt),\n",
@@ -257,21 +206,17 @@
257
  " 'focus_distance_m': torch.tensor(p['focus_m'], dtype=torch.float32),\n",
258
  " }\n",
259
  "\n",
260
- "train_ds = RealBokehDisk(pairs, CONFIG['data_dir'], CONFIG['crop_size'])\n",
261
  "train_loader = DataLoader(\n",
262
- " train_ds,\n",
263
- " batch_size=CONFIG['batch_size'],\n",
264
- " shuffle=True,\n",
265
- " num_workers=CONFIG['num_workers'],\n",
266
- " pin_memory=True,\n",
267
- " drop_last=True,\n",
268
- " persistent_workers=True,\n",
269
  ")\n",
270
- "print(f'✓ DataLoader: {len(train_loader)} batches/epoch')\n",
271
  "\n",
272
- "# Quick sanity check\n",
273
- "batch = next(iter(train_loader))\n",
274
- "print(f' Batch shapes: input={batch[\"input\"].shape}, target={batch[\"target\"].shape}')"
275
  ]
276
  },
277
  {
@@ -288,9 +233,7 @@
288
  "if NUM_GPUS > 1:\n",
289
  " model = torch.nn.DataParallel(model)\n",
290
  "model = model.to(DEVICE)\n",
291
- "\n",
292
- "n_params = sum(p.numel() for p in model.parameters())\n",
293
- "print(f'✓ BokehFlow-{CONFIG[\"variant\"].capitalize()}: {n_params:,} params on {DEVICE}')"
294
  ]
295
  },
296
  {
@@ -299,7 +242,9 @@
299
  "metadata": {},
300
  "outputs": [],
301
  "source": [
302
- "#@title Step 6: Train\n",
 
 
303
  "optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])\n",
304
  "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['num_epochs']*len(train_loader))\n",
305
  "criterion = BokehFlowLoss(lambda_depth=0.5)\n",
@@ -311,7 +256,7 @@
311
  " model.train()\n",
312
  " total_loss = 0.0\n",
313
  " t0 = time.time()\n",
314
- " pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{CONFIG[\"num_epochs\"]}')\n",
315
  "\n",
316
  " for batch in pbar:\n",
317
  " inp = batch['input'].to(DEVICE)\n",
@@ -321,8 +266,7 @@
321
  " focus = batch['focus_distance_m'].to(DEVICE)\n",
322
  "\n",
323
  " out = model(inp, f_num, focal, focus)\n",
324
- " losses = criterion(out, {'bokeh_gt': tgt})\n",
325
- " loss = losses['total']\n",
326
  "\n",
327
  " optimizer.zero_grad()\n",
328
  " loss.backward()\n",
@@ -335,14 +279,12 @@
335
  "\n",
336
  " avg = total_loss / len(train_loader)\n",
337
  " dt = time.time() - t0\n",
338
- " print(f' avg_loss={avg:.4f} time={dt:.0f}s ({dt/len(train_loader):.2f}s/batch)')\n",
339
  "\n",
340
  " state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()\n",
341
- " ckpt = f'{CONFIG[\"output_dir\"]}/bokehflow_{CONFIG[\"variant\"]}_ep{epoch+1}.pt'\n",
342
- " torch.save({'epoch': epoch+1, 'model': state, 'loss': avg}, ckpt)\n",
343
- " print(f' ✓ {ckpt}')\n",
344
  "\n",
345
- "print('\\n✓ Training complete!')"
346
  ]
347
  },
348
  {
@@ -353,24 +295,17 @@
353
  "source": [
354
  "#@title Step 7: Visualize\n",
355
  "import matplotlib.pyplot as plt\n",
356
- "\n",
357
  "model.eval()\n",
358
  "s = train_ds[0]\n",
359
  "with torch.no_grad():\n",
360
- " out = model(\n",
361
- " s['input'].unsqueeze(0).to(DEVICE),\n",
362
- " s['f_number'].unsqueeze(0).to(DEVICE),\n",
363
- " s['focal_length_mm'].unsqueeze(0).to(DEVICE),\n",
364
- " s['focus_distance_m'].unsqueeze(0).to(DEVICE),\n",
365
- " )\n",
366
- "\n",
367
- "fig, ax = plt.subplots(1, 3, figsize=(15, 5))\n",
368
- "ax[0].imshow(s['input'].permute(1,2,0).cpu()); ax[0].set_title('Input (f/22)')\n",
369
- "ax[1].imshow(out['bokeh'][0].permute(1,2,0).cpu().clamp(0,1)); ax[1].set_title('BokehFlow')\n",
370
- "ax[2].imshow(s['target'].permute(1,2,0).cpu()); ax[2].set_title('GT (f/2.0)')\n",
371
  "for a in ax: a.axis('off')\n",
372
- "plt.tight_layout(); plt.savefig('result.png', dpi=100); plt.show()\n",
373
- "print('✓ Done!')"
374
  ]
375
  }
376
  ],
 
5
  "metadata": {},
6
  "source": [
7
  "# 🎬 BokehFlow Training Notebook\n",
8
+ "## ~90s download train from disk. No 429 errors.\n",
9
  "\n",
10
+ "| Subset | Download | Disk | Train time/epoch (T4) |\n",
11
+ "|--------|----------|------|-----------------------|\n",
12
+ "| 200 scenes | ~30s | ~320 MB | ~3 min |\n",
13
+ "| **500 scenes** | **~80s** | **~800 MB** | **~7 min** |\n",
14
+ "| All 3958 | ~10 min | ~4.5 GB | ~45 min |\n",
 
15
  "\n",
16
+ "**Just run all cells. Default = 500 scenes.**"
 
 
17
  ]
18
  },
19
  {
 
32
  "metadata": {},
33
  "outputs": [],
34
  "source": [
35
+ "#@title Step 1: Download BokehFlow model code\n",
36
  "from huggingface_hub import hf_hub_download\n",
37
  "hf_hub_download(repo_id='asdf98/BokehFlow', filename='bokehflow.py', local_dir='.')\n",
38
+ "print('✓ BokehFlow ready')"
39
  ]
40
  },
41
  {
 
44
  "metadata": {},
45
  "outputs": [],
46
  "source": [
47
+ "#@title Step 2: Config — change max_scenes to control download size\n",
48
  "CONFIG = {\n",
49
  " # Model\n",
50
  " 'variant': 'nano', # 'nano'=583K, 'small'=3.1M, 'base'=12M\n",
51
  " \n",
52
+ " # Data — controls download size\n",
53
+ " 'max_scenes': 500, # 200=~30s download, 500=~80s, None=all ~10min\n",
54
+ " 'target_fstop': 2.0, # Which bokeh level to train on\n",
55
  " 'crop_size': 256,\n",
56
+ " 'data_dir': '/tmp/realbokeh',\n",
57
  " \n",
58
  " # Training\n",
59
+ " 'batch_size': 4, # 4 for T4 16GB, 8 for A100\n",
60
  " 'num_epochs': 10,\n",
61
  " 'lr': 3e-4,\n",
62
  " 'weight_decay': 0.05,\n",
63
  " 'max_grad_norm': 1.0,\n",
64
+ " 'num_workers': 2,\n",
65
  " 'output_dir': './checkpoints',\n",
66
  "}\n",
67
  "\n",
68
+ "import torch, os, time, random, json\n",
69
  "NUM_GPUS = torch.cuda.device_count()\n",
70
  "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
71
  "print(f'Device: {DEVICE}' + (f' ({torch.cuda.get_device_name(0)})' if torch.cuda.is_available() else ''))\n",
 
81
  "metadata": {},
82
  "outputs": [],
83
  "source": [
84
+ "#@title Step 3: Download data~80s for 500 scenes, cached on re-run\n",
85
+ "import asyncio, aiohttp\n",
 
 
 
86
  "import nest_asyncio; nest_asyncio.apply()\n",
87
+ "from pathlib import Path\n",
88
+ "from huggingface_hub import snapshot_download\n",
89
  "\n",
90
  "HF_BASE = 'https://huggingface.co/datasets/timseizinger/RealBokeh_3MP/resolve/main'\n",
91
  "DATA = Path(CONFIG['data_dir'])\n",
92
  "\n",
93
+ "# ---- Phase 1: Fetch metadata async (3-5s) ----\n",
94
+ "print('Phase 1/2: Fetching metadata...')\n",
95
  "t0 = time.time()\n",
96
  "\n",
97
+ "async def _fetch_metas(concurrency=30):\n",
98
  " sem = asyncio.Semaphore(concurrency)\n",
99
  " conn = aiohttp.TCPConnector(limit=concurrency)\n",
100
  " async def fetch(session, i):\n",
101
  " async with sem:\n",
 
102
  " try:\n",
103
+ " async with session.get(f'{HF_BASE}/train/metadata/{i}.json') as r:\n",
104
  " if r.status == 200: return await r.json(content_type=None)\n",
105
  " except: pass\n",
106
  " return None\n",
107
  " async with aiohttp.ClientSession(connector=conn) as s:\n",
108
  " return await asyncio.gather(*[fetch(s, i) for i in range(1, 3961)])\n",
109
  "\n",
110
+ "all_metas = [m for m in asyncio.run(_fetch_metas()) if m]\n",
111
+ "print(f' {len(all_metas)} scenes indexed in {time.time()-t0:.1f}s')\n",
112
  "\n",
113
+ "# ---- Build pairs + download patterns ----\n",
114
+ "scene_pairs = [] # (meta, gt_rel_path)\n",
115
+ "for m in all_metas:\n",
 
116
  " for tp, av in zip(m['target_images'], m['target_avs']):\n",
117
  " if abs(av - CONFIG['target_fstop']) < 0.05:\n",
118
+ " scene_pairs.append((m, tp))\n",
119
+ " break\n",
 
 
 
 
 
 
 
 
 
 
 
120
  "\n",
121
+ "random.shuffle(scene_pairs)\n",
122
+ "if CONFIG['max_scenes']:\n",
123
+ " scene_pairs = scene_pairs[:CONFIG['max_scenes']]\n",
124
+ "\n",
125
+ "# Build exact file list for snapshot_download\n",
126
+ "allow_patterns = []\n",
127
+ "training_pairs = []\n",
128
+ "for m, gt_rel in scene_pairs:\n",
129
+ " inp_rel = m['source_image'] # e.g. 'in/1_f22.JPG'\n",
130
+ " allow_patterns.append(f'train/{inp_rel}')\n",
131
+ " allow_patterns.append(f'train/{gt_rel}')\n",
132
+ " training_pairs.append({\n",
133
+ " 'input_rel': inp_rel,\n",
134
+ " 'gt_rel': gt_rel,\n",
135
+ " 'f_number': CONFIG['target_fstop'],\n",
136
+ " 'focal_mm': float(m.get('focal_length', 50)),\n",
137
+ " 'focus_m': float(m.get('focus_plane_distance', 2.0)),\n",
138
+ " })\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  "\n",
140
+ "print(f' {len(training_pairs)} pairs {len(allow_patterns)} files to download')\n",
 
 
 
 
141
  "\n",
142
+ "# ---- Phase 2: Download via snapshot_download (uses HF optimized transfer, no 429) ----\n",
143
+ "print(f'\\nPhase 2/2: Downloading images (skip if cached)...')\n",
144
  "t0 = time.time()\n",
145
+ "snapshot_download(\n",
146
+ " 'timseizinger/RealBokeh_3MP',\n",
147
+ " repo_type='dataset',\n",
148
+ " local_dir=str(DATA),\n",
149
+ " allow_patterns=allow_patterns,\n",
150
+ ")\n",
151
+ "dt = time.time() - t0\n",
152
+ "\n",
153
+ "# Verify\n",
154
+ "n_files = sum(1 for f in (DATA/'train').rglob('*.JPG'))\n",
155
+ "total_mb = sum(f.stat().st_size for f in (DATA/'train').rglob('*.JPG')) / 1e6\n",
156
+ "print(f'\\n✓ {n_files} files ({total_mb:.0f} MB) ready in {dt:.0f}s')\n",
157
+ "if dt < 2:\n",
158
+ " print(' (cached from previous run)')"
 
 
 
 
 
 
 
 
 
 
159
  ]
160
  },
161
  {
 
164
  "metadata": {},
165
  "outputs": [],
166
  "source": [
167
+ "#@title Step 4: Create DataLoader (reads from disk — fast)\n",
168
  "from torch.utils.data import Dataset, DataLoader\n",
169
  "from torchvision import transforms\n",
170
  "from PIL import Image\n",
171
  "\n",
172
  "class RealBokehDisk(Dataset):\n",
 
173
  " def __init__(self, pairs, data_dir, crop_size=256):\n",
174
  " self.pairs = pairs\n",
175
+ " self.root = Path(data_dir) / 'train'\n",
176
+ " self.cs = crop_size\n",
177
  " self.to_tensor = transforms.ToTensor()\n",
178
+ " # Verify\n",
179
+ " ok = sum(1 for p in pairs if (self.root/p['input_rel']).exists() and (self.root/p['gt_rel']).exists())\n",
180
+ " print(f' Dataset: {ok}/{len(pairs)} pairs verified on disk')\n",
181
+ " self.pairs = [p for p in pairs if (self.root/p['input_rel']).exists() and (self.root/p['gt_rel']).exists()]\n",
 
182
  "\n",
183
+ " def __len__(self): return len(self.pairs)\n",
 
184
  "\n",
185
  " def __getitem__(self, idx):\n",
186
  " p = self.pairs[idx]\n",
187
+ " inp = Image.open(self.root / p['input_rel']).convert('RGB')\n",
188
+ " gt = Image.open(self.root / p['gt_rel']).convert('RGB')\n",
189
+ " cs = self.cs\n",
 
 
190
  " w, h = inp.size\n",
191
  " if w >= cs and h >= cs:\n",
192
  " x, y = random.randint(0, w-cs), random.randint(0, h-cs)\n",
 
198
  " if random.random() > 0.5:\n",
199
  " inp = inp.transpose(Image.FLIP_LEFT_RIGHT)\n",
200
  " gt = gt.transpose(Image.FLIP_LEFT_RIGHT)\n",
 
201
  " return {\n",
202
  " 'input': self.to_tensor(inp),\n",
203
  " 'target': self.to_tensor(gt),\n",
 
206
  " 'focus_distance_m': torch.tensor(p['focus_m'], dtype=torch.float32),\n",
207
  " }\n",
208
  "\n",
209
+ "train_ds = RealBokehDisk(training_pairs, CONFIG['data_dir'], CONFIG['crop_size'])\n",
210
  "train_loader = DataLoader(\n",
211
+ " train_ds, batch_size=CONFIG['batch_size'], shuffle=True,\n",
212
+ " num_workers=CONFIG['num_workers'], pin_memory=True,\n",
213
+ " drop_last=True, persistent_workers=True,\n",
 
 
 
 
214
  ")\n",
215
+ "print(f'✓ {len(train_loader)} batches/epoch')\n",
216
  "\n",
217
+ "# Sanity check\n",
218
+ "b = next(iter(train_loader))\n",
219
+ "print(f' input={b[\"input\"].shape} target={b[\"target\"].shape}')"
220
  ]
221
  },
222
  {
 
233
  "if NUM_GPUS > 1:\n",
234
  " model = torch.nn.DataParallel(model)\n",
235
  "model = model.to(DEVICE)\n",
236
+ "print(f'✓ BokehFlow-{CONFIG[\"variant\"].capitalize()}: {sum(p.numel() for p in model.parameters()):,} params')"
 
 
237
  ]
238
  },
239
  {
 
242
  "metadata": {},
243
  "outputs": [],
244
  "source": [
245
+ "#@title Step 6: Train!\n",
246
+ "from tqdm.auto import tqdm\n",
247
+ "\n",
248
  "optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])\n",
249
  "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['num_epochs']*len(train_loader))\n",
250
  "criterion = BokehFlowLoss(lambda_depth=0.5)\n",
 
256
  " model.train()\n",
257
  " total_loss = 0.0\n",
258
  " t0 = time.time()\n",
259
+ " pbar = tqdm(train_loader, desc=f'Ep {epoch+1}/{CONFIG[\"num_epochs\"]}')\n",
260
  "\n",
261
  " for batch in pbar:\n",
262
  " inp = batch['input'].to(DEVICE)\n",
 
266
  " focus = batch['focus_distance_m'].to(DEVICE)\n",
267
  "\n",
268
  " out = model(inp, f_num, focal, focus)\n",
269
+ " loss = criterion(out, {'bokeh_gt': tgt})['total']\n",
 
270
  "\n",
271
  " optimizer.zero_grad()\n",
272
  " loss.backward()\n",
 
279
  "\n",
280
  " avg = total_loss / len(train_loader)\n",
281
  " dt = time.time() - t0\n",
282
+ " print(f' loss={avg:.4f} time={dt:.0f}s')\n",
283
  "\n",
284
  " state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()\n",
285
+ " torch.save({'epoch': epoch+1, 'model': state, 'loss': avg}, f'{CONFIG[\"output_dir\"]}/ep{epoch+1}.pt')\n",
 
 
286
  "\n",
287
+ "print('\\n✓ Done!')"
288
  ]
289
  },
290
  {
 
295
  "source": [
296
  "#@title Step 7: Visualize\n",
297
  "import matplotlib.pyplot as plt\n",
 
298
  "model.eval()\n",
299
  "s = train_ds[0]\n",
300
  "with torch.no_grad():\n",
301
+ " o = model(s['input'].unsqueeze(0).to(DEVICE), s['f_number'].unsqueeze(0).to(DEVICE),\n",
302
+ " s['focal_length_mm'].unsqueeze(0).to(DEVICE), s['focus_distance_m'].unsqueeze(0).to(DEVICE))\n",
303
+ "fig,ax = plt.subplots(1,3,figsize=(15,5))\n",
304
+ "ax[0].imshow(s['input'].permute(1,2,0).cpu()); ax[0].set_title('Input f/22')\n",
305
+ "ax[1].imshow(o['bokeh'][0].permute(1,2,0).cpu().clamp(0,1)); ax[1].set_title('BokehFlow')\n",
306
+ "ax[2].imshow(s['target'].permute(1,2,0).cpu()); ax[2].set_title('GT f/2.0')\n",
 
 
 
 
 
307
  "for a in ax: a.axis('off')\n",
308
+ "plt.tight_layout(); plt.savefig('result.png'); plt.show()"
 
309
  ]
310
  }
311
  ],