| import torch |
| import torch.nn.functional as F |
| import cv2 |
| import numpy as np |
| from torchvision import transforms |
| import os |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| import time |
| from skimage.metrics import structural_similarity as ssim |
| from skimage.color import rgb2lab |
|
|
|
|
| from combined import IFNet, warp |
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| model = IFNet().to(device) |
|
|
|
|
| checkpoint_path = "save_checkpoints/model_epoch_50.pth" |
| checkpoint = torch.load(checkpoint_path, weights_only=False, map_location=device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.eval() |
|
|
| print(f"Loaded model from epoch {checkpoint['epoch']} with PSNR: {checkpoint.get('psnr', 'N/A')} dB") |
|
|
|
|
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
|
|
| def preprocess_images(img0_path, img1_path, gt_path=None): |
| |
| img0 = cv2.imread(img0_path) |
| img1 = cv2.imread(img1_path) |
| |
| if img0 is None or img1 is None: |
| raise ValueError(f"Could not read images: {img0_path}, {img1_path}") |
| |
| |
| gt = None |
| if gt_path and os.path.exists(gt_path): |
| gt = cv2.imread(gt_path) |
| if gt is None: |
| print(f"Warning: Could not read ground truth image: {gt_path}") |
| gt = None |
| else: |
| gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB) |
| |
| |
| img0 = cv2.cvtColor(img0, cv2.COLOR_BGR2RGB) |
| img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) |
| |
| original_size = (img0.shape[0], img0.shape[1]) |
|
|
| orig_img0 = img0.copy() |
| orig_img1 = img1.copy() |
| |
| img0_resized = cv2.resize(img0, (256, 256)) |
| img1_resized = cv2.resize(img1, (256, 256)) |
|
|
| img0_tensor = transform(img0_resized) |
| img1_tensor = transform(img1_resized) |
| |
|
|
| input_tensor = torch.cat((img0_tensor, img1_tensor), 0).unsqueeze(0).to(device) |
| |
| return input_tensor, original_size, orig_img0, orig_img1, gt |
|
|
| def tensor_to_image(tensor): |
| tensor = tensor.cpu() |
| |
| tensor = tensor * 0.5 + 0.5 |
| tensor = tensor.clamp(0, 1) |
| |
| img = tensor.numpy().transpose(1, 2, 0) * 255 |
| return img.astype(np.uint8) |
|
|
|
|
| def calculate_psnr(img1, img2): |
| mse = np.mean((img1.astype(np.float32) - img2.astype(np.float32)) ** 2) |
| if mse == 0: |
| return float('inf') |
| return 10 * np.log10(255.0 ** 2 / mse) |
|
|
| def calculate_ssim(img1, img2): |
| |
| if img1.ndim == 3 and img1.shape[2] == 3: |
| gray1 = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY) |
| gray2 = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY) |
| return ssim(gray1, gray2) |
| return ssim(img1, img2) |
|
|
|
|
| def calculate_cd(img1, img2): |
| lab1 = rgb2lab(img1 / 255.0) |
| lab2 = rgb2lab(img2 / 255.0) |
| |
| delta_e = np.sqrt(np.sum((lab1 - lab2) ** 2, axis=2)) |
| return np.mean(delta_e) |
|
|
| def calculate_ie(interpolated, gt): |
| return np.mean(np.abs(interpolated.astype(np.float32) - gt.astype(np.float32))) |
|
|
| def interpolate_frames(img0_path, img1_path, output_path, gt_path=None): |
| input_tensor, original_size, img0, img1, gt = preprocess_images(img0_path, img1_path, gt_path) |
| |
| start_time = time.time() |
| with torch.no_grad(): |
| flow, mask, interpolated = model(input_tensor) |
| inference_time = time.time() - start_time |
| print(f"Inference time: {inference_time:.4f} seconds") |
| |
| interpolated_img = tensor_to_image(interpolated[0]) |
| |
| interpolated_img = cv2.resize(interpolated_img, (original_size[1], original_size[0])) |
| |
|
|
| interpolated_img_bgr = cv2.cvtColor(interpolated_img, cv2.COLOR_RGB2BGR) |
| cv2.imwrite(output_path, interpolated_img_bgr) |
| |
| metrics = {} |
| if gt is not None: |
| metrics['psnr'] = calculate_psnr(interpolated_img, gt) |
| metrics['ssim'] = calculate_ssim(interpolated_img, gt) |
| metrics['cd'] = calculate_cd(interpolated_img, gt) |
| metrics['ie'] = calculate_ie(interpolated_img, gt) |
| |
| print(f"Metrics (compared to ground truth):") |
| print(f" PSNR: {metrics['psnr']:.4f} dB") |
| print(f" SSIM: {metrics['ssim']:.4f}") |
| print(f" Color Difference (CD): {metrics['cd']:.4f}") |
| print(f" Interpolation Error (IE): {metrics['ie']:.4f}") |
| |
| |
| return img0, img1, interpolated_img, gt, metrics |
|
|
|
|
| def display_results(img0, img1, interpolated, gt, metrics, output_path): |
| |
| has_gt = gt is not None |
| |
| |
| plt.figure(figsize=(15, 5 if not has_gt else 10)) |
| |
| |
| plt.subplot(2 if has_gt else 1, 3, 1) |
| plt.imshow(img0) |
| plt.title('Frame 1') |
| plt.axis('off') |
| |
| plt.subplot(2 if has_gt else 1, 3, 2) |
| plt.imshow(interpolated) |
| plt.title('Interpolated Frame') |
| plt.axis('off') |
| |
| plt.subplot(2 if has_gt else 1, 3, 3) |
| plt.imshow(img1) |
| plt.title('Frame 2') |
| plt.axis('off') |
| |
| |
| if has_gt: |
| plt.subplot(2, 3, 4) |
| plt.imshow(gt) |
| plt.title('Ground Truth') |
| plt.axis('off') |
| |
| plt.subplot(2, 3, 5) |
| |
| diff = np.abs(interpolated.astype(np.float32) - gt.astype(np.float32)) |
| plt.imshow(diff.astype(np.uint8)) |
| plt.title('Difference') |
| plt.axis('off') |
| |
| plt.subplot(2, 3, 6) |
| plt.axis('off') |
| metrics_text = "\n".join([ |
| f"PSNR: {metrics['psnr']:.2f} dB", |
| f"SSIM: {metrics['ssim']:.4f}", |
| f"CD: {metrics['cd']:.2f}", |
| f"IE: {metrics['ie']:.2f}" |
| ]) |
| plt.text(0.1, 0.5, metrics_text, fontsize=12) |
| plt.title('Metrics') |
| |
| plt.tight_layout() |
| plt.savefig(output_path.replace('.png', '_comparison.png')) |
| plt.show() |
|
|
| |
| test_pairs = [ |
| |
| ("test_frames/frame1.png", "test_frames/frame3.png", "results/scene1_interpolated.png", "test_frames/frame2.png"), |
| |
| ] |
|
|
|
|
| os.makedirs("results", exist_ok=True) |
|
|
| for test_item in test_pairs: |
| img0_path, img1_path, output_path = test_item[0], test_item[1], test_item[2] |
| gt_path = test_item[3] if len(test_item) > 3 else None |
| |
| print(f"Processing: {img0_path} and {img1_path}") |
| try: |
| img0, img1, interpolated, gt, metrics = interpolate_frames(img0_path, img1_path, output_path, gt_path) |
| display_results(img0, img1, interpolated, gt, metrics, output_path) |
| except Exception as e: |
| print(f"Error processing frames: {e}") |
| import traceback |
| traceback.print_exc() |