personal_math / finetune.py
psidharth567's picture
Sync full project: code, checkpoints, datasets, logs
dcd2bd2 verified
import os
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import transforms
from train_network import (
DEFAULT_GAMMAS,
DEFAULT_SIGMA_LEVELS,
DEFAULT_TAU_INITS,
UnrolledNetwork as MLPUnrolledNetwork,
TrainDataset,
calculate_psnr,
gamma_tag,
resolve_train_dir,
softplus_inverse,
_train_dir_candidates,
_collect_image_paths,
_DEFAULT_TEST,
sigma_int_to_float,
sigma_tag,
tau_tag,
)
from train_network_rbf import UnrolledNetwork as RBFUnrolledNetwork
from train_tnrd_baseline import TNRDBaselineNetwork
def _model_label(model_type, use_wave):
if model_type == "mlp":
return f"{'Telegraph' if use_wave else 'No-wave'} MLP model"
if model_type == "rbf":
return f"{'Telegraph' if use_wave else 'No-wave'} RBF model"
if model_type == "tnrd":
return "TNRD baseline"
raise ValueError(f"Unknown model type: {model_type}")
def _build_model(model_type, stages, use_wave, damping_gamma, tau_init):
if model_type == "mlp":
return MLPUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init)
if model_type == "rbf":
return RBFUnrolledNetwork(stages, use_wave, damping_gamma=damping_gamma, tau_init=tau_init)
if model_type == "tnrd":
return TNRDBaselineNetwork(stages, tau_init=tau_init)
raise ValueError(f"Unknown model type: {model_type}")
def _base_checkpoint_name(model_type, stages, use_wave, sigma_name, gamma_name, tau_name):
if model_type == "mlp":
return f"model_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth"
if model_type == "rbf":
return f"rbf_model_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth"
if model_type == "tnrd":
return f"tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth"
raise ValueError(f"Unknown model type: {model_type}")
def _finetuned_checkpoint_name(model_type, stages, use_wave, sigma_name, gamma_name, tau_name):
if model_type == "mlp":
return f"finetuned_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth"
if model_type == "rbf":
return f"finetuned_rbf_model_{stages}stages_wave{use_wave}_{sigma_name}_{gamma_name}_{tau_name}.pth"
if model_type == "tnrd":
return f"finetuned_tnrd_baseline_{stages}stages_{sigma_name}_{tau_name}.pth"
raise ValueError(f"Unknown model type: {model_type}")
def fine_tune_single_config(args, sigma_value, damping_gamma, tau_init, train_root, test_root):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sigma = sigma_int_to_float(sigma_value)
sigma_name = sigma_tag(sigma_value)
gamma_name = gamma_tag(damping_gamma)
tau_name = tau_tag(tau_init)
model_label = _model_label(args.model_type, args.use_wave)
print(
f"\n[*] Loading stage-wise weights for {args.stages}-Stage "
f"{model_label} at sigma={int(sigma_value)}/255, "
f"gamma={damping_gamma}, tau_init={tau_init}..."
)
test_paths = _collect_image_paths(test_root)
if not test_paths:
raise FileNotFoundError(
f"No test images in {os.path.abspath(test_root)}. "
"Run download_data.py for FFDNet testsets or pass --test_dir."
)
model = _build_model(
args.model_type,
args.stages,
args.use_wave,
damping_gamma=damping_gamma,
tau_init=tau_init,
).to(device)
weight_file = _base_checkpoint_name(
args.model_type, args.stages, args.use_wave, sigma_name, gamma_name, tau_name
)
state = torch.load(weight_file, map_location=device)
model.load_state_dict(state)
for param in model.parameters():
param.requires_grad = True
criterion = nn.MSELoss()
train_loader = DataLoader(
TrainDataset(train_root, sigma=sigma),
batch_size=64,
shuffle=True,
num_workers=8,
pin_memory=device.type == "cuda",
)
scaler = torch.amp.GradScaler(device.type, enabled=device.type == "cuda")
optimizer = optim.Adam(model.parameters(), lr=1e-4)
print("\n--- Starting end-to-end fine-tuning ---")
for epoch in range(10):
model.train()
total_loss = 0
for clean, noisy in train_loader:
clean, noisy = clean.to(device), noisy.to(device)
optimizer.zero_grad()
with torch.amp.autocast(device.type, enabled=device.type == "cuda"):
output = model(noisy)
loss = criterion(output, clean)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
print(
f"Fine-tune epoch [{epoch + 1}/10], loss: {total_loss / len(train_loader):.6f}"
)
print("\n[+] Fine-tuning complete! Evaluating...")
model.eval()
test_transform = transforms.Compose([transforms.Grayscale(), transforms.ToTensor()])
total_psnr = 0.0
with torch.no_grad():
for path in test_paths:
clean = test_transform(Image.open(path)).unsqueeze(0).to(device)
noisy = torch.clamp(clean + torch.randn_like(clean) * sigma, 0.0, 1.0)
with torch.amp.autocast(device.type, enabled=device.type == "cuda"):
output = model(noisy)
total_psnr += calculate_psnr(clean, output)
avg_psnr = total_psnr / len(test_paths)
print(
f"\n[!!!] Fine-tuned {model_label} PSNR ({os.path.basename(test_root)}) "
f"at sigma={int(sigma_value)}/255, gamma={damping_gamma}, tau_init={tau_init}: "
f"{avg_psnr:.2f} dB"
)
output_file = _finetuned_checkpoint_name(
args.model_type, args.stages, args.use_wave, sigma_name, gamma_name, tau_name
)
torch.save(model.state_dict(), output_file)
print(f"[+] Saved checkpoint: {output_file}")
def fine_tune(args):
if args.model_type == "tnrd" and args.use_wave:
raise ValueError("--use_wave is only valid for MLP/RBF models, not for tnrd.")
train_root = resolve_train_dir(args.train_dir)
if train_root is None:
tried = "\n ".join(os.path.abspath(p) for p in _train_dir_candidates())
raise FileNotFoundError(
"No training images found under datasets/. Tried:\n "
f"{tried}\n"
"Pass --train_dir to your PNG folder (e.g. .../DIV2K_Train_HR/DIV2K_train_HR)."
)
print(f"[*] Fine-tune data: {train_root} ({len(_collect_image_paths(train_root))} images)")
test_root = args.test_dir or _DEFAULT_TEST
sigmas = [int(s) for s in args.sigmas]
gammas = [float(g) for g in args.gammas]
tau_inits = [float(t) for t in args.tau_inits]
print(f"[*] Sigma sweep: {', '.join(str(s) for s in sigmas)}")
if args.model_type != "tnrd":
print(f"[*] Gamma sweep: {', '.join(str(g) for g in gammas)}")
print(f"[*] Tau-init sweep: {', '.join(str(t) for t in tau_inits)}")
for sigma_value in sigmas:
if args.model_type == "tnrd":
for tau_init in tau_inits:
fine_tune_single_config(args, sigma_value, 1.0, tau_init, train_root, test_root)
else:
for damping_gamma in gammas:
for tau_init in tau_inits:
fine_tune_single_config(
args, sigma_value, damping_gamma, tau_init, train_root, test_root
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--stages", type=int, required=True)
parser.add_argument(
"--model_type",
type=str,
choices=("mlp", "rbf", "tnrd"),
default="mlp",
help="Which model family to fine-tune (default: mlp)",
)
parser.add_argument("--use_wave", action="store_true")
parser.add_argument(
"--sigmas",
type=int,
nargs="+",
default=list(DEFAULT_SIGMA_LEVELS),
help="Noise levels to sweep, specified in 0-255 units (default: 15 25 50 75)",
)
parser.add_argument(
"--gammas",
type=float,
nargs="+",
default=list(DEFAULT_GAMMAS),
help="Fixed damping gamma values to sweep for MLP/RBF models (default: 1.0)",
)
parser.add_argument(
"--tau_inits",
type=float,
nargs="+",
default=list(DEFAULT_TAU_INITS),
help="Initial tau values to sweep (default: 0.1)",
)
parser.add_argument(
"--train_dir",
type=str,
default=None,
help="Training PNG folder (default: same auto-detect as train_network.py)",
)
parser.add_argument(
"--test_dir",
type=str,
default=None,
help=f"Test image folder (default: BSD68 under FFDNet testsets)",
)
args = parser.parse_args()
fine_tune(args)