import os import torch import torch.nn as nn import torch.optim as optim import argparse from torch.utils.data import DataLoader from PIL import Image from torchvision import transforms from train_network import ( DEFAULT_GAMMAS, DEFAULT_SIGMA_LEVELS, DEFAULT_TAU_INITS, UnrolledNetwork as MLPUnrolledNetwork, TrainDataset, calculate_psnr, gamma_tag, resolve_train_dir, softplus_inverse, _train_dir_candidates, _collect_image_paths, _DEFAULT_TEST, sigma_int_to_float, sigma_tag, tau_tag, ) from train_network_rbf import UnrolledNetwork as RBFUnrolledNetwork from train_tnrd_baseline import TNRDBaselineNetwork def _model_label(model_type, use_wave): if model_type == "mlp": return f"{'Telegraph' if use_wave else 'No-wave'} MLP model" if model_type == "rbf": return f"{'Telegraph' if use_wave else 'No-wave'} RBF model" if model_type == "tnrd": return "TNRD baseline" raise ValueError(f"Unknown model type: {model_type}") def _build_model(model_type, stages, use_wave, damping_gamma, tau_init): if model_type == "mlp": return MLPUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init) if model_type == "rbf": return RBFUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init) if model_type == "tnrd": return TNRDBaselineNetwork(stages, tau_init=tau_init) raise ValueError(f"Unknown model type: {model_type}") def _base_checkpoint_name(model_type, stages, use_wave, sigma_name, gamma_name, tau_name): if model_type == "mlp": return f"model_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth" if model_type == "rbf": return f"rbf_model_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth" if model_type == "tnrd": return f"tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth" raise ValueError(f"Unknown model type: {model_type}") def _finetuned_checkpoint_name(model_type, stages, use_wave, sigma_name, gamma_name, tau_name): if model_type == "mlp": return f"finetuned_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth" if model_type == "rbf": return f"finetuned_rbf_model_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth" if model_type == "tnrd": return f"finetuned_tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth" raise ValueError(f"Unknown model type: {model_type}") def fine_tune_single_config(args, sigma_value, damping_gamma, tau_init, train_root, test_root): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") sigma = sigma_int_to_float(sigma_value) sigma_name = sigma_tag(sigma_value) gamma_name = gamma_tag(damping_gamma) tau_name = tau_tag(tau_init) model_label = _model_label(args.model_type, args.use_wave) print( f"\n[*] Loading stage-wise weights for {args.stages}-Stage " f"{model_label} at sigma={int(sigma_value)}/255, " f"gamma={damping_gamma}, tau_init={tau_init}..." ) test_paths = _collect_image_paths(test_root) if not test_paths: raise FileNotFoundError( f"No test images in {os.path.abspath(test_root)}. " "Run download_data.py for FFDNet testsets or pass --test_dir." ) model = _build_model( args.model_type, args.stages, args.use_wave, damping_gamma=damping_gamma, tau_init=tau_init, ).to(device) weight_file = _base_checkpoint_name( args.model_type, args.stages, args.use_wave, sigma_name, gamma_name, tau_name ) state = torch.load(weight_file, map_location=device) model.load_state_dict(state) for param in model.parameters(): param.requires_grad = True criterion = nn.MSELoss() train_loader = DataLoader( TrainDataset(train_root, sigma=sigma), batch_size=64, shuffle=True, num_workers=8, pin_memory=device.type == "cuda", ) scaler = torch.amp.GradScaler(device.type, enabled=device.type == "cuda") optimizer = optim.Adam(model.parameters(), lr=1e-4) print("\n--- Starting end-to-end fine-tuning ---") for epoch in range(10): model.train() total_loss = 0 for clean, noisy in train_loader: clean, noisy = clean.to(device), noisy.to(device) optimizer.zero_grad() with torch.amp.autocast(device.type, enabled=device.type == "cuda"): output = model(noisy) loss = criterion(output, clean) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() total_loss += loss.item() print( f"Fine-tune epoch [{epoch + 1}/10], loss: {total_loss / len(train_loader):.6f}" ) print("\n[+] Fine-tuning complete! Evaluating...") model.eval() test_transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]) total_psnr = 0.0 with torch.no_grad(): for path in test_paths: clean = test_transform(Image.open(path)).unsqueeze(0).to(device) noisy = torch.clamp(clean + torch.randn_like(clean) * sigma, 0.0, 1.0) with torch.amp.autocast(device.type, enabled=device.type == "cuda"): output = model(noisy) total_psnr += calculate_psnr(clean, output) avg_psnr = total_psnr / len(test_paths) print( f"\n[!!!] Fine-tuned {model_label} PSNR ({os.path.basename(test_root)}) " f"at sigma={int(sigma_value)}/255, gamma={damping_gamma}, tau_init={tau_init}: " f"{avg_psnr:.2f} dB" ) output_file = _finetuned_checkpoint_name( args.model_type, args.stages, args.use_wave, sigma_name, gamma_name, tau_name ) torch.save(model.state_dict(), output_file) print(f"[+] Saved checkpoint: {output_file}") def fine_tune(args): if args.model_type == "tnrd" and args.use_wave: raise ValueError("--use_wave is only valid for MLP/RBF models, not for tnrd.") train_root = resolve_train_dir(args.train_dir) if train_root is None: tried = "\n ".join(os.path.abspath(p) for p in _train_dir_candidates()) raise FileNotFoundError( "No training images found under datasets/. Tried:\n " f"{tried}\n" "Pass --train_dir to your PNG folder (e.g. .../DIV2K_Train_HR/DIV2K_train_HR)." ) print(f"[*] Fine-tune data: {train_root} ({len(_collect_image_paths(train_root))} images)") test_root = args.test_dir or _DEFAULT_TEST sigmas = [int(s) for s in args.sigmas] gammas = [float(g) for g in args.gammas] tau_inits = [float(t) for t in args.tau_inits] print(f"[*] Sigma sweep: {', '.join(str(s) for s in sigmas)}") if args.model_type != "tnrd": print(f"[*] Gamma sweep: {', '.join(str(g) for g in gammas)}") print(f"[*] Tau-init sweep: {', '.join(str(t) for t in tau_inits)}") for sigma_value in sigmas: if args.model_type == "tnrd": for tau_init in tau_inits: fine_tune_single_config(args, sigma_value, 1.0, tau_init, train_root, test_root) else: for damping_gamma in gammas: for tau_init in tau_inits: fine_tune_single_config( args, sigma_value, damping_gamma, tau_init, train_root, test_root ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--stages", type=int, required=True) parser.add_argument( "--model_type", type=str, choices=("mlp", "rbf", "tnrd"), default="mlp", help="Which model family to fine-tune (default: mlp)", ) parser.add_argument("--use_wave", action="store_true") parser.add_argument( "--sigmas", type=int, nargs="+", default=list(DEFAULT_SIGMA_LEVELS), help="Noise levels to sweep, specified in 0-255 units (default: 15 25 50 75)", ) parser.add_argument( "--gammas", type=float, nargs="+", default=list(DEFAULT_GAMMAS), help="Fixed damping gamma values to sweep for MLP/RBF models (default: 1.0)", ) parser.add_argument( "--tau_inits", type=float, nargs="+", default=list(DEFAULT_TAU_INITS), help="Initial tau values to sweep (default: 0.1)", ) parser.add_argument( "--train_dir", type=str, default=None, help="Training PNG folder (default: same auto-detect as train_network.py)", ) parser.add_argument( "--test_dir", type=str, default=None, help=f"Test image folder (default: BSD68 under FFDNet testsets)", ) args = parser.parse_args() fine_tune(args)