| import torch |
| import numpy as np |
| from scipy.fftpack import dctn, idctn |
|
|
| class FrequencyAwareNoise: |
| def __init__(self, config): |
| self.config = config |
| self.betas = torch.linspace(config.beta_start, config.beta_end, config.T) |
| self.alphas = 1. - self.betas |
| self.alpha_bars = torch.cumprod(self.alphas, dim=0) |
| |
| |
| self.betas_np = self.betas.numpy() |
| self.alphas_np = self.alphas.numpy() |
| self.alpha_bars_np = self.alpha_bars.numpy() |
|
|
| def apply_noise(self, x0, t, noise=None): |
| """Add noise in frequency space (patch-wise DCT) - FIXED VERSION""" |
| B, C, H, W = x0.shape |
| device = x0.device |
| xt = torch.zeros_like(x0) |
| noise_spatial = torch.zeros_like(x0) |
| patch_size = self.config.patch_size |
| |
| |
| t_cpu = t.cpu() |
| |
| for i in range(0, H, patch_size): |
| for j in range(0, W, patch_size): |
| patch = x0[:, :, i:i+patch_size, j:j+patch_size] |
| patch_np = patch.cpu().numpy() |
| |
| |
| dct = dctn(patch_np, axes=(2, 3), norm='ortho') |
| |
| |
| noise_dct = np.random.randn(*dct.shape) |
| |
| |
| max_freq = dct.shape[2] + dct.shape[3] - 2 |
| for u in range(dct.shape[2]): |
| for v in range(dct.shape[3]): |
| freq_weight = 0.1 + 0.9 * (u + v) / max_freq |
| noise_dct[:, :, u, v] *= freq_weight |
| |
| |
| alpha_bars = self.alpha_bars_np[t_cpu] |
| if alpha_bars.ndim == 0: |
| alpha_bars = np.array([alpha_bars]) |
| alpha_bars = alpha_bars.reshape(-1, 1, 1, 1) |
| if alpha_bars.shape[0] != dct.shape[0]: |
| alpha_bars = np.broadcast_to(alpha_bars[0:1], (dct.shape[0], 1, 1, 1)) |
| |
| |
| noisy_dct = np.sqrt(alpha_bars) * dct + np.sqrt(1 - alpha_bars) * noise_dct |
| noisy_patch = idctn(noisy_dct, axes=(2, 3), norm='ortho') |
| |
| |
| noise_patch_spatial = idctn(noise_dct, axes=(2, 3), norm='ortho') |
| |
| xt[:, :, i:i+patch_size, j:j+patch_size] = torch.from_numpy(noisy_patch).float().to(device) |
| noise_spatial[:, :, i:i+patch_size, j:j+patch_size] = torch.from_numpy(noise_patch_spatial).float().to(device) |
| |
| return xt, noise_spatial |
| |
| def debug_noise_stats(self, x0, t): |
| """Debug function to check noise statistics""" |
| xt, noise = self.apply_noise(x0, t) |
| print(f"Input range: [{x0.min().item():.4f}, {x0.max().item():.4f}]") |
| print(f"Noise range: [{noise.min().item():.4f}, {noise.max().item():.4f}]") |
| print(f"Noisy range: [{xt.min().item():.4f}, {xt.max().item():.4f}]") |
| print(f"Noise std: {noise.std().item():.4f}") |
| return xt, noise |