| import torch |
| from model import SmoothDiffusionUNet |
| from noise_scheduler import FrequencyAwareNoise |
| from config import Config |
| from torchvision.utils import save_image, make_grid |
| from dataloader import get_dataloaders |
| import numpy as np |
|
|
| def diagnose_and_fix(): |
| """Final diagnosis and alternative sampling approach""" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| checkpoint = torch.load('model_final.pth', map_location=device) |
| config = Config() |
| |
| model = SmoothDiffusionUNet(config).to(device) |
| noise_scheduler = FrequencyAwareNoise(config) |
| model.load_state_dict(checkpoint) |
| model.eval() |
| |
| print("=== FINAL DIAGNOSIS ===") |
| |
| |
| train_loader, _ = get_dataloaders(config) |
| real_batch, _ = next(iter(train_loader)) |
| real_images = real_batch[:4].to(device) |
| |
| print(f"Real training data range: [{real_images.min():.3f}, {real_images.max():.3f}]") |
| print(f"Real training data mean: {real_images.mean():.3f}, std: {real_images.std():.3f}") |
| |
| |
| real_display = torch.clamp((real_images + 1) / 2, 0, 1) |
| real_grid = make_grid(real_display, nrow=2, normalize=False, pad_value=1.0) |
| save_image(real_grid, "real_training_images.png") |
| print("Real training images saved to real_training_images.png") |
| |
| with torch.no_grad(): |
| |
| print("\n=== TESTING MODEL ON REAL DATA ===") |
| |
| for t_val in [50, 200, 400]: |
| t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) |
| |
| |
| x_noisy, noise_target = noise_scheduler.apply_noise(real_images, t_tensor) |
| |
| |
| noise_pred = model(x_noisy, t_tensor) |
| |
| |
| alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
| x_reconstructed = (x_noisy - np.sqrt(1 - alpha_bar_t) * noise_pred) / np.sqrt(alpha_bar_t) |
| x_reconstructed = torch.clamp(x_reconstructed, -1, 1) |
| |
| print(f"\nTimestep {t_val}:") |
| print(f" Reconstruction error: {torch.mean((x_reconstructed - real_images) ** 2).item():.6f}") |
| |
| |
| recon_display = torch.clamp((x_reconstructed + 1) / 2, 0, 1) |
| recon_grid = make_grid(recon_display, nrow=2, normalize=False) |
| save_image(recon_grid, f"reconstruction_t{t_val}.png") |
| print(f" Reconstruction saved to reconstruction_t{t_val}.png") |
| |
| print("\n=== TRYING INTERPOLATION SAMPLING ===") |
| |
| |
| x1 = real_images[0:1] |
| x2 = real_images[1:2] |
| |
| |
| alphas = torch.linspace(0, 1, 4, device=device).view(-1, 1, 1, 1) |
| x_interp = torch.cat([ |
| alpha * x1 + (1 - alpha) * x2 for alpha in alphas |
| ], dim=0) |
| |
| print(f"Starting from real image interpolation...") |
| print(f"Interpolation range: [{x_interp.min():.3f}, {x_interp.max():.3f}]") |
| |
| |
| timesteps = [100, 80, 60, 40, 25, 15, 8, 3, 1] |
| |
| x = x_interp.clone() |
| |
| for t_val in timesteps: |
| t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) |
| |
| |
| predicted_noise = model(x, t_tensor) |
| |
| |
| alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
| x = (x - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.3) / np.sqrt(alpha_bar_t) |
| x = torch.clamp(x, -1, 1) |
| |
| print(f"Interpolation result range: [{x.min():.3f}, {x.max():.3f}]") |
| |
| |
| interp_display = torch.clamp((x + 1) / 2, 0, 1) |
| interp_grid = make_grid(interp_display, nrow=2, normalize=False) |
| save_image(interp_grid, "interpolation_sampling.png") |
| print("Interpolation sampling saved to interpolation_sampling.png") |
| |
| print("\n=== TRYING MINIMAL NOISE SAMPLING ===") |
| |
| |
| x_minimal = torch.randn(4, 3, 64, 64, device=device) * 0.1 |
| |
| |
| light_timesteps = [50, 30, 15, 5, 1] |
| |
| for t_val in light_timesteps: |
| t_tensor = torch.full((4,), t_val, device=device, dtype=torch.long) |
| |
| |
| predicted_noise = model(x_minimal, t_tensor) |
| |
| |
| alpha_bar_t = noise_scheduler.alpha_bars[t_val].item() |
| x_minimal = (x_minimal - np.sqrt(1 - alpha_bar_t) * predicted_noise * 0.5) / np.sqrt(alpha_bar_t) |
| x_minimal = torch.clamp(x_minimal, -1, 1) |
| |
| print(f"Minimal noise result range: [{x_minimal.min():.3f}, {x_minimal.max():.3f}]") |
| print(f"Minimal noise result std: {x_minimal.std():.3f}") |
| |
| |
| minimal_display = torch.clamp((x_minimal + 1) / 2, 0, 1) |
| minimal_grid = make_grid(minimal_display, nrow=2, normalize=False) |
| save_image(minimal_grid, "minimal_noise_sampling.png") |
| print("Minimal noise sampling saved to minimal_noise_sampling.png") |
| |
| print("\n=== SUMMARY ===") |
| print("Generated files:") |
| print("- real_training_images.png (what we want to achieve)") |
| print("- reconstruction_t*.png (model's denoising ability)") |
| print("- interpolation_sampling.png (interpolation approach)") |
| print("- minimal_noise_sampling.png (light noise approach)") |
|
|
| if __name__ == "__main__": |
| diagnose_and_fix() |
|
|