import argparse import os import torch from PIL import Image from torchvision import transforms from train_network import ( DEFAULT_GAMMAS, DEFAULT_SIGMA_LEVELS, DEFAULT_TAU_INITS, _collect_image_paths, _SCRIPT_DIR, calculate_psnr, gamma_tag, sigma_int_to_float, sigma_tag, tau_tag, ) from train_network import UnrolledNetwork as MLPUnrolledNetwork from train_network_rbf import UnrolledNetwork as RBFUnrolledNetwork from train_tnrd_baseline import TNRDBaselineNetwork DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") TESTSETS = ("Set12", "BSD68") def _testset_root(name): return os.path.join( _SCRIPT_DIR, "datasets", "Test_Datasets", "FFDNet-master", "testsets", name, ) def _autocast_context(): return ( torch.amp.autocast("cuda") if DEVICE.type == "cuda" else torch.autocast("cpu", enabled=False) ) def _build_model(model_type, stages, use_wave, damping_gamma, tau_init): if model_type == "mlp": return MLPUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init).to(DEVICE) if model_type == "rbf": return RBFUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init).to(DEVICE) if model_type == "tnrd": return TNRDBaselineNetwork(stages, tau_init=tau_init).to(DEVICE) raise ValueError(f"Unknown model type: {model_type}") def _checkpoint_specs(stages, sigmas, gammas, tau_inits, include_finetuned): specs = [] for sigma in sigmas: sigma_name = sigma_tag(sigma) for tau_init in tau_inits: tau_name = tau_tag(tau_init) specs.append( { "label": f"TNRD baseline sigma={sigma} tau={tau_init}", "model_type": "tnrd", "use_wave": False, "damping_gamma": 1.0, "tau_init": tau_init, "path": f"tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth", } ) if include_finetuned: specs.append( { "label": f"Finetuned TNRD baseline sigma={sigma} tau={tau_init}", "model_type": "tnrd", "use_wave": False, "damping_gamma": 1.0, "tau_init": tau_init, "path": f"finetuned_tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth", } ) for damping_gamma in gammas: gamma_name = gamma_tag(damping_gamma) for tau_init in tau_inits: tau_name = tau_tag(tau_init) specs.extend( [ { "label": f"MLP Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}", "model_type": "mlp", "use_wave": True, "damping_gamma": damping_gamma, "tau_init": tau_init, "path": f"model_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth", }, { "label": f"MLP No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}", "model_type": "mlp", "use_wave": False, "damping_gamma": damping_gamma, "tau_init": tau_init, "path": f"model_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth", }, { "label": f"RBF Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}", "model_type": "rbf", "use_wave": True, "damping_gamma": damping_gamma, "tau_init": tau_init, "path": f"rbf_model_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth", }, { "label": f"RBF No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}", "model_type": "rbf", "use_wave": False, "damping_gamma": damping_gamma, "tau_init": tau_init, "path": f"rbf_model_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth", }, ] ) if include_finetuned: specs.extend( [ { "label": f"Finetuned MLP Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}", "model_type": "mlp", "use_wave": True, "damping_gamma": damping_gamma, "tau_init": tau_init, "path": f"finetuned_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth", }, { "label": f"Finetuned MLP No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}", "model_type": "mlp", "use_wave": False, "damping_gamma": damping_gamma, "tau_init": tau_init, "path": f"finetuned_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth", }, { "label": f"Finetuned RBF Telegraph sigma={sigma} gamma={damping_gamma} tau={tau_init}", "model_type": "rbf", "use_wave": True, "damping_gamma": damping_gamma, "tau_init": tau_init, "path": f"finetuned_rbf_model_{stages}stages_waveTrue_{sigma_name}_{gamma_name}_{tau_name}.pth", }, { "label": f"Finetuned RBF No-wave sigma={sigma} gamma={damping_gamma} tau={tau_init}", "model_type": "rbf", "use_wave": False, "damping_gamma": damping_gamma, "tau_init": tau_init, "path": f"finetuned_rbf_model_{stages}stages_waveFalse_{sigma_name}_{gamma_name}_{tau_name}.pth", }, ] ) return specs def evaluate_checkpoint(spec, dataset_name, sigma, stages): model = _build_model( spec["model_type"], stages, spec["use_wave"], spec["damping_gamma"], spec["tau_init"], ) state = torch.load(spec["path"], map_location=DEVICE) model.load_state_dict(state) model.eval() test_root = _testset_root(dataset_name) test_paths = _collect_image_paths(test_root) if not test_paths: raise FileNotFoundError( f"No test images found in {os.path.abspath(test_root)} for {dataset_name}." ) sigma_float = sigma_int_to_float(sigma) test_transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]) torch.manual_seed(42) total_psnr = 0.0 with torch.no_grad(): for path in test_paths: clean = test_transform(Image.open(path)).unsqueeze(0).to(DEVICE) noisy = torch.clamp(clean + torch.randn_like(clean) * sigma_float, 0.0, 1.0) with _autocast_context(): output = model(noisy) total_psnr += calculate_psnr(clean, output) return total_psnr / len(test_paths) def main(args): specs = _checkpoint_specs( args.stages, [int(s) for s in args.sigmas], [float(g) for g in args.gammas], [float(t) for t in args.tau_inits], args.include_finetuned, ) print(f"[*] Evaluating checkpoints on {', '.join(TESTSETS)}") print(f"[*] Device: {DEVICE}") print("-" * 90) print(f"{'Model':<38} {'Dataset':<8} {'PSNR':>8} Checkpoint") print("-" * 90) for spec in specs: if not os.path.exists(spec["path"]): print(f"{spec['label']:<38} {'-':<8} {'[missing]':>8} {spec['path']}") continue sigma_value = int(spec["label"].split("sigma=")[-1]) for dataset_name in TESTSETS: psnr = evaluate_checkpoint(spec, dataset_name, sigma_value, args.stages) print(f"{spec['label']:<38} {dataset_name:<8} {psnr:>8.2f} {spec['path']}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--stages", type=int, default=5) parser.add_argument( "--sigmas", type=int, nargs="+", default=list(DEFAULT_SIGMA_LEVELS), help="Noise levels to evaluate, specified in 0-255 units.", ) parser.add_argument( "--gammas", type=float, nargs="+", default=list(DEFAULT_GAMMAS), help="Fixed damping gamma values to evaluate for MLP/RBF models.", ) parser.add_argument( "--tau_inits", type=float, nargs="+", default=list(DEFAULT_TAU_INITS), help="Initial tau values to evaluate.", ) parser.add_argument( "--include_finetuned", action="store_true", help="Also evaluate finetuned checkpoints.", ) args = parser.parse_args() main(args)