| import warnings |
|
|
| import torch |
| from torch import nn |
| from torch.utils import checkpoint |
|
|
| from .utils import conv_nd, apply_initialization |
| from .openaimodel import Upsample, Downsample |
|
|
|
|
| class TimeEmbedLayer(nn.Module): |
|
|
| def __init__( |
| self, |
| base_channels, |
| time_embed_channels, |
| linear_init_mode="0" |
| ): |
| super(TimeEmbedLayer, self).__init__() |
| self.layer = nn.Sequential( |
| nn.Linear(base_channels, time_embed_channels), |
| nn.SiLU(), |
| nn.Linear(time_embed_channels, time_embed_channels), |
| ) |
| self.linear_init_mode = linear_init_mode |
|
|
| def forward(self, x): |
| return self.layer(x) |
|
|
| def reset_parameters(self): |
| apply_initialization(self.layer[0], linear_mode=self.linear_init_mode) |
| apply_initialization(self.layer[2], linear_mode=self.linear_init_mode) |
|
|
|
|
| class TimeEmbedResBlock(nn.Module): |
| r""" |
| Code is adapted from https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/openaimodel.py |
| |
| Modifications: |
| 1. Change GroupNorm32 to use arbitrary `num_groups`. |
| 2. Add method `self.reset_parameters()`. |
| 3. Use gradient checkpoint from PyTorch instead of the stable diffusion implementation https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/util.py#L102. |
| 4. If no input time embed, it degrades to res block. |
| """ |
| def __init__( |
| self, |
| channels, |
| dropout, |
| emb_channels=None, |
| out_channels=None, |
| use_conv=False, |
| use_embed=True, |
| use_scale_shift_norm=False, |
| dims=2, |
| use_checkpoint=False, |
| up=False, |
| down=False, |
| norm_groups=32, |
| ): |
| r""" |
| Parameters |
| ---------- |
| channels |
| dropout |
| emb_channels |
| out_channels |
| use_conv |
| use_embed: bool |
| include `emb` as input in `self.forward()` |
| use_scale_shift_norm: bool |
| take effect only when `use_embed == True` |
| dims |
| use_checkpoint |
| up |
| down |
| norm_groups |
| """ |
| super().__init__() |
| self.channels = channels |
| self.dropout = dropout |
| self.use_embed = use_embed |
| if use_embed: |
| assert isinstance(emb_channels, int) |
| self.emb_channels = emb_channels |
| self.out_channels = out_channels or channels |
| self.use_conv = use_conv |
| if use_checkpoint: |
| warnings.warn("use_checkpoint is not supported yet.") |
| use_checkpoint = False |
| self.use_checkpoint = use_checkpoint |
| self.use_scale_shift_norm = use_scale_shift_norm |
|
|
| self.in_layers = nn.Sequential( |
| nn.GroupNorm(num_groups=norm_groups if channels % norm_groups == 0 else channels, |
| num_channels=channels), |
| nn.SiLU(), |
| conv_nd(dims, channels, self.out_channels, 3, padding=1), |
| ) |
|
|
| self.updown = up or down |
|
|
| if up: |
| self.h_upd = Upsample(channels, False, dims) |
| self.x_upd = Upsample(channels, False, dims) |
| elif down: |
| self.h_upd = Downsample(channels, False, dims) |
| self.x_upd = Downsample(channels, False, dims) |
| else: |
| self.h_upd = self.x_upd = nn.Identity() |
|
|
| if use_embed: |
| self.emb_layers = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear( |
| in_features=emb_channels, |
| out_features=2 * self.out_channels if use_scale_shift_norm else self.out_channels, |
| ), |
| ) |
| self.out_layers = nn.Sequential( |
| nn.GroupNorm(num_groups=norm_groups if self.out_channels % norm_groups == 0 else self.out_channels, |
| num_channels=self.out_channels), |
| nn.SiLU(), |
| nn.Dropout(p=dropout), |
| conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1), |
| ) |
|
|
| if self.out_channels == channels: |
| self.skip_connection = nn.Identity() |
| elif use_conv: |
| self.skip_connection = conv_nd( |
| dims, channels, self.out_channels, 3, padding=1 |
| ) |
| else: |
| self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) |
|
|
| self.reset_parameters() |
|
|
| def forward(self, x, emb=None): |
| """ |
| Apply the block to a Tensor, conditioned on a timestep embedding. |
| |
| Parameters |
| ---------- |
| x: an [N x C x ...] Tensor of features. |
| emb: an [N x emb_channels] Tensor of timestep embeddings. |
| |
| Returns |
| ------- |
| out: an [N x C x ...] Tensor of outputs. |
| """ |
| if self.updown: |
| in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] |
| h = in_rest(x) |
| h = self.h_upd(h) |
| x = self.x_upd(x) |
| h = in_conv(h) |
| else: |
| h = self.in_layers(x) |
| if self.use_embed: |
| emb_out = self.emb_layers(emb).type(h.dtype) |
| while len(emb_out.shape) < len(h.shape): |
| emb_out = emb_out[..., None] |
| if self.use_scale_shift_norm: |
| out_norm, out_rest = self.out_layers[0], self.out_layers[1:] |
| scale, shift = torch.chunk(emb_out, 2, dim=1) |
| h = out_norm(h) * (1 + scale) + shift |
| h = out_rest(h) |
| else: |
| h = h + emb_out |
| h = self.out_layers(h) |
| else: |
| h = self.out_layers(h) |
| return self.skip_connection(x) + h |
|
|
| def reset_parameters(self): |
| for m in self.modules(): |
| apply_initialization(m) |
| for p in self.out_layers[-1].parameters(): |
| nn.init.zeros_(p) |
|
|