asdf98 commited on
Commit
800dcfc
·
verified ·
1 Parent(s): 373e0ae

Add Colab/Kaggle training notebook — just run all cells

Browse files
Files changed (1) hide show
  1. train_bokehflow.ipynb +416 -0
train_bokehflow.ipynb ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🎬 BokehFlow Training Notebook\n",
8
+ "## Train on Free Colab T4 or Kaggle Dual-GPU\n",
9
+ "\n",
10
+ "**Just run all cells.** Default config trains BokehFlow-Nano on RealBokeh dataset.\n",
11
+ "\n",
12
+ "| Platform | GPU | VRAM | Expected Time (1 epoch) |\n",
13
+ "|----------|-----|------|------------------------|\n",
14
+ "| Colab Free | T4 | 16GB | ~45 min |\n",
15
+ "| Kaggle | 2×T4 | 2×16GB | ~25 min |\n",
16
+ "| Colab Pro | A100 | 40GB | ~10 min |"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "# ============================================================\n",
26
+ "# STEP 0: Install dependencies\n",
27
+ "# ============================================================\n",
28
+ "!pip install -q torch torchvision Pillow huggingface_hub tqdm"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "# ============================================================\n",
38
+ "# STEP 1: Download BokehFlow architecture\n",
39
+ "# ============================================================\n",
40
+ "from huggingface_hub import hf_hub_download\n",
41
+ "hf_hub_download(repo_id='asdf98/BokehFlow', filename='bokehflow.py', local_dir='.')\n",
42
+ "print('✓ BokehFlow downloaded')"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "# ============================================================\n",
52
+ "# STEP 2: Configuration — CHANGE THESE IF YOU WANT\n",
53
+ "# ============================================================\n",
54
+ "CONFIG = {\n",
55
+ " # Model\n",
56
+ " 'variant': 'nano', # 'nano'=583K params, 'small'=3.1M, 'base'=12M\n",
57
+ " \n",
58
+ " # Training\n",
59
+ " 'batch_size': 4, # 4 for T4 16GB, 8 for A100\n",
60
+ " 'crop_size': 256, # 256x256 random crops\n",
61
+ " 'num_epochs': 5, # 5 epochs for demo, 50+ for full training\n",
62
+ " 'lr': 3e-4,\n",
63
+ " 'weight_decay': 0.05,\n",
64
+ " 'max_grad_norm': 1.0,\n",
65
+ " \n",
66
+ " # Data\n",
67
+ " 'num_workers': 2, # 2 for Colab, 4 for Kaggle\n",
68
+ " 'max_train_samples': 500, # Limit for quick test. Set None for full dataset.\n",
69
+ " \n",
70
+ " # Target f-stop (train on f/2.0 bokeh)\n",
71
+ " 'target_fstop': 2.0,\n",
72
+ " \n",
73
+ " # Save\n",
74
+ " 'save_every': 1, # Save checkpoint every N epochs\n",
75
+ " 'output_dir': './checkpoints',\n",
76
+ "}\n",
77
+ "\n",
78
+ "# Auto-detect Kaggle dual GPU\n",
79
+ "import torch\n",
80
+ "NUM_GPUS = torch.cuda.device_count()\n",
81
+ "print(f'GPUs: {NUM_GPUS}, Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"CPU\"}')\n",
82
+ "if NUM_GPUS > 1:\n",
83
+ " print(f'Kaggle dual-GPU detected! Will use DataParallel.')"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "# ============================================================\n",
93
+ "# STEP 3: Dataset — Download RealBokeh (raw images, ~19GB)\n",
94
+ "# For free Colab/Kaggle, we use the HF Hub API to stream\n",
95
+ "# ============================================================\n",
96
+ "import os, json, re, glob\n",
97
+ "from pathlib import Path\n",
98
+ "from huggingface_hub import snapshot_download\n",
99
+ "\n",
100
+ "# Only download the train split input images + f/2.0 GT + metadata\n",
101
+ "# This saves bandwidth vs full 19GB\n",
102
+ "DATA_DIR = './realbokeh'\n",
103
+ "\n",
104
+ "if not os.path.exists(f'{DATA_DIR}/train/in'):\n",
105
+ " print('Downloading RealBokeh train split (input + metadata)...')\n",
106
+ " print('This downloads ~5GB. On Colab it takes ~3-5 minutes.')\n",
107
+ " snapshot_download(\n",
108
+ " repo_id='timseizinger/RealBokeh_3MP',\n",
109
+ " repo_type='dataset',\n",
110
+ " local_dir=DATA_DIR,\n",
111
+ " allow_patterns=['train/in/*', 'train/metadata/*', 'train/gt/*/f2.0*',\n",
112
+ " 'train/gt/*/*_f2.0*',\n",
113
+ " 'validation/in/*', 'validation/metadata/*', \n",
114
+ " 'validation/gt/*/*_f2.0*'],\n",
115
+ " )\n",
116
+ " print('✓ Dataset downloaded')\n",
117
+ "else:\n",
118
+ " print('✓ Dataset already exists')"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "# ============================================================\n",
128
+ "# STEP 4: PyTorch Dataset class for RealBokeh\n",
129
+ "# ============================================================\n",
130
+ "import torch\n",
131
+ "from torch.utils.data import Dataset, DataLoader\n",
132
+ "from torchvision import transforms\n",
133
+ "from PIL import Image\n",
134
+ "import random\n",
135
+ "\n",
136
+ "class RealBokehDataset(Dataset):\n",
137
+ " \"\"\"RealBokeh dataset for BokehFlow training.\n",
138
+ " \n",
139
+ " Each sample returns:\n",
140
+ " input_img: (3, crop_size, crop_size) sharp f/22 image\n",
141
+ " target_img: (3, crop_size, crop_size) bokeh GT at target f-stop\n",
142
+ " f_number: scalar f-stop value\n",
143
+ " focal_length_mm: scalar focal length\n",
144
+ " focus_distance_m: scalar focus distance in meters\n",
145
+ " \"\"\"\n",
146
+ " \n",
147
+ " def __init__(self, data_dir, split='train', crop_size=256, \n",
148
+ " target_fstop=2.0, max_samples=None):\n",
149
+ " self.data_dir = Path(data_dir) / split\n",
150
+ " self.crop_size = crop_size\n",
151
+ " self.target_fstop = target_fstop\n",
152
+ " \n",
153
+ " # Load metadata\n",
154
+ " self.samples = []\n",
155
+ " meta_dir = self.data_dir / 'metadata'\n",
156
+ " if not meta_dir.exists():\n",
157
+ " raise FileNotFoundError(f'No metadata at {meta_dir}')\n",
158
+ " \n",
159
+ " for meta_file in sorted(meta_dir.glob('*.json')):\n",
160
+ " with open(meta_file) as f:\n",
161
+ " meta = json.load(f)\n",
162
+ " \n",
163
+ " # Find target f-stop image\n",
164
+ " fstop_str = f'f{target_fstop}'\n",
165
+ " gt_path = None\n",
166
+ " for img, av in zip(meta['target_images'], meta['target_avs']):\n",
167
+ " if abs(av - target_fstop) < 0.01:\n",
168
+ " gt_path = self.data_dir / img\n",
169
+ " break\n",
170
+ " \n",
171
+ " if gt_path is None or not gt_path.exists():\n",
172
+ " continue\n",
173
+ " \n",
174
+ " input_path = self.data_dir / meta['source_image']\n",
175
+ " if not input_path.exists():\n",
176
+ " continue\n",
177
+ " \n",
178
+ " self.samples.append({\n",
179
+ " 'input': str(input_path),\n",
180
+ " 'target': str(gt_path),\n",
181
+ " 'f_number': target_fstop,\n",
182
+ " 'focal_length_mm': float(meta['focal_length']),\n",
183
+ " 'focus_distance_m': float(meta['focus_plane_distance']),\n",
184
+ " })\n",
185
+ " \n",
186
+ " if max_samples:\n",
187
+ " self.samples = self.samples[:max_samples]\n",
188
+ " \n",
189
+ " print(f'{split}: {len(self.samples)} paired samples found')\n",
190
+ " \n",
191
+ " self.to_tensor = transforms.ToTensor()\n",
192
+ " \n",
193
+ " def __len__(self):\n",
194
+ " return len(self.samples)\n",
195
+ " \n",
196
+ " def __getitem__(self, idx):\n",
197
+ " s = self.samples[idx]\n",
198
+ " \n",
199
+ " # Load images\n",
200
+ " inp = Image.open(s['input']).convert('RGB')\n",
201
+ " tgt = Image.open(s['target']).convert('RGB')\n",
202
+ " \n",
203
+ " # Random crop (same crop for both)\n",
204
+ " w, h = inp.size\n",
205
+ " cs = self.crop_size\n",
206
+ " if w >= cs and h >= cs:\n",
207
+ " x = random.randint(0, w - cs)\n",
208
+ " y = random.randint(0, h - cs)\n",
209
+ " inp = inp.crop((x, y, x+cs, y+cs))\n",
210
+ " tgt = tgt.crop((x, y, x+cs, y+cs))\n",
211
+ " else:\n",
212
+ " inp = inp.resize((cs, cs), Image.LANCZOS)\n",
213
+ " tgt = tgt.resize((cs, cs), Image.LANCZOS)\n",
214
+ " \n",
215
+ " # Random horizontal flip\n",
216
+ " if random.random() > 0.5:\n",
217
+ " inp = inp.transpose(Image.FLIP_LEFT_RIGHT)\n",
218
+ " tgt = tgt.transpose(Image.FLIP_LEFT_RIGHT)\n",
219
+ " \n",
220
+ " inp_t = self.to_tensor(inp) # [0,1] range\n",
221
+ " tgt_t = self.to_tensor(tgt)\n",
222
+ " \n",
223
+ " return {\n",
224
+ " 'input': inp_t,\n",
225
+ " 'target': tgt_t,\n",
226
+ " 'f_number': torch.tensor(s['f_number'], dtype=torch.float32),\n",
227
+ " 'focal_length_mm': torch.tensor(s['focal_length_mm'], dtype=torch.float32),\n",
228
+ " 'focus_distance_m': torch.tensor(s['focus_distance_m'], dtype=torch.float32),\n",
229
+ " }\n",
230
+ "\n",
231
+ "# Create datasets\n",
232
+ "train_ds = RealBokehDataset(\n",
233
+ " DATA_DIR, split='train', \n",
234
+ " crop_size=CONFIG['crop_size'],\n",
235
+ " target_fstop=CONFIG['target_fstop'],\n",
236
+ " max_samples=CONFIG['max_train_samples'],\n",
237
+ ")\n",
238
+ "\n",
239
+ "train_loader = DataLoader(\n",
240
+ " train_ds, \n",
241
+ " batch_size=CONFIG['batch_size'],\n",
242
+ " shuffle=True,\n",
243
+ " num_workers=CONFIG['num_workers'],\n",
244
+ " pin_memory=True,\n",
245
+ " drop_last=True,\n",
246
+ ")\n",
247
+ "\n",
248
+ "print(f'\\n✓ DataLoader ready: {len(train_loader)} batches per epoch')"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "metadata": {},
255
+ "outputs": [],
256
+ "source": [
257
+ "# ============================================================\n",
258
+ "# STEP 5: Create model\n",
259
+ "# ============================================================\n",
260
+ "from bokehflow import BokehFlow, BokehFlowConfig, BokehFlowLoss, model_summary\n",
261
+ "\n",
262
+ "config = BokehFlowConfig(variant=CONFIG['variant'])\n",
263
+ "model = BokehFlow(config)\n",
264
+ "\n",
265
+ "# Multi-GPU support for Kaggle\n",
266
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
267
+ "if NUM_GPUS > 1:\n",
268
+ " model = torch.nn.DataParallel(model)\n",
269
+ " print(f'Using DataParallel on {NUM_GPUS} GPUs')\n",
270
+ "model = model.to(device)\n",
271
+ "\n",
272
+ "# Print summary\n",
273
+ "print(model_summary(config))\n",
274
+ "print(f'Device: {device}')"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "metadata": {},
281
+ "outputs": [],
282
+ "source": [
283
+ "# ============================================================\n",
284
+ "# STEP 6: Training loop\n",
285
+ "# ============================================================\n",
286
+ "import torch.nn.functional as F\n",
287
+ "from tqdm.auto import tqdm\n",
288
+ "import time\n",
289
+ "\n",
290
+ "optimizer = torch.optim.AdamW(\n",
291
+ " model.parameters(), \n",
292
+ " lr=CONFIG['lr'], \n",
293
+ " weight_decay=CONFIG['weight_decay']\n",
294
+ ")\n",
295
+ "\n",
296
+ "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
297
+ " optimizer, T_max=CONFIG['num_epochs'] * len(train_loader)\n",
298
+ ")\n",
299
+ "\n",
300
+ "criterion = BokehFlowLoss(lambda_depth=0.5)\n",
301
+ "\n",
302
+ "os.makedirs(CONFIG['output_dir'], exist_ok=True)\n",
303
+ "\n",
304
+ "# Training\n",
305
+ "print(f'\\n{\"=\"*60}')\n",
306
+ "print(f'Starting training: {CONFIG[\"num_epochs\"]} epochs')\n",
307
+ "print(f'{\"=\"*60}\\n')\n",
308
+ "\n",
309
+ "for epoch in range(CONFIG['num_epochs']):\n",
310
+ " model.train()\n",
311
+ " epoch_loss = 0.0\n",
312
+ " epoch_start = time.time()\n",
313
+ " \n",
314
+ " pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{CONFIG[\"num_epochs\"]}')\n",
315
+ " for step, batch in enumerate(pbar):\n",
316
+ " # Move to device\n",
317
+ " inp = batch['input'].to(device)\n",
318
+ " tgt = batch['target'].to(device)\n",
319
+ " f_num = batch['f_number'].to(device)\n",
320
+ " focal = batch['focal_length_mm'].to(device)\n",
321
+ " focus = batch['focus_distance_m'].to(device)\n",
322
+ " \n",
323
+ " # Forward\n",
324
+ " output = model(inp, f_num, focal, focus)\n",
325
+ " \n",
326
+ " # Loss\n",
327
+ " losses = criterion(\n",
328
+ " output if not isinstance(output, dict) else output,\n",
329
+ " {'bokeh_gt': tgt}\n",
330
+ " )\n",
331
+ " loss = losses['total']\n",
332
+ " \n",
333
+ " # Backward\n",
334
+ " optimizer.zero_grad()\n",
335
+ " loss.backward()\n",
336
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['max_grad_norm'])\n",
337
+ " optimizer.step()\n",
338
+ " scheduler.step()\n",
339
+ " \n",
340
+ " epoch_loss += loss.item()\n",
341
+ " pbar.set_postfix({\n",
342
+ " 'loss': f'{loss.item():.4f}',\n",
343
+ " 'lr': f'{scheduler.get_last_lr()[0]:.2e}',\n",
344
+ " })\n",
345
+ " \n",
346
+ " avg_loss = epoch_loss / len(train_loader)\n",
347
+ " elapsed = time.time() - epoch_start\n",
348
+ " print(f'Epoch {epoch+1}: avg_loss={avg_loss:.4f}, time={elapsed:.0f}s')\n",
349
+ " \n",
350
+ " # Save checkpoint\n",
351
+ " if (epoch + 1) % CONFIG['save_every'] == 0:\n",
352
+ " ckpt_path = f'{CONFIG[\"output_dir\"]}/bokehflow_{CONFIG[\"variant\"]}_epoch{epoch+1}.pt'\n",
353
+ " state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()\n",
354
+ " torch.save({\n",
355
+ " 'epoch': epoch + 1,\n",
356
+ " 'model_state_dict': state,\n",
357
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
358
+ " 'loss': avg_loss,\n",
359
+ " 'config': CONFIG,\n",
360
+ " }, ckpt_path)\n",
361
+ " print(f' ✓ Saved checkpoint: {ckpt_path}')\n",
362
+ "\n",
363
+ "print(f'\\n✓ Training complete!')"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": null,
369
+ "metadata": {},
370
+ "outputs": [],
371
+ "source": [
372
+ "# ============================================================\n",
373
+ "# STEP 7: Quick inference test\n",
374
+ "# ============================================================\n",
375
+ "import matplotlib.pyplot as plt\n",
376
+ "\n",
377
+ "model.eval()\n",
378
+ "with torch.no_grad():\n",
379
+ " sample = train_ds[0]\n",
380
+ " inp = sample['input'].unsqueeze(0).to(device)\n",
381
+ " out = model(\n",
382
+ " inp,\n",
383
+ " sample['f_number'].unsqueeze(0).to(device),\n",
384
+ " sample['focal_length_mm'].unsqueeze(0).to(device),\n",
385
+ " sample['focus_distance_m'].unsqueeze(0).to(device),\n",
386
+ " )\n",
387
+ "\n",
388
+ "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
389
+ "axes[0].imshow(sample['input'].permute(1,2,0).numpy())\n",
390
+ "axes[0].set_title('Input (f/22)')\n",
391
+ "axes[1].imshow(out['bokeh'][0].cpu().permute(1,2,0).clamp(0,1).numpy())\n",
392
+ "axes[1].set_title('BokehFlow Output')\n",
393
+ "axes[2].imshow(sample['target'].permute(1,2,0).numpy())\n",
394
+ "axes[2].set_title('Ground Truth (f/2.0)')\n",
395
+ "for ax in axes: ax.axis('off')\n",
396
+ "plt.tight_layout()\n",
397
+ "plt.savefig('result.png', dpi=100)\n",
398
+ "plt.show()\n",
399
+ "print('✓ Inference test complete')"
400
+ ]
401
+ }
402
+ ],
403
+ "metadata": {
404
+ "kernelspec": {
405
+ "display_name": "Python 3",
406
+ "language": "python",
407
+ "name": "python3"
408
+ },
409
+ "language_info": {
410
+ "name": "python",
411
+ "version": "3.10.0"
412
+ }
413
+ },
414
+ "nbformat": 4,
415
+ "nbformat_minor": 4
416
+ }