| """ |
| BokehFlow v3 Training Script |
| Trains on RealBokeh_3MP dataset (timseizinger/RealBokeh_3MP) |
| |
| Self-contained — all model code is inline so this works as a standalone |
| script in HF Jobs or any GPU environment. |
| |
| Usage: |
| # Quick test (200 scenes, 3 epochs) |
| VARIANT=small MAX_SCENES=200 EPOCHS=3 BATCH_SIZE=4 python train_v3.py |
| |
| # Full training (all 3960 scenes, 10 epochs) |
| VARIANT=small EPOCHS=10 BATCH_SIZE=8 python train_v3.py |
| |
| Environment variables: |
| VARIANT: nano/small/base (default: small) |
| MAX_SCENES: limit scenes for testing (default: 0 = all) |
| EPOCHS: number of epochs (default: 10) |
| BATCH_SIZE: batch size (default: 4) |
| CROP_SIZE: random crop size (default: 256) |
| LR: learning rate (default: 2e-4) |
| HUB_MODEL_ID: HF model repo to push to (default: asdf98/BokehFlow) |
| |
| Requirements: |
| pip install torch torchvision Pillow huggingface_hub trackio aiohttp |
| """ |
|
|
| import os, sys, time, json, math, random, glob |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from pathlib import Path |
| from dataclasses import dataclass |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class BokehFlowConfig: |
| variant: str = "small" |
| embed_dim: int = 96 |
| depth_blocks: int = 6 |
| bokeh_blocks: int = 6 |
| fusion_every: int = 2 |
| stem_channels: int = 48 |
| patch_stride: int = 4 |
| max_coc_radius: int = 31 |
| num_depth_layers: int = 8 |
| aperture_embed_dim: int = 64 |
| dropout: float = 0.0 |
| sensor_width_mm: float = 36.0 |
| default_focal_mm: float = 50.0 |
| default_fnumber: float = 2.0 |
| default_focus_m: float = 2.0 |
| ffn_expansion: int = 2 |
| large_kernel: int = 7 |
|
|
| def __post_init__(self): |
| if self.variant == "nano": |
| self.embed_dim = 48 |
| self.depth_blocks = 4 |
| self.bokeh_blocks = 4 |
| elif self.variant == "small": |
| self.embed_dim = 96 |
| self.depth_blocks = 6 |
| self.bokeh_blocks = 6 |
| elif self.variant == "base": |
| self.embed_dim = 192 |
| self.depth_blocks = 8 |
| self.bokeh_blocks = 8 |
|
|
|
|
| class GatedConvRecurrence(nn.Module): |
| def __init__(self, dim, kernel_size=7, ffn_expansion=2): |
| super().__init__() |
| k = kernel_size; p = k // 2 |
| self.norm1 = nn.GroupNorm(8, dim) |
| self.dw1 = nn.Conv2d(dim, dim, k, padding=p, groups=dim, bias=False) |
| self.dw2 = nn.Conv2d(dim, dim, k, padding=p, groups=dim, bias=False) |
| self.pw = nn.Conv2d(dim, dim, 1, bias=False) |
| self.gate_proj = nn.Conv2d(dim, dim, 1, bias=True) |
| self.norm2 = nn.GroupNorm(8, dim) |
| h = dim * ffn_expansion |
| self.ffn = nn.Sequential(nn.Conv2d(dim, h, 1, bias=False), nn.GELU(), nn.Conv2d(h, dim, 1, bias=False)) |
| nn.init.zeros_(self.pw.weight) |
| nn.init.zeros_(self.ffn[-1].weight) |
|
|
| def forward(self, x): |
| h = self.norm1(x) |
| spatial = self.dw2(F.silu(self.dw1(h))) |
| spatial = self.pw(spatial) |
| gate = torch.sigmoid(self.gate_proj(h)) |
| x = x + spatial * gate |
| x = x + self.ffn(self.norm2(x)) |
| return x |
|
|
|
|
| class GatedConvRecurrenceWithACFM(GatedConvRecurrence): |
| def __init__(self, dim, kernel_size=7, ffn_expansion=2, aperture_embed_dim=64): |
| super().__init__(dim, kernel_size, ffn_expansion) |
| self.acfm = nn.Linear(aperture_embed_dim, dim * 2) |
| nn.init.zeros_(self.acfm.weight) |
| self.acfm.bias.data[:dim] = 1.0 |
| self.acfm.bias.data[dim:] = 0.0 |
|
|
| def forward(self, x, aperture_embed=None): |
| x = super().forward(x) |
| if aperture_embed is not None: |
| B, C, H, W = x.shape |
| ss = self.acfm(aperture_embed) |
| scale = ss[:, :C].view(B, C, 1, 1) |
| shift = ss[:, C:].view(B, C, 1, 1) |
| x = x * scale + shift |
| return x |
|
|
|
|
| class ConvStem(nn.Module): |
| def __init__(self, in_ch=3, stem_ch=48, embed_dim=96): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(in_ch, stem_ch, 7, stride=2, padding=3, bias=False), |
| nn.GroupNorm(8, stem_ch), nn.GELU(), |
| nn.Conv2d(stem_ch, stem_ch, 3, stride=2, padding=1, groups=stem_ch, bias=False), |
| nn.Conv2d(stem_ch, embed_dim, 1, bias=False), |
| nn.GroupNorm(8, embed_dim), nn.GELU()) |
| def forward(self, x): return self.net(x) |
|
|
|
|
| class ApertureEncoder(nn.Module): |
| def __init__(self, embed_dim=64): |
| super().__init__() |
| self.mlp = nn.Sequential(nn.Linear(3, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim), nn.GELU()) |
| self.register_buffer('p_min', torch.tensor([1., 10., 0.1])) |
| self.register_buffer('p_max', torch.tensor([22., 200., 100.])) |
| def forward(self, f_number, focal_mm, focus_m): |
| p = torch.stack([f_number, focal_mm, focus_m], -1) |
| return self.mlp(((p - self.p_min) / (self.p_max - self.p_min + 1e-6)).clamp(0,1)) |
|
|
|
|
| class CrossFusion(nn.Module): |
| def __init__(self, d): |
| super().__init__() |
| self.gate_d = nn.Sequential(nn.Conv2d(d, d, 1), nn.Sigmoid()) |
| self.gate_b = nn.Sequential(nn.Conv2d(d, d, 1), nn.Sigmoid()) |
| self.proj_d = nn.Conv2d(d, d, 1, bias=False) |
| self.proj_b = nn.Conv2d(d, d, 1, bias=False) |
| nn.init.zeros_(self.proj_d.weight); nn.init.zeros_(self.proj_b.weight) |
| def forward(self, d_feat, b_feat): |
| return (d_feat + self.gate_d(b_feat) * self.proj_d(b_feat), |
| b_feat + self.gate_b(d_feat) * self.proj_b(d_feat)) |
|
|
|
|
| class DepthHead(nn.Module): |
| def __init__(self, dim=96): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(dim, dim//2, 3, padding=1), nn.GELU(), |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), |
| nn.Conv2d(dim//2, dim//4, 3, padding=1), nn.GELU(), |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), |
| nn.Conv2d(dim//4, 1, 3, padding=1), nn.Softplus()) |
| def forward(self, x): return self.net(x).clamp(max=100.0) |
|
|
|
|
| class BokehHead(nn.Module): |
| def __init__(self, dim=96): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Conv2d(dim, dim, 3, padding=1), nn.GELU(), |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), |
| nn.Conv2d(dim, dim//2, 3, padding=1), nn.GELU(), |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), |
| nn.Conv2d(dim//2, 3, 3, padding=1)) |
| def forward(self, x): return self.net(x) |
|
|
|
|
| class PGCoC(nn.Module): |
| def __init__(self, sensor_width=36.0, max_radius=31, n_levels=5): |
| super().__init__() |
| self.sensor_width = sensor_width |
| self.max_radius = max_radius |
| self.n_levels = n_levels |
| self.kernels = nn.ParameterList() |
| for i in range(n_levels): |
| sigma = (i + 1) * max_radius / n_levels / 3.0 |
| ks = int(sigma * 6) | 1; ks = max(ks, 3); ks = min(ks, 31) |
| k1d = torch.exp(-torch.arange(-(ks//2), ks//2+1).float()**2 / (2*sigma**2+1e-6)) |
| k1d = k1d / k1d.sum() |
| k2d = k1d.unsqueeze(1) @ k1d.unsqueeze(0) |
| self.kernels.append(nn.Parameter(k2d.unsqueeze(0).unsqueeze(0), requires_grad=False)) |
| self.refine = nn.Sequential(nn.Conv2d(3, 16, 3, padding=1), nn.GELU(), nn.Conv2d(16, 3, 3, padding=1)) |
|
|
| def _blur_at_level(self, image, kernel): |
| B, C, H, W = image.shape |
| k = kernel.expand(C, -1, -1, -1) |
| p = kernel.shape[-1] // 2 |
| return F.conv2d(F.pad(image, [p]*4, mode='reflect'), k, groups=C) |
|
|
| def forward(self, image, depth, f_number, focal_mm, focus_m): |
| B, C, H, W = image.shape |
| f = focal_mm.view(-1,1,1,1); N = f_number.view(-1,1,1,1) |
| S1 = (focus_m.view(-1,1,1,1) * 1000).clamp(min=51) |
| D = (depth * 1000).clamp(min=100) |
| coc = (f**2 / (N * (S1 - f).clamp(min=1))) * (D - S1).abs() / D |
| coc_px = (coc * W / self.sensor_width / 2).clamp(0, self.max_radius) |
| coc_norm = coc_px / self.max_radius |
| blurred_levels = [self._blur_at_level(image, kernel) for kernel in self.kernels] |
| level_float = coc_norm * (self.n_levels - 1) |
| level_low = level_float.long().clamp(0, self.n_levels - 2) |
| level_frac = (level_float - level_low.float()).clamp(0, 1) |
| rendered = image.clone() |
| for lv in range(self.n_levels - 1): |
| mask = (level_low == lv).float() |
| if mask.sum() > 0: |
| interp = blurred_levels[lv] * (1 - level_frac) + blurred_levels[lv + 1] * level_frac |
| rendered = rendered * (1 - mask) + interp * mask |
| mask_top = (level_low >= self.n_levels - 2).float() * (level_frac > 0.99).float() |
| rendered = rendered * (1 - mask_top) + blurred_levels[-1] * mask_top |
| rendered = rendered + self.refine(rendered) * 0.1 |
| return rendered, coc_px |
|
|
|
|
| class BokehFlow(nn.Module): |
| def __init__(self, config=None): |
| super().__init__() |
| if config is None: config = BokehFlowConfig() |
| self.config = config; c = config |
| self.stem = ConvStem(3, c.stem_channels, c.embed_dim) |
| self.aperture_enc = ApertureEncoder(c.aperture_embed_dim) |
| self.depth_blocks = nn.ModuleList([ |
| GatedConvRecurrence(c.embed_dim, c.large_kernel, c.ffn_expansion) |
| for _ in range(c.depth_blocks)]) |
| self.bokeh_blocks = nn.ModuleList([ |
| GatedConvRecurrenceWithACFM(c.embed_dim, c.large_kernel, c.ffn_expansion, c.aperture_embed_dim) |
| for _ in range(c.bokeh_blocks)]) |
| n_fusions = max(c.depth_blocks, c.bokeh_blocks) // c.fusion_every |
| self.fusions = nn.ModuleList([CrossFusion(c.embed_dim) for _ in range(n_fusions)]) |
| self.depth_head = DepthHead(c.embed_dim) |
| self.bokeh_head = BokehHead(c.embed_dim) |
| self.pgcoc = PGCoC(c.sensor_width_mm, c.max_coc_radius) |
| self.blend_w = nn.Parameter(torch.tensor(0.5)) |
|
|
| def forward(self, image, f_number=None, focal_length_mm=None, focus_distance_m=None, **kw): |
| B = image.shape[0]; dev = image.device; c = self.config |
| if f_number is None: f_number = torch.full((B,), c.default_fnumber, device=dev) |
| if focal_length_mm is None: focal_length_mm = torch.full((B,), c.default_focal_mm, device=dev) |
| if focus_distance_m is None: focus_distance_m = torch.full((B,), c.default_focus_m, device=dev) |
| ae = self.aperture_enc(f_number, focal_length_mm, focus_distance_m) |
| feat = self.stem(image) |
| d_feat = feat; b_feat = feat; fi = 0 |
| n_blk = max(c.depth_blocks, c.bokeh_blocks) |
| for i in range(n_blk): |
| if i < c.depth_blocks: d_feat = self.depth_blocks[i](d_feat) |
| if i < c.bokeh_blocks: b_feat = self.bokeh_blocks[i](b_feat, ae) |
| if (i+1) % c.fusion_every == 0 and fi < len(self.fusions): |
| d_feat, b_feat = self.fusions[fi](d_feat, b_feat); fi += 1 |
| depth = self.depth_head(d_feat) |
| if depth.shape[2:] != image.shape[2:]: |
| depth = F.interpolate(depth, image.shape[2:], mode='bilinear', align_corners=False) |
| physics_bokeh, coc_map = self.pgcoc(image, depth, f_number, focal_length_mm, focus_distance_m) |
| learned_bokeh = self.bokeh_head(b_feat) |
| if learned_bokeh.shape[2:] != image.shape[2:]: |
| learned_bokeh = F.interpolate(learned_bokeh, image.shape[2:], mode='bilinear', align_corners=False) |
| w = torch.sigmoid(self.blend_w) |
| bokeh = (w * physics_bokeh + (1-w) * (image + learned_bokeh)).clamp(0, 1) |
| return {'bokeh': bokeh, 'depth': depth, 'coc_map': coc_map} |
|
|
|
|
| class BokehFlowLoss(nn.Module): |
| def forward(self, pred, targets): |
| bp, bg = pred['bokeh'], targets['bokeh_gt'] |
| l1 = F.l1_loss(bp, bg) |
| C1, C2 = 0.01**2, 0.03**2 |
| mu_p = F.avg_pool2d(bp, 11, 1, 5); mu_g = F.avg_pool2d(bg, 11, 1, 5) |
| mu_pp = mu_p*mu_p; mu_gg = mu_g*mu_g; mu_pg = mu_p*mu_g |
| sig_pp = F.avg_pool2d(bp*bp, 11, 1, 5) - mu_pp |
| sig_gg = F.avg_pool2d(bg*bg, 11, 1, 5) - mu_gg |
| sig_pg = F.avg_pool2d(bp*bg, 11, 1, 5) - mu_pg |
| ssim_map = ((2*mu_pg+C1)*(2*sig_pg+C2)) / ((mu_pp+mu_gg+C1)*(sig_pp+sig_gg+C2)) |
| ssim_loss = 1.0 - ssim_map.mean() |
| return {'total': l1 + ssim_loss, 'l1': l1.detach(), 'ssim': ssim_loss.detach()} |
|
|
|
|
| |
| |
| |
|
|
| class RealBokehDataset(Dataset): |
| """Loads from local disk after snapshot_download.""" |
| def __init__(self, root, crop_size=256, split='train', target_fstop='2.0'): |
| self.crop = crop_size |
| self.pairs = [] |
| in_dir = Path(root) / split / 'in' |
| gt_dir = Path(root) / split / 'gt' |
| meta_dir = Path(root) / split / 'metadata' |
|
|
| for in_path in sorted(in_dir.glob('*_f22.JPG')): |
| sid = in_path.stem.split('_')[0] |
| gt_path = gt_dir / sid / f'{sid}_f{target_fstop}.JPG' |
| meta_path = meta_dir / f'{sid}.json' |
| if gt_path.exists(): |
| meta = {} |
| if meta_path.exists(): |
| with open(meta_path) as f: |
| meta = json.load(f) |
| self.pairs.append((str(in_path), str(gt_path), meta)) |
|
|
| print(f"RealBokehDataset: {len(self.pairs)} pairs found (target f/{target_fstop})") |
|
|
| def __len__(self): |
| return len(self.pairs) |
|
|
| def __getitem__(self, idx): |
| from PIL import Image |
| import torchvision.transforms.functional as TF |
|
|
| in_path, gt_path, meta = self.pairs[idx] |
| inp = Image.open(in_path).convert('RGB') |
| gt = Image.open(gt_path).convert('RGB') |
|
|
| |
| short = min(inp.size) |
| if short > 512: |
| scale = 512.0 / short |
| new_w = int(inp.size[0] * scale) |
| new_h = int(inp.size[1] * scale) |
| inp = inp.resize((new_w, new_h), Image.LANCZOS) |
| gt = gt.resize((new_w, new_h), Image.LANCZOS) |
|
|
| inp = TF.to_tensor(inp) |
| gt = TF.to_tensor(gt) |
|
|
| |
| _, h, w = inp.shape |
| cs = self.crop |
| if h >= cs and w >= cs: |
| i = random.randint(0, h - cs) |
| j = random.randint(0, w - cs) |
| inp = inp[:, i:i+cs, j:j+cs] |
| gt = gt[:, i:i+cs, j:j+cs] |
| else: |
| inp = F.interpolate(inp.unsqueeze(0), (cs, cs), mode='bilinear', align_corners=False).squeeze(0) |
| gt = F.interpolate(gt.unsqueeze(0), (cs, cs), mode='bilinear', align_corners=False).squeeze(0) |
|
|
| |
| if random.random() > 0.5: |
| inp = inp.flip(-1) |
| gt = gt.flip(-1) |
|
|
| focal = meta.get('focal_length', 50.0) |
| focus = meta.get('focus_plane_distance', 2.0) |
| fnum = 2.0 |
|
|
| return { |
| 'image': inp, |
| 'bokeh_gt': gt, |
| 'f_number': torch.tensor(fnum, dtype=torch.float32), |
| 'focal_length_mm': torch.tensor(float(focal), dtype=torch.float32), |
| 'focus_distance_m': torch.tensor(float(focus), dtype=torch.float32), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def download_realbokeh(max_scenes=None): |
| """Download RealBokeh_3MP using snapshot_download with exact patterns.""" |
| from huggingface_hub import snapshot_download |
|
|
| data_dir = '/tmp/realbokeh' |
| check_file = Path(data_dir) / 'train' / 'in' / '1_f22.JPG' |
| if check_file.exists(): |
| n = len(list(Path(data_dir).glob('train/in/*_f22.JPG'))) |
| print(f"Data already cached: {n} scenes") |
| return data_dir |
|
|
| print("Fetching metadata to build download list...") |
| snapshot_download( |
| 'timseizinger/RealBokeh_3MP', |
| repo_type='dataset', |
| local_dir=data_dir, |
| allow_patterns=['train/metadata/*.json'], |
| ) |
|
|
| meta_dir = Path(data_dir) / 'train' / 'metadata' |
| scene_ids = sorted([p.stem for p in meta_dir.glob('*.json')], key=lambda x: int(x)) |
|
|
| if max_scenes: |
| scene_ids = scene_ids[:max_scenes] |
|
|
| print(f"Found {len(scene_ids)} scenes. Downloading input + f/2.0 GT images...") |
|
|
| patterns = [] |
| for sid in scene_ids: |
| patterns.append(f'train/in/{sid}_f22.JPG') |
| patterns.append(f'train/gt/{sid}/{sid}_f2.0.JPG') |
|
|
| t0 = time.time() |
| snapshot_download( |
| 'timseizinger/RealBokeh_3MP', |
| repo_type='dataset', |
| local_dir=data_dir, |
| allow_patterns=patterns, |
| ) |
| elapsed = time.time() - t0 |
| n_found = len(list(Path(data_dir).glob('train/in/*_f22.JPG'))) |
| print(f"Downloaded {n_found} scenes in {elapsed:.0f}s") |
| return data_dir |
|
|
|
|
| |
| |
| |
|
|
| def train(): |
| import trackio |
|
|
| VARIANT = os.environ.get('VARIANT', 'small') |
| MAX_SCENES = int(os.environ.get('MAX_SCENES', '0')) or None |
| EPOCHS = int(os.environ.get('EPOCHS', '10')) |
| BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) |
| CROP_SIZE = int(os.environ.get('CROP_SIZE', '256')) |
| LR = float(os.environ.get('LR', '2e-4')) |
| HUB_MODEL_ID = os.environ.get('HUB_MODEL_ID', 'asdf98/BokehFlow') |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"Device: {device}") |
| if device == 'cuda': |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") |
|
|
| trackio.init(project="bokehflow", name=f"v3-{VARIANT}-e{EPOCHS}-lr{LR}") |
|
|
| data_dir = download_realbokeh(max_scenes=MAX_SCENES) |
|
|
| ds = RealBokehDataset(data_dir, crop_size=CROP_SIZE) |
| dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, |
| pin_memory=True, drop_last=True, persistent_workers=True) |
| print(f"Batches per epoch: {len(dl)}") |
|
|
| config = BokehFlowConfig(variant=VARIANT) |
| model = BokehFlow(config).to(device) |
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Model: BokehFlow-{VARIANT}, {n_params:,} params") |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01) |
| total_steps = EPOCHS * len(dl) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=LR/20) |
| loss_fn = BokehFlowLoss() |
|
|
| scaler = torch.amp.GradScaler('cuda', enabled=(device == 'cuda')) |
|
|
| global_step = 0 |
| best_loss = float('inf') |
|
|
| for epoch in range(EPOCHS): |
| model.train() |
| epoch_loss = 0.0 |
| t_epoch = time.time() |
|
|
| for batch_idx, batch in enumerate(dl): |
| t_step = time.time() |
| image = batch['image'].to(device) |
| bokeh_gt = batch['bokeh_gt'].to(device) |
| f_number = batch['f_number'].to(device) |
| focal_mm = batch['focal_length_mm'].to(device) |
| focus_m = batch['focus_distance_m'].to(device) |
|
|
| optimizer.zero_grad(set_to_none=True) |
|
|
| with torch.amp.autocast('cuda', enabled=(device == 'cuda')): |
| out = model(image, f_number, focal_mm, focus_m) |
| losses = loss_fn(out, {'bokeh_gt': bokeh_gt}) |
| loss = losses['total'] |
|
|
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
|
|
| epoch_loss += loss.item() |
| global_step += 1 |
| step_time = time.time() - t_step |
|
|
| if global_step % 10 == 0 or batch_idx == 0: |
| lr_now = scheduler.get_last_lr()[0] |
| print(f"Ep {epoch+1}/{EPOCHS} [{batch_idx+1}/{len(dl)}] " |
| f"loss={loss.item():.4f} l1={losses['l1'].item():.4f} " |
| f"ssim={losses['ssim'].item():.4f} lr={lr_now:.2e} " |
| f"step={step_time*1000:.0f}ms") |
| trackio.log({ |
| "loss": loss.item(), |
| "l1": losses['l1'].item(), |
| "ssim_loss": losses['ssim'].item(), |
| "lr": lr_now, |
| "step_ms": step_time * 1000, |
| "epoch": epoch + 1, |
| }) |
|
|
| if device == 'cuda' and global_step == 1: |
| vram = torch.cuda.max_memory_allocated() / 1e9 |
| print(f"Peak VRAM after 1st step: {vram:.2f} GB") |
| trackio.log({"peak_vram_gb": vram}) |
|
|
| epoch_time = time.time() - t_epoch |
| avg_loss = epoch_loss / len(dl) |
| print(f"Epoch {epoch+1}/{EPOCHS} done in {epoch_time:.0f}s, avg_loss={avg_loss:.4f}") |
| trackio.log({"epoch_avg_loss": avg_loss, "epoch_time_s": epoch_time}) |
|
|
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'config': config.__dict__, |
| 'epoch': epoch + 1, |
| 'loss': avg_loss, |
| }, '/tmp/bokehflow_best.pt') |
| print(f" Saved best model (loss={avg_loss:.4f})") |
|
|
| |
| print("\nPushing model to Hub...") |
| from huggingface_hub import HfApi |
| api = HfApi() |
|
|
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'config': config.__dict__, |
| 'epoch': EPOCHS, |
| 'loss': avg_loss, |
| }, '/tmp/bokehflow_final.pt') |
|
|
| for fname in ['bokehflow_best.pt', 'bokehflow_final.pt']: |
| fpath = f'/tmp/{fname}' |
| if os.path.exists(fpath): |
| api.upload_file( |
| path_or_fileobj=fpath, |
| path_in_repo=f'checkpoints/{fname}', |
| repo_id=HUB_MODEL_ID, |
| ) |
| print(f" Uploaded {fname}") |
|
|
| print(f"\nTraining complete! Model: https://huggingface.co/{HUB_MODEL_ID}") |
|
|
|
|
| if __name__ == '__main__': |
| train() |
|
|