import os import glob import torch import torch.nn.functional as F import math from PIL import Image from torchvision import transforms from torch.utils.data import Dataset, DataLoader # --- Configuration --- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) _SIGMAS = (15.0 / 255.0, 50.0 / 255.0, 75.0 / 255.0) def _testset_root(name): return os.path.join( _SCRIPT_DIR, "datasets", "Test_Datasets", "FFDNet-master", "testsets", name, ) class TestDataset(Dataset): def __init__(self, root_dir, sigma): self.sigma = sigma self.image_paths = glob.glob(os.path.join(root_dir, "*.png")) + glob.glob( os.path.join(root_dir, "*.jpg") ) self.transform = transforms.Compose( [transforms.Grayscale(num_output_channels=1), transforms.ToTensor()] ) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): img = Image.open(self.image_paths[idx]) clean = self.transform(img) noisy = clean + torch.randn_like(clean) * self.sigma return clean, torch.clamp(noisy, 0.0, 1.0) def calculate_psnr(img1, img2): mse = torch.mean((img1 - img2) ** 2) if mse == 0: return float("inf") return 20 * math.log10(1.0 / math.sqrt(mse)) def classical_telegraph_step(u_n, u_n_minus_1, tau=0.2, gamma=1.0): kx = torch.tensor([[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]], device=DEVICE).view( 1, 1, 3, 3 ) ky = torch.tensor([[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]], device=DEVICE).view( 1, 1, 3, 3 ) grad_x = F.conv2d(u_n, kx, padding=1) grad_y = F.conv2d(u_n, ky, padding=1) grad_mag = torch.sqrt(grad_x**2 + grad_y**2 + 1e-8) c = 1.0 / (1.0 + (grad_mag / 0.1) ** 2) divergence = F.conv2d(c * grad_x, kx, padding=1) + F.conv2d( c * grad_y, ky, padding=1 ) alpha = (2 + gamma * tau) / (1 + gamma * tau) beta = -1 / (1 + gamma * tau) lam = (tau**2) / (1 + gamma * tau) return alpha * u_n + beta * u_n_minus_1 + lam * divergence def run_eval(dataset_name, sigma): root = _testset_root(dataset_name) dataset = TestDataset(root, sigma) if len(dataset) == 0: print(f"[!] Skip {dataset_name}: no images in {os.path.abspath(root)}") return None dataloader = DataLoader(dataset, batch_size=1, shuffle=False) total_psnr = 0.0 for clean, noisy in dataloader: clean, noisy = clean.to(DEVICE), noisy.to(DEVICE) u_n_minus_1, u_n = noisy.clone(), noisy.clone() for _ in range(20): u_next = classical_telegraph_step(u_n, u_n_minus_1) u_n_minus_1, u_n = u_n, u_next total_psnr += calculate_psnr(clean, u_n) avg_psnr = total_psnr / len(dataset) sigma_int = int(round(sigma * 255.0)) print( f"[+] {dataset_name} sigma={sigma_int}/255 PSNR: {avg_psnr:.2f} dB " f"({len(dataset)} images)" ) return avg_psnr def main(): print("[*] Classical Majee 2020 baseline — Set12 & BSD68") for dataset_name in ("Set12", "BSD68"): for sigma in _SIGMAS: run_eval(dataset_name, sigma) print() if __name__ == "__main__": main()