| import os |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import argparse |
| from torch.utils.data import DataLoader |
| from PIL import Image |
| from torchvision import transforms |
|
|
| from train_network import ( |
| DEFAULT_GAMMAS, |
| DEFAULT_SIGMA_LEVELS, |
| DEFAULT_TAU_INITS, |
| UnrolledNetwork as MLPUnrolledNetwork, |
| TrainDataset, |
| calculate_psnr, |
| gamma_tag, |
| resolve_train_dir, |
| softplus_inverse, |
| _train_dir_candidates, |
| _collect_image_paths, |
| _DEFAULT_TEST, |
| sigma_int_to_float, |
| sigma_tag, |
| tau_tag, |
| ) |
| from train_network_rbf import UnrolledNetwork as RBFUnrolledNetwork |
| from train_tnrd_baseline import TNRDBaselineNetwork |
|
|
|
|
| def _model_label(model_type, use_wave): |
| if model_type == "mlp": |
| return f"{'Telegraph' if use_wave else 'No-wave'} MLP model" |
| if model_type == "rbf": |
| return f"{'Telegraph' if use_wave else 'No-wave'} RBF model" |
| if model_type == "tnrd": |
| return "TNRD baseline" |
| raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
| def _build_model(model_type, stages, use_wave, damping_gamma, tau_init): |
| if model_type == "mlp": |
| return MLPUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init) |
| if model_type == "rbf": |
| return RBFUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init) |
| if model_type == "tnrd": |
| return TNRDBaselineNetwork(stages, tau_init=tau_init) |
| raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
| def _base_checkpoint_name(model_type, stages, use_wave, sigma_name, gamma_name, tau_name): |
| if model_type == "mlp": |
| return f"model_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth" |
| if model_type == "rbf": |
| return f"rbf_model_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth" |
| if model_type == "tnrd": |
| return f"tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth" |
| raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
| def _finetuned_checkpoint_name(model_type, stages, use_wave, sigma_name, gamma_name, tau_name): |
| if model_type == "mlp": |
| return f"finetuned_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth" |
| if model_type == "rbf": |
| return f"finetuned_rbf_model_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth" |
| if model_type == "tnrd": |
| return f"finetuned_tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth" |
| raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
| def fine_tune_single_config(args, sigma_value, damping_gamma, tau_init, train_root, test_root): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| sigma = sigma_int_to_float(sigma_value) |
| sigma_name = sigma_tag(sigma_value) |
| gamma_name = gamma_tag(damping_gamma) |
| tau_name = tau_tag(tau_init) |
| model_label = _model_label(args.model_type, args.use_wave) |
| print( |
| f"\n[*] Loading stage-wise weights for {args.stages}-Stage " |
| f"{model_label} at sigma={int(sigma_value)}/255, " |
| f"gamma={damping_gamma}, tau_init={tau_init}..." |
| ) |
|
|
| test_paths = _collect_image_paths(test_root) |
| if not test_paths: |
| raise FileNotFoundError( |
| f"No test images in {os.path.abspath(test_root)}. " |
| "Run download_data.py for FFDNet testsets or pass --test_dir." |
| ) |
|
|
| model = _build_model( |
| args.model_type, |
| args.stages, |
| args.use_wave, |
| damping_gamma=damping_gamma, |
| tau_init=tau_init, |
| ).to(device) |
|
|
| weight_file = _base_checkpoint_name( |
| args.model_type, args.stages, args.use_wave, sigma_name, gamma_name, tau_name |
| ) |
| state = torch.load(weight_file, map_location=device) |
| model.load_state_dict(state) |
|
|
| for param in model.parameters(): |
| param.requires_grad = True |
|
|
| criterion = nn.MSELoss() |
| train_loader = DataLoader( |
| TrainDataset(train_root, sigma=sigma), |
| batch_size=64, |
| shuffle=True, |
| num_workers=8, |
| pin_memory=device.type == "cuda", |
| ) |
| scaler = torch.amp.GradScaler(device.type, enabled=device.type == "cuda") |
|
|
| optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
|
| print("\n--- Starting end-to-end fine-tuning ---") |
| for epoch in range(10): |
| model.train() |
| total_loss = 0 |
| for clean, noisy in train_loader: |
| clean, noisy = clean.to(device), noisy.to(device) |
| optimizer.zero_grad() |
|
|
| with torch.amp.autocast(device.type, enabled=device.type == "cuda"): |
| output = model(noisy) |
| loss = criterion(output, clean) |
|
|
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
| total_loss += loss.item() |
|
|
| print( |
| f"Fine-tune epoch [{epoch + 1}/10], loss: {total_loss / len(train_loader):.6f}" |
| ) |
|
|
| print("\n[+] Fine-tuning complete! Evaluating...") |
|
|
| model.eval() |
| test_transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()]) |
| total_psnr = 0.0 |
| with torch.no_grad(): |
| for path in test_paths: |
| clean = test_transform(Image.open(path)).unsqueeze(0).to(device) |
| noisy = torch.clamp(clean + torch.randn_like(clean) * sigma, 0.0, 1.0) |
|
|
| with torch.amp.autocast(device.type, enabled=device.type == "cuda"): |
| output = model(noisy) |
| total_psnr += calculate_psnr(clean, output) |
|
|
| avg_psnr = total_psnr / len(test_paths) |
| print( |
| f"\n[!!!] Fine-tuned {model_label} PSNR ({os.path.basename(test_root)}) " |
| f"at sigma={int(sigma_value)}/255, gamma={damping_gamma}, tau_init={tau_init}: " |
| f"{avg_psnr:.2f} dB" |
| ) |
| output_file = _finetuned_checkpoint_name( |
| args.model_type, args.stages, args.use_wave, sigma_name, gamma_name, tau_name |
| ) |
| torch.save(model.state_dict(), output_file) |
| print(f"[+] Saved checkpoint: {output_file}") |
|
|
|
|
| def fine_tune(args): |
| if args.model_type == "tnrd" and args.use_wave: |
| raise ValueError("--use_wave is only valid for MLP/RBF models, not for tnrd.") |
|
|
| train_root = resolve_train_dir(args.train_dir) |
| if train_root is None: |
| tried = "\n ".join(os.path.abspath(p) for p in _train_dir_candidates()) |
| raise FileNotFoundError( |
| "No training images found under datasets/. Tried:\n " |
| f"{tried}\n" |
| "Pass --train_dir to your PNG folder (e.g. .../DIV2K_Train_HR/DIV2K_train_HR)." |
| ) |
| print(f"[*] Fine-tune data: {train_root} ({len(_collect_image_paths(train_root))} images)") |
|
|
| test_root = args.test_dir or _DEFAULT_TEST |
| sigmas = [int(s) for s in args.sigmas] |
| gammas = [float(g) for g in args.gammas] |
| tau_inits = [float(t) for t in args.tau_inits] |
| print(f"[*] Sigma sweep: {', '.join(str(s) for s in sigmas)}") |
| if args.model_type != "tnrd": |
| print(f"[*] Gamma sweep: {', '.join(str(g) for g in gammas)}") |
| print(f"[*] Tau-init sweep: {', '.join(str(t) for t in tau_inits)}") |
| for sigma_value in sigmas: |
| if args.model_type == "tnrd": |
| for tau_init in tau_inits: |
| fine_tune_single_config(args, sigma_value, 1.0, tau_init, train_root, test_root) |
| else: |
| for damping_gamma in gammas: |
| for tau_init in tau_inits: |
| fine_tune_single_config( |
| args, sigma_value, damping_gamma, tau_init, train_root, test_root |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--stages", type=int, required=True) |
| parser.add_argument( |
| "--model_type", |
| type=str, |
| choices=("mlp", "rbf", "tnrd"), |
| default="mlp", |
| help="Which model family to fine-tune (default: mlp)", |
| ) |
| parser.add_argument("--use_wave", action="store_true") |
| parser.add_argument( |
| "--sigmas", |
| type=int, |
| nargs="+", |
| default=list(DEFAULT_SIGMA_LEVELS), |
| help="Noise levels to sweep, specified in 0-255 units (default: 15 25 50 75)", |
| ) |
| parser.add_argument( |
| "--gammas", |
| type=float, |
| nargs="+", |
| default=list(DEFAULT_GAMMAS), |
| help="Fixed damping gamma values to sweep for MLP/RBF models (default: 1.0)", |
| ) |
| parser.add_argument( |
| "--tau_inits", |
| type=float, |
| nargs="+", |
| default=list(DEFAULT_TAU_INITS), |
| help="Initial tau values to sweep (default: 0.1)", |
| ) |
| parser.add_argument( |
| "--train_dir", |
| type=str, |
| default=None, |
| help="Training PNG folder (default: same auto-detect as train_network.py)", |
| ) |
| parser.add_argument( |
| "--test_dir", |
| type=str, |
| default=None, |
| help=f"Test image folder (default: BSD68 under FFDNet testsets)", |
| ) |
| args = parser.parse_args() |
| fine_tune(args) |
|
|