| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision import transforms |
| import os |
| import cv2 |
| import numpy as np |
| from torch.amp import autocast, GradScaler |
| from torch.utils.data.distributed import DistributedSampler |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| import os |
| from torch.optim.lr_scheduler import CosineAnnealingLR |
| import time |
| from datetime import timedelta |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| checkpoint_dir = 'save_checkpoints/' |
| if not os.path.exists(checkpoint_dir): |
| os.makedirs(checkpoint_dir) |
|
|
| backwarp_tenGrid = {} |
|
|
| def warp(tenInput, tenFlow): |
| if tenFlow.dim() == 3: |
| tenFlow = tenFlow.unsqueeze(1) |
| if tenFlow.size(1) != 2: |
| raise ValueError(f"tenFlow must have 2 channels. Got {tenFlow.size(1)} channels.") |
|
|
| k = (str(tenFlow.device), str(tenFlow.size())) |
| if k not in backwarp_tenGrid: |
| tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( |
| 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) |
| tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( |
| 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) |
| backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) |
|
|
| tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), |
| tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) |
|
|
| g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) |
| return F.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) |
|
|
|
|
| def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1): |
| return nn.Sequential( |
| nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=True), |
| nn.LeakyReLU(0.1, inplace=True) |
| ) |
|
|
| def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): |
| return nn.Sequential( |
| nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), |
| nn.LeakyReLU(0.1, inplace=True) |
| ) |
|
|
| |
| class IFBlock(nn.Module): |
| def __init__(self, in_planes, c=48): |
| super(IFBlock, self).__init__() |
| self.conv0 = nn.Sequential( |
| conv(in_planes, c, 3, 2, 1), |
| conv(c, 2*c, 3, 2, 1), |
| ) |
| self.convblock = nn.Sequential( |
| conv(2*c, 2*c), |
| conv(2*c, 2*c), |
| conv(2*c, 2*c), |
| conv(2*c, 2*c), |
| ) |
| self.lastconv = nn.ConvTranspose2d(2*c, 5, 4, 2, 1) |
|
|
| def forward(self, x, flow=None, scale=1): |
| if scale != 1: |
| x = F.interpolate(x, scale_factor=1./scale, mode="bilinear", align_corners=False) |
| if flow is not None: |
| flow = F.interpolate(flow, scale_factor=1./scale, mode="bilinear", align_corners=False) * 1./scale |
| x = torch.cat((x, flow), 1) |
| x = self.conv0(x) |
| x = self.convblock(x) + x |
| tmp = self.lastconv(x) |
| tmp = F.interpolate(tmp, scale_factor=scale*2, mode="bilinear", align_corners=False) |
| flow = tmp[:, :4] * scale*2 |
| mask = tmp[:, 4:5] |
| return flow, mask |
|
|
| |
| c = 48 |
|
|
| |
| class Contextnet(nn.Module): |
| def __init__(self): |
| super(Contextnet, self).__init__() |
| self.conv1 = conv(3, c) |
| self.conv2 = conv(c, 2*c) |
| self.conv3 = conv(2*c, 4*c) |
|
|
| def forward(self, x, flow): |
| x = self.conv1(x) |
| flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
| f1 = warp(x, flow) |
| |
| x = self.conv2(x) |
| flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
| f2 = warp(x, flow) |
| |
| x = self.conv3(x) |
| flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
| f3 = warp(x, flow) |
| |
| return [f1, f2, f3] |
|
|
| class Unet(nn.Module): |
| def __init__(self): |
| super(Unet, self).__init__() |
| |
| self.down0 = conv(17, 2*c) |
| self.down1 = conv(4*c, 4*c) |
| self.down2 = conv(8*c, 8*c) |
| |
| self.up0 = deconv(8*c, 4*c) |
| self.up1 = deconv(4*c, 2*c) |
| self.up2 = deconv(2*c, c) |
| self.conv = nn.Conv2d(c, 3, 3, 1, 1) |
|
|
| def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): |
| s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) |
| |
| c0_0 = F.interpolate(c0[0], size=(s0.size(2), s0.size(3)), mode="bilinear", align_corners=False) |
| c1_0 = F.interpolate(c1[0], size=(s0.size(2), s0.size(3)), mode="bilinear", align_corners=False) |
| s1 = self.down1(torch.cat((s0, c0_0, c1_0), 1)) |
| |
| c1_1 = F.interpolate(c1[1], size=(s1.size(2), s1.size(3)), mode="bilinear", align_corners=False) |
| s2 = self.down2(torch.cat((s1, c0_1, c1_1), 1)) |
| |
| x = self.up0(s2) |
| x = self.up1(x) |
| x = self.up2(x) |
| x = self.conv(x) |
| return x |
|
|
| class IFNet(nn.Module): |
| def __init__(self): |
| super(IFNet, self).__init__() |
| self.block0 = IFBlock(6, c=64) |
| self.block1 = IFBlock(13+4, c=64) |
| self.contextnet = Contextnet() |
| self.unet = Unet() |
|
|
| def forward(self, x, scale=[4, 2, 1], timestep=0.5): |
| img0 = x[:, :3] |
| img1 = x[:, 3:6] |
| gt = x[:, 6:] if x.shape[1] > 6 else None |
| |
| |
| flow, mask = self.block0(torch.cat((img0, img1), 1), None, scale=scale[0]) |
| |
| |
| warped_img0 = warp(img0, flow[:, :2]) |
| warped_img1 = warp(img1, flow[:, 2:4]) |
| |
| flow_d, mask_d = self.block1(torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[1]) |
| flow = flow + flow_d |
| mask = mask + mask_d |
| |
| |
| warped_img0 = warp(img0, flow[:, :2]) |
| warped_img1 = warp(img1, flow[:, 2:4]) |
| |
| |
| mask_final = torch.sigmoid(mask) |
| merged = warped_img0 * mask_final + warped_img1 * (1 - mask_final) |
| |
| |
| c0 = self.contextnet(img0, flow[:, :2]) |
| c1 = self.contextnet(img1, flow[:, 2:4]) |
| |
| |
| refined = self.unet(img0, img1, warped_img0, warped_img1, mask_final, flow, c0, c1) |
| refined = refined[:, :3] * 2 - 1 |
| |
| |
| if refined.size(2) != merged.size(2) or refined.size(3) != merged.size(3): |
| refined = F.interpolate(refined, size=(merged.size(2), merged.size(3)), mode='bilinear', align_corners=False) |
| |
| |
| final_output = torch.clamp(merged + refined, 0, 1) |
| |
| return flow, mask_final, final_output |
|
|
| class FrameInterpolationDataset(Dataset): |
| def __init__(self, data_dir, transform=None, resize=None, cache_size=100, augment=True): |
| self.data_dir = data_dir |
| self.transform = transform |
| self.resize = resize |
| self.frame_pairs = self._load_frame_pairs() |
| self.cache = {} |
| self.cache_size = cache_size |
| self.augment = augment |
| |
| def _load_frame_pairs(self): |
| frame_pairs = [] |
| |
| if not os.path.exists(self.data_dir): |
| raise ValueError(f"Dataset directory does not exist: {self.data_dir}") |
| |
| for seq in os.listdir(self.data_dir): |
| seq_dir = os.path.join(self.data_dir, seq) |
| |
| if not os.path.isdir(seq_dir): |
| continue |
| |
| frames = sorted([f for f in os.listdir(seq_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]) |
| |
| if len(frames) < 3: |
| continue |
| |
| for i in range(len(frames) - 2): |
| frame_pairs.append(( |
| os.path.join(seq_dir, frames[i]), |
| os.path.join(seq_dir, frames[i+2]), |
| os.path.join(seq_dir, frames[i+1]) |
| )) |
| |
| if not frame_pairs: |
| raise ValueError(f"No valid frame pairs found in {self.data_dir}") |
| |
| return frame_pairs |
|
|
| def __len__(self): |
| return len(self.frame_pairs) |
|
|
| def __getitem__(self, idx): |
| if idx in self.cache: |
| return self.cache[idx] |
| |
| img0_path, img1_path, gt_path = self.frame_pairs[idx] |
| |
| img0 = cv2.imread(img0_path) |
| img1 = cv2.imread(img1_path) |
| gt = cv2.imread(gt_path) |
| |
| if img0 is None or img1 is None or gt is None: |
| raise ValueError(f"Could not read one of the images: {self.frame_pairs[idx]}") |
| |
| img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB) |
| img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) |
| gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB) |
| |
| if self.resize: |
| img0 = cv2.resize(img0, self.resize, interpolation=cv2.INTER_AREA) |
| img1 = cv2.resize(img1, self.resize, interpolation=cv2.INTER_AREA) |
| gt = cv2.resize(gt, self.resize, interpolation=cv2.INTER_AREA) |
| |
| if self.augment: |
| if np.random.random() > 0.5: |
| img0 = np.flip(img0, axis=1).copy() |
| img1 = np.flip(img1, axis=1).copy() |
| gt = np.flip(gt, axis=1).copy() |
| |
| if np.random.random() > 0.5: |
| img0 = np.flip(img0, axis=0).copy() |
| img1 = np.flip(img1, axis=0).copy() |
| gt = np.flip(gt, axis=0).copy() |
| |
| if np.random.random() > 0.5: |
| brightness = 0.9 + np.random.random() * 0.2 |
| img0 = np.clip(img0 * brightness, 0, 255).astype(np.uint8) |
| img1 = np.clip(img1 * brightness, 0, 255).astype(np.uint8) |
| gt = np.clip(gt * brightness, 0, 255).astype(np.uint8) |
| |
| if self.transform: |
| img0 = self.transform(img0) |
| img1 = self.transform(img1) |
| gt = self.transform(gt) |
| |
| result = torch.cat((img0, img1, gt), 0) |
| |
| if len(self.cache) < self.cache_size: |
| self.cache[idx] = result |
| |
| return result |
| |
| def train_with_amp(model, train_dataloader, val_dataloader, optimizer, scheduler, criterion, |
| num_epochs=10, patience=5, start_epoch=0, best_val_loss=float('inf'), best_val_psnr=0.0): |
| scaler = GradScaler() |
| patience_counter = 0 |
| |
| |
| total_start_time = time.time() |
| epoch_times = [] |
| |
| for epoch in range(start_epoch, num_epochs): |
| epoch_start_time = time.time() |
| model.train() |
| train_loss = 0.0 |
| |
| for i, data in enumerate(train_dataloader): |
| data = data.to(device, non_blocking=True) |
| |
| with autocast(device_type='cuda'): |
| flow, mask, final_output = model(data) |
| loss = criterion(final_output, data[:, 6:9]) |
| |
| optimizer.zero_grad(set_to_none=True) |
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
| |
| train_loss += loss.item() |
| |
| if i % 10 == 0: |
| print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(train_dataloader)}], Loss: {loss.item():.6f}, LR: {scheduler.get_last_lr()[0]:.6f}") |
| |
| scheduler.step() |
| |
| avg_train_loss = train_loss / len(train_dataloader) |
| |
| val_loss, val_psnr = validate_with_amp(model, val_dataloader, criterion) |
| |
| epoch_end_time = time.time() |
| epoch_time = epoch_end_time - epoch_start_time |
| epoch_times.append(epoch_time) |
| |
| avg_epoch_time = sum(epoch_times) / len(epoch_times) |
| epochs_remaining = num_epochs - (epoch + 1) |
| est_time_remaining = avg_epoch_time * epochs_remaining |
| |
| epoch_time_str = str(timedelta(seconds=int(epoch_time))) |
| est_remaining_str = str(timedelta(seconds=int(est_time_remaining))) |
| total_elapsed_str = str(timedelta(seconds=int(time.time() - total_start_time))) |
| |
| print(f"Epoch [{epoch+1}/{num_epochs}] completed. Train Loss: {avg_train_loss:.6f}, " |
| f"Validation Loss: {val_loss:.6f}, Validation PSNR: {val_psnr:.4f} dB") |
| print(f"Time: {epoch_time_str} | Total: {total_elapsed_str} | Remaining: {est_remaining_str}") |
| |
| checkpoint_path = f"{checkpoint_dir}/model_epoch_{epoch+1}.pth" |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'scheduler_state_dict': scheduler.state_dict(), |
| 'loss': val_loss, |
| 'psnr': val_psnr, |
| }, checkpoint_path) |
|
|
| if val_psnr > best_val_psnr: |
| best_val_psnr = val_psnr |
| torch.save(model.state_dict(), f"{checkpoint_dir}/best_psnr_model.pth") |
| print(f"Model saved with improved validation PSNR: {best_val_psnr:.4f} dB") |
| patience_counter = 0 |
| elif val_loss < best_val_loss: |
| best_val_loss = val_loss |
| torch.save(model.state_dict(), f"{checkpoint_dir}/best_loss_model.pth") |
| print(f"Model saved with improved validation loss: {best_val_loss:.6f}") |
| patience_counter = 0 |
| else: |
| patience_counter += 1 |
| if patience_counter >= patience: |
| print(f"Early stopping triggered after {epoch+1} epochs") |
| break |
| |
| total_training_time = time.time() - total_start_time |
| total_time_str = str(timedelta(seconds=int(total_training_time))) |
| avg_epoch_time = total_training_time / min(num_epochs, epoch+1) |
| avg_epoch_time_str = str(timedelta(seconds=int(avg_epoch_time))) |
| |
| print(f"Training completed in {total_time_str} ({avg_epoch_time_str} per epoch)") |
| print(f"Best validation PSNR: {best_val_psnr:.4f} dB") |
| |
| with open(f"{checkpoint_dir}/training_time_stats.txt", "w") as f: |
| f.write(f"Total training time: {total_time_str}\n") |
| f.write(f"Average epoch time: {avg_epoch_time_str}\n") |
| f.write(f"Total epochs: {epoch+1}\n") |
| f.write(f"Best validation PSNR: {best_val_psnr:.4f} dB\n") |
| f.write(f"Final learning rate: {scheduler.get_last_lr()[0]:.8f}\n") |
| |
| f.write("\nEpoch times:\n") |
| for i, e_time in enumerate(epoch_times): |
| e_time_str = str(timedelta(seconds=int(e_time))) |
| f.write(f"Epoch {i+1}: {e_time_str}\n") |
| |
| return best_val_psnr, total_time_str |
|
|
| def validate_with_amp(model, dataloader, criterion): |
| model.eval() |
| total_loss = 0.0 |
| total_psnr = 0.0 |
| |
| with torch.no_grad(): |
| for data in dataloader: |
| data = data.to(device, non_blocking=True) |
| |
| with autocast(device_type='cuda'): |
| flow, mask, final_output = model(data) |
| gt = data[:, 6:9] |
| loss = criterion(final_output, gt) |
| |
| total_loss += loss.item() |
| |
| mse = F.mse_loss(final_output, gt).item() |
| if mse > 0: |
| psnr = 10 * np.log10(1.0 / mse) |
| else: |
| psnr = float('inf') |
| total_psnr += psnr |
| |
| avg_loss = total_loss / len(dataloader) |
| avg_psnr = total_psnr / len(dataloader) |
| |
| return avg_loss, avg_psnr |
|
|
| def load_checkpoint_and_resume(checkpoint_path, model, optimizer, scheduler): |
| print(f"Loading checkpoint from {checkpoint_path}...") |
| try: |
| checkpoint = torch.load(checkpoint_path, weights_only=False) |
| |
| model.load_state_dict(checkpoint['model_state_dict']) |
| |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| |
| if 'scheduler_state_dict' in checkpoint: |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| |
| start_epoch = checkpoint['epoch'] + 1 |
| |
| best_val_loss = checkpoint.get('loss', float('inf')) |
| best_val_psnr = checkpoint.get('psnr', 0.0) |
| |
| print(f"Successfully loaded checkpoint from epoch {checkpoint['epoch']}") |
| print(f"Resuming training from epoch {start_epoch}") |
| print(f"Best validation loss: {best_val_loss:.6f}, Best PSNR: {best_val_psnr:.4f} dB") |
| |
| return model, optimizer, scheduler, start_epoch, best_val_loss, best_val_psnr |
| |
| except Exception as e: |
| print(f"Error loading checkpoint: {e}") |
| raise e |
| |
| if __name__ == "__main__": |
| data_dir_train = "datasets/train_10k" |
| data_dir_val = "datasets/test_2k" |
| batch_size = 16 |
| resize = (256, 256) |
| world_size = torch.cuda.device_count() |
| |
| load_from_checkpoint = True |
| checkpoint_path = f"{checkpoint_dir}/model_epoch_49.pth" |
| |
| |
| if world_size <= 1: |
| model = IFNet().to(device) |
| optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5) |
| scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6) |
| |
| start_epoch = 0 |
| best_val_loss = float('inf') |
| best_val_psnr = 0.0 |
|
|
| if load_from_checkpoint: |
| try: |
| model, optimizer, scheduler, start_epoch, best_val_loss, best_val_psnr = load_checkpoint_and_resume( |
| checkpoint_path, model, optimizer, scheduler |
| ) |
| except Exception as e: |
| print(f"Failed to load checkpoint: {e}") |
| print("Starting training from scratch instead.") |
|
|
| model = IFNet().to(device) |
| optimizer = optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-5) |
| scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6) |
| start_epoch = 0 |
| best_val_loss = float('inf') |
| best_val_psnr = 0.0 |
| |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
| |
| train_dataset = FrameInterpolationDataset( |
| data_dir=data_dir_train, |
| transform=transform, |
| resize=resize, |
| augment=True |
| ) |
| |
| val_dataset = FrameInterpolationDataset( |
| data_dir=data_dir_val, |
| transform=transform, |
| resize=resize, |
| augment=False |
| ) |
| |
| train_dataloader = DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=4, |
| pin_memory=True, |
| drop_last=True |
| ) |
| |
| val_dataloader = DataLoader( |
| val_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=4, |
| pin_memory=True |
| ) |
| |
| criterion = nn.L1Loss() |
|
|
| best_psnr, total_time = train_with_amp( |
| model, |
| train_dataloader, |
| val_dataloader, |
| optimizer, |
| scheduler, |
| criterion, |
| num_epochs=50, |
| patience=10, |
| start_epoch=start_epoch, |
| best_val_loss=best_val_loss, |
| best_val_psnr=best_val_psnr |
| ) |
| print(f"Training completed in {total_time} with best PSNR of {best_psnr:.4f} dB") |
| else: |
|
|
| print("Distributed training not implemented in this example") |