personal_math / classical_baseline.py
psidharth567's picture
Sync full project: code, checkpoints, datasets, logs
dcd2bd2 verified
import os
import glob
import torch
import torch.nn.functional as F
import math
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
_SIGMAS = (15.0 / 255.0, 50.0 / 255.0, 75.0 / 255.0)
def _testset_root(name):
return os.path.join(
_SCRIPT_DIR,
"datasets",
"Test_Datasets",
"FFDNet-master",
"testsets",
name,
)
class TestDataset(Dataset):
def __init__(self, root_dir, sigma):
self.sigma = sigma
self.image_paths = glob.glob(os.path.join(root_dir, "*.png")) + glob.glob(
os.path.join(root_dir, "*.jpg")
)
self.transform = transforms.Compose(
[transforms.Grayscale(num_output_channels=1), transforms.ToTensor()]
)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img = Image.open(self.image_paths[idx])
clean = self.transform(img)
noisy = clean + torch.randn_like(clean) * self.sigma
return clean, torch.clamp(noisy, 0.0, 1.0)
def calculate_psnr(img1, img2):
mse = torch.mean((img1 - img2) ** 2)
if mse == 0:
return float("inf")
return 20 * math.log10(1.0 / math.sqrt(mse))
def classical_telegraph_step(u_n, u_n_minus_1, tau=0.2, gamma=1.0):
kx = torch.tensor([[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]], device=DEVICE).view(
1, 1, 3, 3
)
ky = torch.tensor([[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]], device=DEVICE).view(
1, 1, 3, 3
)
grad_x = F.conv2d(u_n, kx, padding=1)
grad_y = F.conv2d(u_n, ky, padding=1)
grad_mag = torch.sqrt(grad_x**2 + grad_y**2 + 1e-8)
c = 1.0 / (1.0 + (grad_mag / 0.1) ** 2)
divergence = F.conv2d(c * grad_x, kx, padding=1) + F.conv2d(
c * grad_y, ky, padding=1
)
alpha = (2 + gamma * tau) / (1 + gamma * tau)
beta = -1 / (1 + gamma * tau)
lam = (tau**2) / (1 + gamma * tau)
return alpha * u_n + beta * u_n_minus_1 + lam * divergence
def run_eval(dataset_name, sigma):
root = _testset_root(dataset_name)
dataset = TestDataset(root, sigma)
if len(dataset) == 0:
print(f"[!] Skip {dataset_name}: no images in {os.path.abspath(root)}")
return None
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
total_psnr = 0.0
for clean, noisy in dataloader:
clean, noisy = clean.to(DEVICE), noisy.to(DEVICE)
u_n_minus_1, u_n = noisy.clone(), noisy.clone()
for _ in range(20):
u_next = classical_telegraph_step(u_n, u_n_minus_1)
u_n_minus_1, u_n = u_n, u_next
total_psnr += calculate_psnr(clean, u_n)
avg_psnr = total_psnr / len(dataset)
sigma_int = int(round(sigma * 255.0))
print(
f"[+] {dataset_name} sigma={sigma_int}/255 PSNR: {avg_psnr:.2f} dB "
f"({len(dataset)} images)"
)
return avg_psnr
def main():
print("[*] Classical Majee 2020 baseline — Set12 & BSD68")
for dataset_name in ("Set12", "BSD68"):
for sigma in _SIGMAS:
run_eval(dataset_name, sigma)
print()
if __name__ == "__main__":
main()