| """
|
| From https://github.com/CompVis/latent-diffusion/main/ldm/models/diffusion/ddpm.py
|
| Pared down to simplify code.
|
|
|
| The original file acknowledges:
|
| https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
| https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
|
| https://github.com/CompVis/taming-transformers
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import numpy as np
|
| import pytorch_lightning as pl
|
| from contextlib import contextmanager
|
| from functools import partial
|
| from torchmetrics import MeanSquaredError
|
|
|
| from .utils import make_beta_schedule, extract_into_tensor, noise_like, timestep_embedding
|
| from .ema import LitEma
|
| from ..blocks.afno import PatchEmbed3d, PatchExpand3d, AFNOBlock3d
|
|
|
|
|
| class LatentDiffusion(pl.LightningModule):
|
| def __init__(self,
|
| model,
|
| autoencoder,
|
| context_encoder=None,
|
| timesteps=1000,
|
| beta_schedule="linear",
|
| loss_type="l2",
|
| use_ema=True,
|
| lr=1e-4,
|
| lr_warmup=0,
|
| linear_start=1e-4,
|
| linear_end=2e-2,
|
| cosine_s=8e-3,
|
| parameterization="eps",
|
| ):
|
| super().__init__()
|
| self.model = model
|
| self.autoencoder = autoencoder.requires_grad_(False)
|
| self.conditional = (context_encoder is not None)
|
| self.context_encoder = context_encoder
|
| self.lr = lr
|
| self.lr_warmup = lr_warmup
|
|
|
| self.val_loss = MeanSquaredError()
|
|
|
| assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
|
| self.parameterization = parameterization
|
|
|
| self.use_ema = use_ema
|
| if self.use_ema:
|
| self.model_ema = LitEma(self.model)
|
|
|
| self.register_schedule(
|
| beta_schedule=beta_schedule, timesteps=timesteps,
|
| linear_start=linear_start, linear_end=linear_end,
|
| cosine_s=cosine_s
|
| )
|
|
|
| self.loss_type = loss_type
|
|
|
| def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
| linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
|
|
| betas = make_beta_schedule(
|
| beta_schedule, timesteps,
|
| linear_start=linear_start, linear_end=linear_end,
|
| cosine_s=cosine_s
|
| )
|
| alphas = 1. - betas
|
| alphas_cumprod = np.cumprod(alphas, axis=0)
|
| alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
|
|
| timesteps, = betas.shape
|
| self.num_timesteps = int(timesteps)
|
| self.linear_start = linear_start
|
| self.linear_end = linear_end
|
| assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
|
|
| to_torch = partial(torch.tensor, dtype=torch.float32)
|
|
|
| self.register_buffer('betas', to_torch(betas))
|
| self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
| self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
|
|
|
|
| self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
| self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
|
|
| @contextmanager
|
| def ema_scope(self, context=None):
|
| if self.use_ema:
|
| self.model_ema.store(self.model.parameters())
|
| self.model_ema.copy_to(self.model)
|
| if context is not None:
|
| print(f"{context}: Switched to EMA weights")
|
| try:
|
| yield None
|
| finally:
|
| if self.use_ema:
|
| self.model_ema.restore(self.model.parameters())
|
| if context is not None:
|
| print(f"{context}: Restored training weights")
|
|
|
| def apply_model(self, x_noisy, t, cond=None, return_ids=False):
|
| if self.conditional:
|
| cond = self.context_encoder(cond)
|
| with self.ema_scope():
|
| return self.model(x_noisy, t, context=cond)
|
|
|
| def q_sample(self, x_start, t, noise=None):
|
| if noise is None:
|
| noise = torch.randn_like(x_start)
|
| return (
|
| extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
| extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| )
|
|
|
| def get_loss(self, pred, target, mean=True):
|
| if self.loss_type == 'l1':
|
| loss = (target - pred).abs()
|
| if mean:
|
| loss = loss.mean()
|
| elif self.loss_type == 'l2':
|
| if mean:
|
| loss = torch.nn.functional.mse_loss(target, pred)
|
| else:
|
| loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
| else:
|
| raise NotImplementedError("unknown loss type '{loss_type}'")
|
|
|
| return loss
|
|
|
| def p_losses(self, x_start, t, noise=None, context=None):
|
| if noise is None:
|
| noise = torch.randn_like(x_start)
|
| x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
| model_out = self.model(x_noisy, t, context=context)
|
|
|
| if self.parameterization == "eps":
|
| target = noise
|
| elif self.parameterization == "x0":
|
| target = x_start
|
| else:
|
| raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
|
|
|
| return self.get_loss(model_out, target, mean=False).mean()
|
|
|
| def forward(self, x, *args, **kwargs):
|
| t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
|
| return self.p_losses(x, t, *args, **kwargs)
|
|
|
| def shared_step(self, batch):
|
| (x,y) = batch
|
| y = self.autoencoder.encode(y)[0]
|
| context = self.context_encoder(x) if self.conditional else None
|
| return self(y, context=context)
|
|
|
| def training_step(self, batch, batch_idx):
|
| loss = self.shared_step(batch)
|
| self.log("train_loss", loss)
|
| return loss
|
|
|
| @torch.no_grad()
|
| def validation_step(self, batch, batch_idx):
|
|
|
|
|
|
|
| loss = self.shared_step(batch)
|
| with self.ema_scope():
|
| loss_ema = self.shared_step(batch)
|
| log_params = {"on_step": False, "on_epoch": True, "prog_bar": True}
|
| self.log("val_loss", loss, **log_params)
|
| self.log("val_loss_ema", loss, **log_params)
|
|
|
|
|
| def test_step(self, batch, batch_idx):
|
| return self.validation_step(batch, batch_idx)
|
|
|
| def on_train_batch_end(self, *args, **kwargs):
|
| if self.use_ema:
|
| self.model_ema(self.model)
|
|
|
| def configure_optimizers(self):
|
| optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr,
|
| betas=(0.5, 0.9), weight_decay=1e-3)
|
| reduce_lr = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| optimizer, patience=3, factor=0.25, verbose=True
|
| )
|
| return {
|
| "optimizer": optimizer,
|
| "lr_scheduler": {
|
| "scheduler": reduce_lr,
|
| "monitor": "val_loss_ema",
|
| "frequency": 1,
|
| },
|
| }
|
|
|
| def optimizer_step(
|
| self,
|
| epoch,
|
| batch_idx,
|
| optimizer,
|
| optimizer_idx,
|
|
|
| **kwargs
|
| ):
|
| if self.trainer.global_step < self.lr_warmup:
|
| lr_scale = (self.trainer.global_step+1) / self.lr_warmup
|
| for pg in optimizer.param_groups:
|
| pg['lr'] = lr_scale * self.lr
|
|
|
| super().optimizer_step(
|
| epoch, batch_idx, optimizer,
|
| optimizer_idx,
|
|
|
| **kwargs
|
| )
|
|
|
|
|