| |
| """ |
| Simple Metrics Evaluation for Frequency-Aware Super-Denoiser |
| ============================================================ |
| Calculates PSNR, SSIM, and MSE metrics using existing sampling methods |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from PIL import Image |
| import os |
| from skimage.metrics import structural_similarity as ssim |
| import matplotlib.pyplot as plt |
|
|
| |
| from model import SmoothDiffusionUNet |
| from noise_scheduler import FrequencyAwareNoise |
| from config import Config |
| from dataloader import get_dataloaders |
| from sample import frequency_aware_sample |
|
|
| def calculate_psnr(img1, img2, max_val=2.0): |
| """Calculate PSNR between two images""" |
| mse = F.mse_loss(img1, img2) |
| if mse == 0: |
| return float('inf') |
| return 20 * torch.log10(torch.tensor(max_val) / torch.sqrt(mse)) |
|
|
| def calculate_ssim(img1, img2): |
| """Calculate SSIM between two images""" |
| |
| img1_np = img1.detach().cpu().numpy().transpose(1, 2, 0) |
| img2_np = img2.detach().cpu().numpy().transpose(1, 2, 0) |
| |
| |
| img1_np = (img1_np + 1) / 2 |
| img2_np = (img2_np + 1) / 2 |
| img1_np = np.clip(img1_np, 0, 1) |
| img2_np = np.clip(img2_np, 0, 1) |
| |
| return ssim(img1_np, img2_np, multichannel=True, channel_axis=2, data_range=1.0) |
|
|
| def add_noise(image, noise_level=0.2): |
| """Add Gaussian noise to images""" |
| noise = torch.randn_like(image) * noise_level |
| return torch.clamp(image + noise, -1, 1) |
|
|
| def evaluate_model(): |
| """Simplified model evaluation using existing sampling methods""" |
| print("π FREQUENCY-AWARE SUPER-DENOISER METRICS EVALUATION") |
| print("=" * 60) |
| |
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| config = Config() |
| |
| |
| model = SmoothDiffusionUNet(config).to(device) |
| if os.path.exists('model_final.pth'): |
| checkpoint = torch.load('model_final.pth', map_location=device, weights_only=False) |
| model.load_state_dict(checkpoint) |
| print("β
Model loaded successfully") |
| else: |
| print("β No trained model found! Please run training first.") |
| return |
| |
| model.eval() |
| scheduler = FrequencyAwareNoise(config) |
| |
| |
| try: |
| _, test_loader = get_dataloaders(config) |
| print(f"β
Test data loaded: {len(test_loader)} batches") |
| except: |
| print("β Could not load test data") |
| return |
| |
| |
| metrics = { |
| 'reconstruction_mse': [], |
| 'reconstruction_psnr': [], |
| 'reconstruction_ssim': [], |
| 'enhancement_mse': [], |
| 'enhancement_psnr': [], |
| 'enhancement_ssim': [] |
| } |
| |
| print("\nπ Evaluating reconstruction quality...") |
| |
| with torch.no_grad(): |
| for i, (images, _) in enumerate(test_loader): |
| if i >= 20: |
| break |
| |
| images = images.to(device) |
| batch_size = min(4, images.shape[0]) |
| images = images[:batch_size] |
| |
| print(f" Processing batch {i+1}/20...") |
| |
| |
| |
| lightly_noisy = add_noise(images, noise_level=0.1) |
| |
| |
| t_light = torch.full((batch_size,), 50, device=device, dtype=torch.long) |
| noisy_imgs, noise_spatial = scheduler.apply_noise(images, t_light) |
| |
| |
| predicted_noise = model(noisy_imgs, t_light) |
| |
| |
| alpha_bar = scheduler.alpha_bars[50].item() |
| reconstructed = (noisy_imgs - np.sqrt(1 - alpha_bar) * predicted_noise) / np.sqrt(alpha_bar) |
| |
| |
| for j in range(batch_size): |
| original = images[j] |
| recon = reconstructed[j] |
| |
| |
| mse_val = F.mse_loss(original, recon).item() |
| metrics['reconstruction_mse'].append(mse_val) |
| |
| |
| psnr_val = calculate_psnr(original, recon, max_val=2.0).item() |
| metrics['reconstruction_psnr'].append(psnr_val) |
| |
| |
| ssim_val = calculate_ssim(original, recon) |
| metrics['reconstruction_ssim'].append(ssim_val) |
| |
| |
| |
| noisy_enhanced = add_noise(images, noise_level=0.3) |
| |
| |
| t_heavy = torch.full((batch_size,), 150, device=device, dtype=torch.long) |
| heavy_noisy, _ = scheduler.apply_noise(images, t_heavy) |
| |
| |
| enhanced = heavy_noisy.clone() |
| timesteps = [150, 100, 50, 25, 10, 5, 1] |
| |
| for t_val in timesteps: |
| t_tensor = torch.full((batch_size,), max(t_val, 0), device=device, dtype=torch.long) |
| pred_noise = model(enhanced, t_tensor) |
| |
| |
| if t_val > 0: |
| alpha_bar = scheduler.alpha_bars[t_val].item() |
| enhanced = (enhanced - 0.1 * pred_noise) |
| enhanced = torch.clamp(enhanced, -1, 1) |
| |
| |
| for j in range(batch_size): |
| original = images[j] |
| enhanced_img = enhanced[j] |
| |
| mse_val = F.mse_loss(original, enhanced_img).item() |
| metrics['enhancement_mse'].append(mse_val) |
| |
| psnr_val = calculate_psnr(original, enhanced_img, max_val=2.0).item() |
| metrics['enhancement_psnr'].append(psnr_val) |
| |
| ssim_val = calculate_ssim(original, enhanced_img) |
| metrics['enhancement_ssim'].append(ssim_val) |
| |
| |
| print("\nπ FINAL METRICS RESULTS:") |
| print("=" * 60) |
| |
| print("π― RECONSTRUCTION PERFORMANCE (Light Noise β Original):") |
| recon_mse = np.mean(metrics['reconstruction_mse']) |
| recon_psnr = np.mean(metrics['reconstruction_psnr']) |
| recon_ssim = np.mean(metrics['reconstruction_ssim']) |
| |
| print(f" MSE: {recon_mse:.6f} Β± {np.std(metrics['reconstruction_mse']):.6f}") |
| print(f" PSNR: {recon_psnr:.2f} Β± {np.std(metrics['reconstruction_psnr']):.2f} dB") |
| print(f" SSIM: {recon_ssim:.4f} Β± {np.std(metrics['reconstruction_ssim']):.4f}") |
| |
| print("\nπ§Ή ENHANCEMENT PERFORMANCE (Heavy Noise β Original):") |
| enh_mse = np.mean(metrics['enhancement_mse']) |
| enh_psnr = np.mean(metrics['enhancement_psnr']) |
| enh_ssim = np.mean(metrics['enhancement_ssim']) |
| |
| print(f" MSE: {enh_mse:.6f} Β± {np.std(metrics['enhancement_mse']):.6f}") |
| print(f" PSNR: {enh_psnr:.2f} Β± {np.std(metrics['enhancement_psnr']):.2f} dB") |
| print(f" SSIM: {enh_ssim:.4f} Β± {np.std(metrics['enhancement_ssim']):.4f}") |
| |
| |
| def grade_metric(value, thresholds, metric_name): |
| if metric_name == 'MSE': |
| if value < thresholds[0]: return "Excellent β
" |
| elif value < thresholds[1]: return "Very Good π’" |
| elif value < thresholds[2]: return "Good π΅" |
| else: return "Fair π‘" |
| else: |
| if value > thresholds[0]: return "Excellent β
" |
| elif value > thresholds[1]: return "Very Good π’" |
| elif value > thresholds[2]: return "Good π΅" |
| else: return "Fair π‘" |
| |
| print("\nπ RECONSTRUCTION GRADES:") |
| print(f" MSE: {grade_metric(recon_mse, [0.01, 0.05, 0.1], 'MSE')}") |
| print(f" PSNR: {grade_metric(recon_psnr, [35, 30, 25], 'PSNR')}") |
| print(f" SSIM: {grade_metric(recon_ssim, [0.9, 0.8, 0.7], 'SSIM')}") |
| |
| print("\nπ ENHANCEMENT GRADES:") |
| print(f" MSE: {grade_metric(enh_mse, [0.05, 0.1, 0.2], 'MSE')}") |
| print(f" PSNR: {grade_metric(enh_psnr, [30, 25, 20], 'PSNR')}") |
| print(f" SSIM: {grade_metric(enh_ssim, [0.85, 0.75, 0.65], 'SSIM')}") |
| |
| |
| print("\nπ SUMMARY FOR README:") |
| print("=" * 60) |
| print("Reconstruction Performance:") |
| print(f"- MSE: {recon_mse:.6f}") |
| print(f"- PSNR: {recon_psnr:.1f} dB") |
| print(f"- SSIM: {recon_ssim:.4f}") |
| print("\nEnhancement Performance:") |
| print(f"- MSE: {enh_mse:.6f}") |
| print(f"- PSNR: {enh_psnr:.1f} dB") |
| print(f"- SSIM: {enh_ssim:.4f}") |
| |
| print("\nπ Metrics evaluation completed!") |
| return { |
| 'recon_mse': recon_mse, |
| 'recon_psnr': recon_psnr, |
| 'recon_ssim': recon_ssim, |
| 'enh_mse': enh_mse, |
| 'enh_psnr': enh_psnr, |
| 'enh_ssim': enh_ssim |
| } |
|
|
| if __name__ == "__main__": |
| evaluate_model() |
|
|