| import abc |
|
|
| import torch |
| import torch.nn as nn |
|
|
| |
| torch._C._jit_set_profiling_mode(False) |
| torch._C._jit_set_profiling_executor(False) |
| torch._C._jit_override_can_fuse_on_cpu(True) |
| torch._C._jit_override_can_fuse_on_gpu(True) |
|
|
|
|
| def get_noise(config, dtype=torch.float32): |
| return LogLinearNoise() |
|
|
| if config.noise.type == 'geometric': |
| return GeometricNoise(config.noise.sigma_min, |
| config.noise.sigma_max) |
| elif config.noise.type == 'loglinear': |
| return LogLinearNoise() |
| elif config.noise.type == 'cosine': |
| return CosineNoise() |
| elif config.noise.type == 'cosinesqr': |
| return CosineSqrNoise() |
| elif config.noise.type == 'linear': |
| return Linear(config.noise.sigma_min, |
| config.noise.sigma_max, |
| dtype) |
| else: |
| raise ValueError(f'{config.noise.type} is not a valid noise') |
|
|
|
|
| def binary_discretization(z): |
| z_hard = torch.sign(z) |
| z_soft = z / torch.norm(z, dim=-1, keepdim=True) |
| return z_soft + (z_hard - z_soft).detach() |
|
|
|
|
| class Noise(abc.ABC, nn.Module): |
| """ |
| Baseline forward method to get the total + rate of noise at a timestep |
| """ |
| def forward(self, t): |
| |
| return self.total_noise(t), self.rate_noise(t) |
| |
| @abc.abstractmethod |
| def rate_noise(self, t): |
| """ |
| Rate of change of noise ie g(t) |
| """ |
| pass |
|
|
| @abc.abstractmethod |
| def total_noise(self, t): |
| """ |
| Total noise ie \int_0^t g(t) dt + g(0) |
| """ |
| pass |
|
|
|
|
| class CosineNoise(Noise): |
| def __init__(self, eps=1e-3): |
| super().__init__() |
| self.eps = eps |
|
|
| def rate_noise(self, t): |
| cos = (1 - self.eps) * torch.cos(t * torch.pi / 2) |
| sin = (1 - self.eps) * torch.sin(t * torch.pi / 2) |
| scale = torch.pi / 2 |
| return scale * sin / (cos + self.eps) |
|
|
| def total_noise(self, t): |
| cos = torch.cos(t * torch.pi / 2) |
| return - torch.log(self.eps + (1 - self.eps) * cos) |
|
|
|
|
| class CosineSqrNoise(Noise): |
| def __init__(self, eps=1e-3): |
| super().__init__() |
| self.eps = eps |
|
|
| def rate_noise(self, t): |
| cos = (1 - self.eps) * ( |
| torch.cos(t * torch.pi / 2) ** 2) |
| sin = (1 - self.eps) * torch.sin(t * torch.pi) |
| scale = torch.pi / 2 |
| return scale * sin / (cos + self.eps) |
|
|
| def total_noise(self, t): |
| cos = torch.cos(t * torch.pi / 2) ** 2 |
| return - torch.log(self.eps + (1 - self.eps) * cos) |
|
|
|
|
| class Linear(Noise): |
| def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32): |
| super().__init__() |
| self.sigma_min = torch.tensor(sigma_min, dtype=dtype) |
| self.sigma_max = torch.tensor(sigma_max, dtype=dtype) |
|
|
| def rate_noise(self, t): |
| return self.sigma_max - self.sigma_min |
|
|
| def total_noise(self, t): |
| return self.sigma_min + t * (self.sigma_max - self.sigma_min) |
|
|
| def importance_sampling_transformation(self, t): |
| f_T = torch.log1p(- torch.exp(- self.sigma_max)) |
| f_0 = torch.log1p(- torch.exp(- self.sigma_min)) |
| sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0)) |
| return (sigma_t - self.sigma_min) / ( |
| self.sigma_max - self.sigma_min) |
|
|
|
|
| class GeometricNoise(Noise): |
| def __init__(self, sigma_min=1e-3, sigma_max=1): |
| super().__init__() |
| self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]) |
|
|
| def rate_noise(self, t): |
| return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * ( |
| self.sigmas[1].log() - self.sigmas[0].log()) |
|
|
| def total_noise(self, t): |
| return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t |
|
|
|
|
| class LogLinearNoise(Noise): |
| """Log Linear noise schedule. |
| |
| Built such that 1 - 1/e^(n(t)) interpolates between 0 and |
| ~1 when t varies from 0 to 1. Total noise is |
| -log(1 - (1 - eps) * t), so the sigma will be |
| (1 - eps) * t. |
| """ |
| def __init__(self, eps=1e-3): |
| super().__init__() |
| self.eps = eps |
| self.sigma_max = self.total_noise(torch.tensor(1.0)) |
| self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0)) |
|
|
| def rate_noise(self, t): |
| return (1 - self.eps) / (1 - (1 - self.eps) * t) |
|
|
| def total_noise(self, t): |
| return -torch.log1p(-(1 - self.eps) * t) |
|
|
| def importance_sampling_transformation(self, t): |
| f_T = torch.log1p(- torch.exp(- self.sigma_max)) |
| f_0 = torch.log1p(- torch.exp(- self.sigma_min)) |
| sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0)) |
| t = - torch.expm1(- sigma_t) / (1 - self.eps) |
| return t |
|
|