""" BokehFlow v3 Training Script Trains on RealBokeh_3MP dataset (timseizinger/RealBokeh_3MP) Self-contained — all model code is inline so this works as a standalone script in HF Jobs or any GPU environment. Usage: # Quick test (200 scenes, 3 epochs) VARIANT=small MAX_SCENES=200 EPOCHS=3 BATCH_SIZE=4 python train_v3.py # Full training (all 3960 scenes, 10 epochs) VARIANT=small EPOCHS=10 BATCH_SIZE=8 python train_v3.py Environment variables: VARIANT: nano/small/base (default: small) MAX_SCENES: limit scenes for testing (default: 0 = all) EPOCHS: number of epochs (default: 10) BATCH_SIZE: batch size (default: 4) CROP_SIZE: random crop size (default: 256) LR: learning rate (default: 2e-4) HUB_MODEL_ID: HF model repo to push to (default: asdf98/BokehFlow) Requirements: pip install torch torchvision Pillow huggingface_hub trackio aiohttp """ import os, sys, time, json, math, random, glob import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from pathlib import Path from dataclasses import dataclass # =================================================================== # Model (inline — identical to bokehflow_v3.py) # =================================================================== @dataclass class BokehFlowConfig: variant: str = "small" embed_dim: int = 96 depth_blocks: int = 6 bokeh_blocks: int = 6 fusion_every: int = 2 stem_channels: int = 48 patch_stride: int = 4 max_coc_radius: int = 31 num_depth_layers: int = 8 aperture_embed_dim: int = 64 dropout: float = 0.0 sensor_width_mm: float = 36.0 default_focal_mm: float = 50.0 default_fnumber: float = 2.0 default_focus_m: float = 2.0 ffn_expansion: int = 2 large_kernel: int = 7 def __post_init__(self): if self.variant == "nano": self.embed_dim = 48 self.depth_blocks = 4 self.bokeh_blocks = 4 elif self.variant == "small": self.embed_dim = 96 self.depth_blocks = 6 self.bokeh_blocks = 6 elif self.variant == "base": self.embed_dim = 192 self.depth_blocks = 8 self.bokeh_blocks = 8 class GatedConvRecurrence(nn.Module): def __init__(self, dim, kernel_size=7, ffn_expansion=2): super().__init__() k = kernel_size; p = k // 2 self.norm1 = nn.GroupNorm(8, dim) self.dw1 = nn.Conv2d(dim, dim, k, padding=p, groups=dim, bias=False) self.dw2 = nn.Conv2d(dim, dim, k, padding=p, groups=dim, bias=False) self.pw = nn.Conv2d(dim, dim, 1, bias=False) self.gate_proj = nn.Conv2d(dim, dim, 1, bias=True) self.norm2 = nn.GroupNorm(8, dim) h = dim * ffn_expansion self.ffn = nn.Sequential(nn.Conv2d(dim, h, 1, bias=False), nn.GELU(), nn.Conv2d(h, dim, 1, bias=False)) nn.init.zeros_(self.pw.weight) nn.init.zeros_(self.ffn[-1].weight) def forward(self, x): h = self.norm1(x) spatial = self.dw2(F.silu(self.dw1(h))) spatial = self.pw(spatial) gate = torch.sigmoid(self.gate_proj(h)) x = x + spatial * gate x = x + self.ffn(self.norm2(x)) return x class GatedConvRecurrenceWithACFM(GatedConvRecurrence): def __init__(self, dim, kernel_size=7, ffn_expansion=2, aperture_embed_dim=64): super().__init__(dim, kernel_size, ffn_expansion) self.acfm = nn.Linear(aperture_embed_dim, dim * 2) nn.init.zeros_(self.acfm.weight) self.acfm.bias.data[:dim] = 1.0 self.acfm.bias.data[dim:] = 0.0 def forward(self, x, aperture_embed=None): x = super().forward(x) if aperture_embed is not None: B, C, H, W = x.shape ss = self.acfm(aperture_embed) scale = ss[:, :C].view(B, C, 1, 1) shift = ss[:, C:].view(B, C, 1, 1) x = x * scale + shift return x class ConvStem(nn.Module): def __init__(self, in_ch=3, stem_ch=48, embed_dim=96): super().__init__() self.net = nn.Sequential( nn.Conv2d(in_ch, stem_ch, 7, stride=2, padding=3, bias=False), nn.GroupNorm(8, stem_ch), nn.GELU(), nn.Conv2d(stem_ch, stem_ch, 3, stride=2, padding=1, groups=stem_ch, bias=False), nn.Conv2d(stem_ch, embed_dim, 1, bias=False), nn.GroupNorm(8, embed_dim), nn.GELU()) def forward(self, x): return self.net(x) class ApertureEncoder(nn.Module): def __init__(self, embed_dim=64): super().__init__() self.mlp = nn.Sequential(nn.Linear(3, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim), nn.GELU()) self.register_buffer('p_min', torch.tensor([1., 10., 0.1])) self.register_buffer('p_max', torch.tensor([22., 200., 100.])) def forward(self, f_number, focal_mm, focus_m): p = torch.stack([f_number, focal_mm, focus_m], -1) return self.mlp(((p - self.p_min) / (self.p_max - self.p_min + 1e-6)).clamp(0,1)) class CrossFusion(nn.Module): def __init__(self, d): super().__init__() self.gate_d = nn.Sequential(nn.Conv2d(d, d, 1), nn.Sigmoid()) self.gate_b = nn.Sequential(nn.Conv2d(d, d, 1), nn.Sigmoid()) self.proj_d = nn.Conv2d(d, d, 1, bias=False) self.proj_b = nn.Conv2d(d, d, 1, bias=False) nn.init.zeros_(self.proj_d.weight); nn.init.zeros_(self.proj_b.weight) def forward(self, d_feat, b_feat): return (d_feat + self.gate_d(b_feat) * self.proj_d(b_feat), b_feat + self.gate_b(d_feat) * self.proj_b(d_feat)) class DepthHead(nn.Module): def __init__(self, dim=96): super().__init__() self.net = nn.Sequential( nn.Conv2d(dim, dim//2, 3, padding=1), nn.GELU(), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.Conv2d(dim//2, dim//4, 3, padding=1), nn.GELU(), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.Conv2d(dim//4, 1, 3, padding=1), nn.Softplus()) def forward(self, x): return self.net(x).clamp(max=100.0) class BokehHead(nn.Module): def __init__(self, dim=96): super().__init__() self.net = nn.Sequential( nn.Conv2d(dim, dim, 3, padding=1), nn.GELU(), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.Conv2d(dim, dim//2, 3, padding=1), nn.GELU(), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), nn.Conv2d(dim//2, 3, 3, padding=1)) def forward(self, x): return self.net(x) class PGCoC(nn.Module): def __init__(self, sensor_width=36.0, max_radius=31, n_levels=5): super().__init__() self.sensor_width = sensor_width self.max_radius = max_radius self.n_levels = n_levels self.kernels = nn.ParameterList() for i in range(n_levels): sigma = (i + 1) * max_radius / n_levels / 3.0 ks = int(sigma * 6) | 1; ks = max(ks, 3); ks = min(ks, 31) k1d = torch.exp(-torch.arange(-(ks//2), ks//2+1).float()**2 / (2*sigma**2+1e-6)) k1d = k1d / k1d.sum() k2d = k1d.unsqueeze(1) @ k1d.unsqueeze(0) self.kernels.append(nn.Parameter(k2d.unsqueeze(0).unsqueeze(0), requires_grad=False)) self.refine = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.GELU(), nn.Conv2d(16, 3, 3, padding=1)) def _blur_at_level(self, image, kernel): B, C, H, W = image.shape k = kernel.expand(C, -1, -1, -1) p = kernel.shape[-1] // 2 return F.conv2d(F.pad(image, [p]*4, mode='reflect'), k, groups=C) def forward(self, image, depth, f_number, focal_mm, focus_m): B, C, H, W = image.shape f = focal_mm.view(-1,1,1,1); N = f_number.view(-1,1,1,1) S1 = (focus_m.view(-1,1,1,1) * 1000).clamp(min=51) D = (depth * 1000).clamp(min=100) coc = (f**2 / (N * (S1 - f).clamp(min=1))) * (D - S1).abs() / D coc_px = (coc * W / self.sensor_width / 2).clamp(0, self.max_radius) coc_norm = coc_px / self.max_radius blurred_levels = [self._blur_at_level(image, kernel) for kernel in self.kernels] level_float = coc_norm * (self.n_levels - 1) level_low = level_float.long().clamp(0, self.n_levels - 2) level_frac = (level_float - level_low.float()).clamp(0, 1) rendered = image.clone() for lv in range(self.n_levels - 1): mask = (level_low == lv).float() if mask.sum() > 0: interp = blurred_levels[lv] * (1 - level_frac) + blurred_levels[lv + 1] * level_frac rendered = rendered * (1 - mask) + interp * mask mask_top = (level_low >= self.n_levels - 2).float() * (level_frac > 0.99).float() rendered = rendered * (1 - mask_top) + blurred_levels[-1] * mask_top rendered = rendered + self.refine(rendered) * 0.1 return rendered, coc_px class BokehFlow(nn.Module): def __init__(self, config=None): super().__init__() if config is None: config = BokehFlowConfig() self.config = config; c = config self.stem = ConvStem(3, c.stem_channels, c.embed_dim) self.aperture_enc = ApertureEncoder(c.aperture_embed_dim) self.depth_blocks = nn.ModuleList([ GatedConvRecurrence(c.embed_dim, c.large_kernel, c.ffn_expansion) for _ in range(c.depth_blocks)]) self.bokeh_blocks = nn.ModuleList([ GatedConvRecurrenceWithACFM(c.embed_dim, c.large_kernel, c.ffn_expansion, c.aperture_embed_dim) for _ in range(c.bokeh_blocks)]) n_fusions = max(c.depth_blocks, c.bokeh_blocks) // c.fusion_every self.fusions = nn.ModuleList([CrossFusion(c.embed_dim) for _ in range(n_fusions)]) self.depth_head = DepthHead(c.embed_dim) self.bokeh_head = BokehHead(c.embed_dim) self.pgcoc = PGCoC(c.sensor_width_mm, c.max_coc_radius) self.blend_w = nn.Parameter(torch.tensor(0.5)) def forward(self, image, f_number=None, focal_length_mm=None, focus_distance_m=None, **kw): B = image.shape[0]; dev = image.device; c = self.config if f_number is None: f_number = torch.full((B,), c.default_fnumber, device=dev) if focal_length_mm is None: focal_length_mm = torch.full((B,), c.default_focal_mm, device=dev) if focus_distance_m is None: focus_distance_m = torch.full((B,), c.default_focus_m, device=dev) ae = self.aperture_enc(f_number, focal_length_mm, focus_distance_m) feat = self.stem(image) d_feat = feat; b_feat = feat; fi = 0 n_blk = max(c.depth_blocks, c.bokeh_blocks) for i in range(n_blk): if i < c.depth_blocks: d_feat = self.depth_blocks[i](d_feat) if i < c.bokeh_blocks: b_feat = self.bokeh_blocks[i](b_feat, ae) if (i+1) % c.fusion_every == 0 and fi < len(self.fusions): d_feat, b_feat = self.fusions[fi](d_feat, b_feat); fi += 1 depth = self.depth_head(d_feat) if depth.shape[2:] != image.shape[2:]: depth = F.interpolate(depth, image.shape[2:], mode='bilinear', align_corners=False) physics_bokeh, coc_map = self.pgcoc(image, depth, f_number, focal_length_mm, focus_distance_m) learned_bokeh = self.bokeh_head(b_feat) if learned_bokeh.shape[2:] != image.shape[2:]: learned_bokeh = F.interpolate(learned_bokeh, image.shape[2:], mode='bilinear', align_corners=False) w = torch.sigmoid(self.blend_w) bokeh = (w * physics_bokeh + (1-w) * (image + learned_bokeh)).clamp(0, 1) return {'bokeh': bokeh, 'depth': depth, 'coc_map': coc_map} class BokehFlowLoss(nn.Module): def forward(self, pred, targets): bp, bg = pred['bokeh'], targets['bokeh_gt'] l1 = F.l1_loss(bp, bg) C1, C2 = 0.01**2, 0.03**2 mu_p = F.avg_pool2d(bp, 11, 1, 5); mu_g = F.avg_pool2d(bg, 11, 1, 5) mu_pp = mu_p*mu_p; mu_gg = mu_g*mu_g; mu_pg = mu_p*mu_g sig_pp = F.avg_pool2d(bp*bp, 11, 1, 5) - mu_pp sig_gg = F.avg_pool2d(bg*bg, 11, 1, 5) - mu_gg sig_pg = F.avg_pool2d(bp*bg, 11, 1, 5) - mu_pg ssim_map = ((2*mu_pg+C1)*(2*sig_pg+C2)) / ((mu_pp+mu_gg+C1)*(sig_pp+sig_gg+C2)) ssim_loss = 1.0 - ssim_map.mean() return {'total': l1 + ssim_loss, 'l1': l1.detach(), 'ssim': ssim_loss.detach()} # =================================================================== # Dataset # =================================================================== class RealBokehDataset(Dataset): """Loads from local disk after snapshot_download.""" def __init__(self, root, crop_size=256, split='train', target_fstop='2.0'): self.crop = crop_size self.pairs = [] in_dir = Path(root) / split / 'in' gt_dir = Path(root) / split / 'gt' meta_dir = Path(root) / split / 'metadata' for in_path in sorted(in_dir.glob('*_f22.JPG')): sid = in_path.stem.split('_')[0] gt_path = gt_dir / sid / f'{sid}_f{target_fstop}.JPG' meta_path = meta_dir / f'{sid}.json' if gt_path.exists(): meta = {} if meta_path.exists(): with open(meta_path) as f: meta = json.load(f) self.pairs.append((str(in_path), str(gt_path), meta)) print(f"RealBokehDataset: {len(self.pairs)} pairs found (target f/{target_fstop})") def __len__(self): return len(self.pairs) def __getitem__(self, idx): from PIL import Image import torchvision.transforms.functional as TF in_path, gt_path, meta = self.pairs[idx] inp = Image.open(in_path).convert('RGB') gt = Image.open(gt_path).convert('RGB') # Resize to manageable size first, then crop short = min(inp.size) if short > 512: scale = 512.0 / short new_w = int(inp.size[0] * scale) new_h = int(inp.size[1] * scale) inp = inp.resize((new_w, new_h), Image.LANCZOS) gt = gt.resize((new_w, new_h), Image.LANCZOS) inp = TF.to_tensor(inp) gt = TF.to_tensor(gt) # Random crop _, h, w = inp.shape cs = self.crop if h >= cs and w >= cs: i = random.randint(0, h - cs) j = random.randint(0, w - cs) inp = inp[:, i:i+cs, j:j+cs] gt = gt[:, i:i+cs, j:j+cs] else: inp = F.interpolate(inp.unsqueeze(0), (cs, cs), mode='bilinear', align_corners=False).squeeze(0) gt = F.interpolate(gt.unsqueeze(0), (cs, cs), mode='bilinear', align_corners=False).squeeze(0) # Random horizontal flip if random.random() > 0.5: inp = inp.flip(-1) gt = gt.flip(-1) focal = meta.get('focal_length', 50.0) focus = meta.get('focus_plane_distance', 2.0) fnum = 2.0 return { 'image': inp, 'bokeh_gt': gt, 'f_number': torch.tensor(fnum, dtype=torch.float32), 'focal_length_mm': torch.tensor(float(focal), dtype=torch.float32), 'focus_distance_m': torch.tensor(float(focus), dtype=torch.float32), } # =================================================================== # Data download # =================================================================== def download_realbokeh(max_scenes=None): """Download RealBokeh_3MP using snapshot_download with exact patterns.""" from huggingface_hub import snapshot_download data_dir = '/tmp/realbokeh' check_file = Path(data_dir) / 'train' / 'in' / '1_f22.JPG' if check_file.exists(): n = len(list(Path(data_dir).glob('train/in/*_f22.JPG'))) print(f"Data already cached: {n} scenes") return data_dir print("Fetching metadata to build download list...") snapshot_download( 'timseizinger/RealBokeh_3MP', repo_type='dataset', local_dir=data_dir, allow_patterns=['train/metadata/*.json'], ) meta_dir = Path(data_dir) / 'train' / 'metadata' scene_ids = sorted([p.stem for p in meta_dir.glob('*.json')], key=lambda x: int(x)) if max_scenes: scene_ids = scene_ids[:max_scenes] print(f"Found {len(scene_ids)} scenes. Downloading input + f/2.0 GT images...") patterns = [] for sid in scene_ids: patterns.append(f'train/in/{sid}_f22.JPG') patterns.append(f'train/gt/{sid}/{sid}_f2.0.JPG') t0 = time.time() snapshot_download( 'timseizinger/RealBokeh_3MP', repo_type='dataset', local_dir=data_dir, allow_patterns=patterns, ) elapsed = time.time() - t0 n_found = len(list(Path(data_dir).glob('train/in/*_f22.JPG'))) print(f"Downloaded {n_found} scenes in {elapsed:.0f}s") return data_dir # =================================================================== # Training loop # =================================================================== def train(): import trackio VARIANT = os.environ.get('VARIANT', 'small') MAX_SCENES = int(os.environ.get('MAX_SCENES', '0')) or None EPOCHS = int(os.environ.get('EPOCHS', '10')) BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) CROP_SIZE = int(os.environ.get('CROP_SIZE', '256')) LR = float(os.environ.get('LR', '2e-4')) HUB_MODEL_ID = os.environ.get('HUB_MODEL_ID', 'asdf98/BokehFlow') device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Device: {device}") if device == 'cuda': print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") trackio.init(project="bokehflow", name=f"v3-{VARIANT}-e{EPOCHS}-lr{LR}") data_dir = download_realbokeh(max_scenes=MAX_SCENES) ds = RealBokehDataset(data_dir, crop_size=CROP_SIZE) dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True, drop_last=True, persistent_workers=True) print(f"Batches per epoch: {len(dl)}") config = BokehFlowConfig(variant=VARIANT) model = BokehFlow(config).to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model: BokehFlow-{VARIANT}, {n_params:,} params") optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01) total_steps = EPOCHS * len(dl) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=LR/20) loss_fn = BokehFlowLoss() scaler = torch.amp.GradScaler('cuda', enabled=(device == 'cuda')) global_step = 0 best_loss = float('inf') for epoch in range(EPOCHS): model.train() epoch_loss = 0.0 t_epoch = time.time() for batch_idx, batch in enumerate(dl): t_step = time.time() image = batch['image'].to(device) bokeh_gt = batch['bokeh_gt'].to(device) f_number = batch['f_number'].to(device) focal_mm = batch['focal_length_mm'].to(device) focus_m = batch['focus_distance_m'].to(device) optimizer.zero_grad(set_to_none=True) with torch.amp.autocast('cuda', enabled=(device == 'cuda')): out = model(image, f_number, focal_mm, focus_m) losses = loss_fn(out, {'bokeh_gt': bokeh_gt}) loss = losses['total'] scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.step(optimizer) scaler.update() scheduler.step() epoch_loss += loss.item() global_step += 1 step_time = time.time() - t_step if global_step % 10 == 0 or batch_idx == 0: lr_now = scheduler.get_last_lr()[0] print(f"Ep {epoch+1}/{EPOCHS} [{batch_idx+1}/{len(dl)}] " f"loss={loss.item():.4f} l1={losses['l1'].item():.4f} " f"ssim={losses['ssim'].item():.4f} lr={lr_now:.2e} " f"step={step_time*1000:.0f}ms") trackio.log({ "loss": loss.item(), "l1": losses['l1'].item(), "ssim_loss": losses['ssim'].item(), "lr": lr_now, "step_ms": step_time * 1000, "epoch": epoch + 1, }) if device == 'cuda' and global_step == 1: vram = torch.cuda.max_memory_allocated() / 1e9 print(f"Peak VRAM after 1st step: {vram:.2f} GB") trackio.log({"peak_vram_gb": vram}) epoch_time = time.time() - t_epoch avg_loss = epoch_loss / len(dl) print(f"Epoch {epoch+1}/{EPOCHS} done in {epoch_time:.0f}s, avg_loss={avg_loss:.4f}") trackio.log({"epoch_avg_loss": avg_loss, "epoch_time_s": epoch_time}) if avg_loss < best_loss: best_loss = avg_loss torch.save({ 'model_state_dict': model.state_dict(), 'config': config.__dict__, 'epoch': epoch + 1, 'loss': avg_loss, }, '/tmp/bokehflow_best.pt') print(f" Saved best model (loss={avg_loss:.4f})") # Push to hub print("\nPushing model to Hub...") from huggingface_hub import HfApi api = HfApi() torch.save({ 'model_state_dict': model.state_dict(), 'config': config.__dict__, 'epoch': EPOCHS, 'loss': avg_loss, }, '/tmp/bokehflow_final.pt') for fname in ['bokehflow_best.pt', 'bokehflow_final.pt']: fpath = f'/tmp/{fname}' if os.path.exists(fpath): api.upload_file( path_or_fileobj=fpath, path_in_repo=f'checkpoints/{fname}', repo_id=HUB_MODEL_ID, ) print(f" Uploaded {fname}") print(f"\nTraining complete! Model: https://huggingface.co/{HUB_MODEL_ID}") if __name__ == '__main__': train()