| import torch |
| import numpy as np |
|
|
| class FrequencyAwareNoise: |
| def __init__(self, config): |
| self.config = config |
| self.betas = torch.linspace(config.beta_start, config.beta_end, config.T) |
| self.alphas = 1. - self.betas |
| self.alpha_bars = torch.cumprod(self.alphas, dim=0) |
|
|
| def apply_noise(self, x0, t, noise=None): |
| """Standard DDPM noise application - let's get basic diffusion working first""" |
| if noise is None: |
| noise = torch.randn_like(x0) |
| |
| device = x0.device |
| |
| |
| alpha_bars = self.alpha_bars.to(device) |
| |
| |
| alpha_bar_t = alpha_bars[t].view(-1, 1, 1, 1) |
| |
| |
| xt = torch.sqrt(alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * noise |
| |
| return xt, noise |
| |
| def debug_noise_stats(self, x0, t): |
| """Debug function to check noise statistics""" |
| xt, noise = self.apply_noise(x0, t) |
| print(f"Input range: [{x0.min().item():.4f}, {x0.max().item():.4f}]") |
| print(f"Noise range: [{noise.min().item():.4f}, {noise.max().item():.4f}]") |
| print(f"Noisy range: [{xt.min().item():.4f}, {xt.max().item():.4f}]") |
| print(f"Noise std: {noise.std().item():.4f}") |
| return xt, noise |
|
|