| from abc import abstractmethod |
| import math |
|
|
| import numpy as np |
| import torch as th |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import os |
| import math |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| from einops import repeat |
| import importlib |
|
|
| class Linear(nn.Linear): |
| def forward(self, input): |
| return F.linear(input, self.weight.to(input.dtype), self.bias.to(input.dtype) if self.bias is not None else None) |
|
|
| def instantiate_from_config(config): |
| if not "target" in config: |
| if config == '__is_first_stage__': |
| return None |
| elif config == "__is_unconditional__": |
| return None |
| raise KeyError("Expected key `target` to instantiate.") |
| return get_obj_from_str(config["target"])(**config.get("params", dict())) |
|
|
|
|
| def get_obj_from_str(string, reload=False): |
| module, cls = string.rsplit(".", 1) |
| if reload: |
| module_imp = importlib.import_module(module) |
| importlib.reload(module_imp) |
| return getattr(importlib.import_module(module, package=None), cls) |
|
|
| def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): |
| if schedule == "linear": |
| betas = ( |
| torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 |
| ) |
|
|
| elif schedule == "cosine": |
| timesteps = ( |
| torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s |
| ) |
| alphas = timesteps / (1 + cosine_s) * np.pi / 2 |
| alphas = torch.cos(alphas).pow(2) |
| alphas = alphas / alphas[0] |
| betas = 1 - alphas[1:] / alphas[:-1] |
| betas = np.clip(betas, a_min=0, a_max=0.999) |
|
|
| elif schedule == "sqrt_linear": |
| betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) |
| elif schedule == "sqrt": |
| betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 |
| else: |
| raise ValueError(f"schedule '{schedule}' unknown.") |
| return betas.numpy() |
|
|
|
|
| def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): |
| if ddim_discr_method == 'uniform': |
| c = num_ddpm_timesteps // num_ddim_timesteps |
| ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) |
| elif ddim_discr_method == 'quad': |
| ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) |
| else: |
| raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') |
|
|
| |
| |
| steps_out = ddim_timesteps + 1 |
| if verbose: |
| print(f'Selected timesteps for ddim sampler: {steps_out}') |
| return steps_out |
|
|
|
|
| def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): |
| |
| alphas = alphacums[ddim_timesteps] |
| alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) |
|
|
| |
| sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) |
| if verbose: |
| print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') |
| print(f'For the chosen value of eta, which is {eta}, ' |
| f'this results in the following sigma_t schedule for ddim sampler {sigmas}') |
| return sigmas, alphas, alphas_prev |
|
|
|
|
| def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): |
| """ |
| Create a beta schedule that discretizes the given alpha_t_bar function, |
| which defines the cumulative product of (1-beta) over time from t = [0,1]. |
| :param num_diffusion_timesteps: the number of betas to produce. |
| :param alpha_bar: a lambda that takes an argument t from 0 to 1 and |
| produces the cumulative product of (1-beta) up to that |
| part of the diffusion process. |
| :param max_beta: the maximum beta to use; use values lower than 1 to |
| prevent singularities. |
| """ |
| betas = [] |
| for i in range(num_diffusion_timesteps): |
| t1 = i / num_diffusion_timesteps |
| t2 = (i + 1) / num_diffusion_timesteps |
| betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) |
| return np.array(betas) |
|
|
|
|
| def extract_into_tensor(a, t, x_shape): |
| b, *_ = t.shape |
| out = a.gather(-1, t) |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
| def checkpoint(func, inputs, params, flag): |
| """ |
| Evaluate a function without caching intermediate activations, allowing for |
| reduced memory at the expense of extra compute in the backward pass. |
| :param func: the function to evaluate. |
| :param inputs: the argument sequence to pass to `func`. |
| :param params: a sequence of parameters `func` depends on but does not |
| explicitly take as arguments. |
| :param flag: if False, disable gradient checkpointing. |
| """ |
| if flag: |
| args = tuple(inputs) + tuple(params) |
| return CheckpointFunction.apply(func, len(inputs), *args) |
| else: |
| return func(*inputs) |
|
|
|
|
| class CheckpointFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, run_function, length, *args): |
| ctx.run_function = run_function |
| ctx.input_tensors = list(args[:length]) |
| ctx.input_params = list(args[length:]) |
| ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), |
| "dtype": torch.get_autocast_gpu_dtype(), |
| "cache_enabled": torch.is_autocast_cache_enabled()} |
| with torch.no_grad(): |
| output_tensors = ctx.run_function(*ctx.input_tensors) |
| return output_tensors |
|
|
| @staticmethod |
| def backward(ctx, *output_grads): |
| ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] |
| with torch.enable_grad(), \ |
| torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): |
| |
| |
| |
| shallow_copies = [x.view_as(x) for x in ctx.input_tensors] |
| output_tensors = ctx.run_function(*shallow_copies) |
| input_grads = torch.autograd.grad( |
| output_tensors, |
| ctx.input_tensors + ctx.input_params, |
| output_grads, |
| allow_unused=True, |
| ) |
| del ctx.input_tensors |
| del ctx.input_params |
| del output_tensors |
| return (None, None) + input_grads |
|
|
|
|
| def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): |
| """ |
| Create sinusoidal timestep embeddings. |
| :param timesteps: a 1-D Tensor of N indices, one per batch element. |
| These may be fractional. |
| :param dim: the dimension of the output. |
| :param max_period: controls the minimum frequency of the embeddings. |
| :return: an [N x dim] Tensor of positional embeddings. |
| """ |
| if not repeat_only: |
| half = dim // 2 |
| freqs = torch.exp( |
| -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half |
| ).to(device=timesteps.device) |
| args = timesteps[:, None].float() * freqs[None] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2: |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| else: |
| embedding = repeat(timesteps, 'b -> b d', d=dim) |
| return embedding |
|
|
|
|
| def zero_module(module): |
| """ |
| Zero out the parameters of a module and return it. |
| """ |
| for p in module.parameters(): |
| p.detach().zero_() |
| return module |
|
|
|
|
| def scale_module(module, scale): |
| """ |
| Scale the parameters of a module and return it. |
| """ |
| for p in module.parameters(): |
| p.detach().mul_(scale) |
| return module |
|
|
|
|
| def mean_flat(tensor): |
| """ |
| Take the mean over all non-batch dimensions. |
| """ |
| return tensor.mean(dim=list(range(1, len(tensor.shape)))) |
|
|
|
|
| def normalization(channels): |
| """ |
| Make a standard normalization layer. |
| :param channels: number of input channels. |
| :return: an nn.Module for normalization. |
| """ |
| return GroupNorm32(32, channels) |
|
|
|
|
| |
| SiLU = nn.SiLU(inplace=True) |
| |
| |
| |
|
|
|
|
| class GroupNorm32(nn.GroupNorm): |
| def forward(self, x): |
| return super().forward(x) |
|
|
| def conv_nd(dims, *args, **kwargs): |
| """ |
| Create a 1D, 2D, or 3D convolution module. |
| """ |
| if dims == 1: |
| return nn.Conv1d(*args, **kwargs) |
| elif dims == 2: |
| return nn.Conv2d(*args, **kwargs) |
| elif dims == 3: |
| return nn.Conv3d(*args, **kwargs) |
| raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
| def linear(*args, **kwargs): |
| """ |
| Create a linear module. |
| """ |
| return Linear(*args, **kwargs) |
|
|
|
|
| def avg_pool_nd(dims, *args, **kwargs): |
| """ |
| Create a 1D, 2D, or 3D average pooling module. |
| """ |
| if dims == 1: |
| return nn.AvgPool1d(*args, **kwargs) |
| elif dims == 2: |
| return nn.AvgPool2d(*args, **kwargs) |
| elif dims == 3: |
| return nn.AvgPool3d(*args, **kwargs) |
| raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
| class HybridConditioner(nn.Module): |
|
|
| def __init__(self, c_concat_config, c_crossattn_config): |
| super().__init__() |
| self.concat_conditioner = instantiate_from_config(c_concat_config) |
| self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) |
|
|
| def forward(self, c_concat, c_crossattn): |
| c_concat = self.concat_conditioner(c_concat) |
| c_crossattn = self.crossattn_conditioner(c_crossattn) |
| return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} |
|
|
|
|
| def noise_like(shape, device, repeat=False): |
| repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) |
| noise = lambda: torch.randn(shape, device=device) |
| return repeat_noise() if repeat else noise() |
|
|
| from inspect import isfunction |
| import math |
| import torch |
| import torch.nn.functional as F |
| from torch import nn, einsum |
| from einops import rearrange, repeat |
| from typing import Optional, Any |
|
|
|
|
| try: |
| import xformers |
| import xformers.ops |
| XFORMERS_IS_AVAILBLE = True |
| except: |
| XFORMERS_IS_AVAILBLE = False |
|
|
| |
| import os |
| _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp16") |
|
|
| def exists(val): |
| return val is not None |
|
|
|
|
| def uniq(arr): |
| return{el: True for el in arr}.keys() |
|
|
|
|
| def default(val, d): |
| if exists(val): |
| return val |
| return d() if isfunction(d) else d |
|
|
|
|
| def max_neg_value(t): |
| return -torch.finfo(t.dtype).max |
|
|
|
|
| def init_(tensor): |
| dim = tensor.shape[-1] |
| std = 1 / math.sqrt(dim) |
| tensor.uniform_(-std, std) |
| return tensor |
|
|
|
|
| |
| class GEGLU(nn.Module): |
| def __init__(self, dim_in, dim_out): |
| super().__init__() |
| self.proj = Linear(dim_in, dim_out * 2) |
|
|
| def forward(self, x): |
| x, gate = self.proj(x).chunk(2, dim=-1) |
| return x * F.gelu(gate) |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): |
| super().__init__() |
| inner_dim = int(dim * mult) |
| dim_out = default(dim_out, dim) |
| project_in = nn.Sequential( |
| Linear(dim, inner_dim), |
| nn.GELU() |
| ) if not glu else GEGLU(dim, inner_dim) |
|
|
| self.net = nn.Sequential( |
| project_in, |
| nn.Dropout(dropout), |
| Linear(inner_dim, dim_out) |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| def zero_module(module): |
| """ |
| Zero out the parameters of a module and return it. |
| """ |
| for p in module.parameters(): |
| p.detach().zero_() |
| return module |
|
|
|
|
| def Normalize(in_channels): |
| return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) |
|
|
|
|
| class SpatialSelfAttention(nn.Module): |
| def __init__(self, in_channels): |
| super().__init__() |
| self.in_channels = in_channels |
|
|
| self.norm = Normalize(in_channels) |
| self.q = torch.nn.Conv2d(in_channels, |
| in_channels, |
| kernel_size=1, |
| stride=1, |
| padding=0) |
| self.k = torch.nn.Conv2d(in_channels, |
| in_channels, |
| kernel_size=1, |
| stride=1, |
| padding=0) |
| self.v = torch.nn.Conv2d(in_channels, |
| in_channels, |
| kernel_size=1, |
| stride=1, |
| padding=0) |
| self.proj_out = torch.nn.Conv2d(in_channels, |
| in_channels, |
| kernel_size=1, |
| stride=1, |
| padding=0) |
|
|
| def forward(self, x): |
| h_ = x |
| h_ = self.norm(h_) |
| q = self.q(h_) |
| k = self.k(h_) |
| v = self.v(h_) |
|
|
| |
| b,c,h,w = q.shape |
| q = rearrange(q, 'b c h w -> b (h w) c') |
| k = rearrange(k, 'b c h w -> b c (h w)') |
| w_ = torch.einsum('bij,bjk->bik', q, k) |
|
|
| w_ = w_ * (int(c)**(-0.5)) |
| w_ = torch.nn.functional.softmax(w_, dim=2) |
|
|
| |
| v = rearrange(v, 'b c h w -> b c (h w)') |
| w_ = rearrange(w_, 'b i j -> b j i') |
| h_ = torch.einsum('bij,bjk->bik', v, w_) |
| h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) |
| h_ = self.proj_out(h_) |
|
|
| return x+h_ |
|
|
|
|
| class CrossAttention(nn.Module): |
| def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): |
| super().__init__() |
| inner_dim = dim_head * heads |
| context_dim = default(context_dim, query_dim) |
|
|
| self.scale = dim_head ** -0.5 |
| self.heads = heads |
|
|
| self.to_q = Linear(query_dim, inner_dim, bias=False) |
| self.to_k = Linear(context_dim, inner_dim, bias=False) |
| self.to_v = Linear(context_dim, inner_dim, bias=False) |
|
|
| self.to_out = nn.Sequential( |
| Linear(inner_dim, query_dim), |
| nn.Dropout(dropout) |
| ) |
|
|
| def forward(self, x, context=None, mask=None): |
| h = self.heads |
|
|
| q = self.to_q(x) |
| context = default(context, x) |
| k = self.to_k(context) |
| v = self.to_v(context) |
|
|
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) |
|
|
| |
| if _ATTN_PRECISION =="fp32": |
| with torch.autocast(enabled=False, device_type=q.device.type if q.device.type != 'cpu' else 'cpu'): |
| q, k = q.float(), k.float() |
| sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
| else: |
| sim = einsum('b i d, b j d -> b i j', q, k) * self.scale |
| |
| del q, k |
| |
| if exists(mask): |
| mask = rearrange(mask, 'b ... -> b (...)') |
| max_neg_value = -torch.finfo(sim.dtype).max |
| mask = repeat(mask, 'b j -> (b h) () j', h=h) |
| sim.masked_fill_(~mask, max_neg_value) |
|
|
| |
| sim = sim.softmax(dim=-1) |
|
|
| out = einsum('b i j, b j d -> b i d', sim, v) |
| out = rearrange(out, '(b h) n d -> b n (h d)', h=h) |
| return self.to_out(out) |
|
|
|
|
| class MemoryEfficientCrossAttention(nn.Module): |
| |
| def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): |
| super().__init__() |
| print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " |
| f"{heads} heads.") |
| inner_dim = dim_head * heads |
| context_dim = default(context_dim, query_dim) |
|
|
| self.heads = heads |
| self.dim_head = dim_head |
|
|
| self.to_q = Linear(query_dim, inner_dim, bias=False) |
| self.to_k = Linear(context_dim, inner_dim, bias=False) |
| self.to_v = Linear(context_dim, inner_dim, bias=False) |
|
|
| self.to_out = nn.Sequential(Linear(inner_dim, query_dim), nn.Dropout(dropout)) |
| self.attention_op: Optional[Any] = None |
|
|
| def forward(self, x, context=None, mask=None): |
| q = self.to_q(x) |
| context = default(context, x) |
| k = self.to_k(context) |
| v = self.to_v(context) |
|
|
| b, _, _ = q.shape |
| q, k, v = map( |
| lambda t: t.unsqueeze(3) |
| .reshape(b, t.shape[1], self.heads, self.dim_head) |
| .permute(0, 2, 1, 3) |
| .reshape(b * self.heads, t.shape[1], self.dim_head) |
| .contiguous(), |
| (q, k, v), |
| ) |
|
|
| |
| try: |
| out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) |
| except (NotImplementedError, RuntimeError): |
| |
| scale = self.dim_head ** -0.5 |
| attn_weights = torch.bmm(q * scale, k.transpose(-2, -1)) |
| attn_weights = torch.softmax(attn_weights, dim=-1) |
| out = torch.bmm(attn_weights, v) |
|
|
| if exists(mask): |
| raise NotImplementedError |
| out = ( |
| out.unsqueeze(0) |
| .reshape(b, self.heads, out.shape[1], self.dim_head) |
| .permute(0, 2, 1, 3) |
| .reshape(b, out.shape[1], self.heads * self.dim_head) |
| ) |
| return self.to_out(out) |
|
|
|
|
| class BasicTransformerBlock(nn.Module): |
| ATTENTION_MODES = { |
| "softmax": CrossAttention, |
| "softmax-xformers": MemoryEfficientCrossAttention |
| } |
| def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, |
| disable_self_attn=False): |
| super().__init__() |
| attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" |
| assert attn_mode in self.ATTENTION_MODES |
| attn_cls = self.ATTENTION_MODES[attn_mode] |
| self.disable_self_attn = disable_self_attn |
| self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, |
| context_dim=context_dim if self.disable_self_attn else None) |
| self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) |
| self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, |
| heads=n_heads, dim_head=d_head, dropout=dropout) |
| self.norm1 = nn.LayerNorm(dim) |
| self.norm2 = nn.LayerNorm(dim) |
| self.norm3 = nn.LayerNorm(dim) |
| self.checkpoint = checkpoint |
|
|
| def forward(self, x, context=None): |
| return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) |
|
|
| def _forward(self, x, context=None): |
| x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x |
| x = self.attn2(self.norm2(x), context=context) + x |
| x = self.ff(self.norm3(x)) + x |
| return x |
|
|
|
|
| class SpatialTransformer(nn.Module): |
| """ |
| Transformer block for image-like data. |
| First, project the input (aka embedding) |
| and reshape to b, t, d. |
| Then apply standard transformer action. |
| Finally, reshape to image |
| NEW: use_linear for more efficiency instead of the 1x1 convs |
| """ |
| def __init__(self, in_channels, n_heads, d_head, |
| depth=1, dropout=0., context_dim=None, |
| disable_self_attn=False, use_linear=False, |
| use_checkpoint=True): |
| super().__init__() |
| if exists(context_dim) and not isinstance(context_dim, list): |
| context_dim = [context_dim] |
| self.in_channels = in_channels |
| inner_dim = n_heads * d_head |
| self.norm = Normalize(in_channels) |
| if not use_linear: |
| self.proj_in = nn.Conv2d(in_channels, |
| inner_dim, |
| kernel_size=1, |
| stride=1, |
| padding=0) |
| else: |
| self.proj_in = Linear(in_channels, inner_dim) |
|
|
| self.transformer_blocks = nn.ModuleList( |
| [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], |
| disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) |
| for d in range(depth)] |
| ) |
| if not use_linear: |
| self.proj_out = zero_module(nn.Conv2d(inner_dim, |
| in_channels, |
| kernel_size=1, |
| stride=1, |
| padding=0)) |
| else: |
| self.proj_out = zero_module(Linear(in_channels, inner_dim)) |
| self.use_linear = use_linear |
|
|
| def forward(self, x, context=None): |
| |
| if not isinstance(context, list): |
| context = [context] |
| b, c, h, w = x.shape |
| x_in = x |
| x = self.norm(x) |
| if not self.use_linear: |
| x = self.proj_in(x) |
| x = rearrange(x, 'b c h w -> b (h w) c').contiguous() |
| if self.use_linear: |
| x = self.proj_in(x) |
| for i, block in enumerate(self.transformer_blocks): |
| x = block(x, context=context[i]) |
| if self.use_linear: |
| x = self.proj_out(x) |
| x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() |
| if not self.use_linear: |
| x = self.proj_out(x) |
| return x + x_in |
|
|
|
|
| class BasicTransformerBlock3D(BasicTransformerBlock): |
|
|
| def forward(self, x, context=None, num_frames=1): |
| return checkpoint(self._forward, (x, context, num_frames), self.parameters(), False) |
|
|
| def _forward(self, x, context=None, num_frames=1): |
| x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous() |
| x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x |
| x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous() |
| x = self.attn2(self.norm2(x), context=context) + x |
| x = self.ff(self.norm3(x)) + x |
| return x |
|
|
|
|
| class SpatialTransformer3D(nn.Module): |
| ''' 3D self-attention ''' |
| def __init__(self, in_channels, n_heads, d_head, |
| depth=1, dropout=0., context_dim=None, |
| disable_self_attn=False, use_linear=False, |
| use_checkpoint=True): |
| super().__init__() |
| if exists(context_dim) and not isinstance(context_dim, list): |
| context_dim = [context_dim] |
| self.in_channels = in_channels |
| inner_dim = n_heads * d_head |
| self.norm = Normalize(in_channels) |
| if not use_linear: |
| self.proj_in = nn.Conv2d(in_channels, |
| inner_dim, |
| kernel_size=1, |
| stride=1, |
| padding=0) |
| else: |
| self.proj_in = Linear(in_channels, inner_dim) |
|
|
| self.transformer_blocks = nn.ModuleList( |
| [BasicTransformerBlock3D(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], |
| disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) |
| for d in range(depth)] |
| ) |
| if not use_linear: |
| self.proj_out = zero_module(nn.Conv2d(inner_dim, |
| in_channels, |
| kernel_size=1, |
| stride=1, |
| padding=0)) |
| else: |
| self.proj_out = zero_module(Linear(in_channels, inner_dim)) |
| self.use_linear = use_linear |
|
|
| def forward(self, x, context=None, num_frames=1): |
| |
| if not isinstance(context, list): |
| context = [context] |
| b, c, h, w = x.shape |
| x_in = x |
| x = self.norm(x) |
| if not self.use_linear: |
| x = self.proj_in(x) |
| x = rearrange(x, 'b c h w -> b (h w) c').contiguous() |
| if self.use_linear: |
| x = self.proj_in(x) |
| for i, block in enumerate(self.transformer_blocks): |
| x = block(x, context=context[i], num_frames=num_frames) |
| if self.use_linear: |
| x = self.proj_out(x) |
| x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() |
| if not self.use_linear: |
| x = self.proj_out(x) |
| return x + x_in |
|
|
| |
| def convert_module_to_f16(x): |
| pass |
|
|
| def convert_module_to_f32(x): |
| pass |
|
|
|
|
| |
| class AttentionPool2d(nn.Module): |
| """ |
| Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py |
| """ |
|
|
| def __init__( |
| self, |
| spacial_dim: int, |
| embed_dim: int, |
| num_heads_channels: int, |
| output_dim: int = None, |
| ): |
| super().__init__() |
| self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) |
| self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) |
| self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) |
| self.num_heads = embed_dim // num_heads_channels |
| self.attention = QKVAttention(self.num_heads) |
|
|
| def forward(self, x): |
| b, c, *_spatial = x.shape |
| x = x.reshape(b, c, -1) |
| x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) |
| x = x + self.positional_embedding[None, :, :].to(x.dtype) |
| x = self.qkv_proj(x) |
| x = self.attention(x) |
| x = self.c_proj(x) |
| return x[:, :, 0] |
|
|
|
|
| class TimestepBlock(nn.Module): |
| """ |
| Any module where forward() takes timestep embeddings as a second argument. |
| """ |
|
|
| @abstractmethod |
| def forward(self, x, emb): |
| """ |
| Apply the module to `x` given `emb` timestep embeddings. |
| """ |
|
|
|
|
| class TimestepEmbedSequential(nn.Sequential, TimestepBlock): |
| """ |
| A sequential module that passes timestep embeddings to the children that |
| support it as an extra input. |
| """ |
|
|
| def forward(self, x, emb, context=None, num_frames=1): |
| for layer in self: |
| if isinstance(layer, TimestepBlock): |
| x = layer(x, emb) |
| elif isinstance(layer, SpatialTransformer3D): |
| x = layer(x, context, num_frames=num_frames) |
| elif isinstance(layer, SpatialTransformer): |
| x = layer(x, context) |
| else: |
| x = layer(x) |
| return x |
|
|
|
|
| class Upsample(nn.Module): |
| """ |
| An upsampling layer with an optional convolution. |
| :param channels: channels in the inputs and outputs. |
| :param use_conv: a bool determining if a convolution is applied. |
| :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then |
| upsampling occurs in the inner-two dimensions. |
| """ |
|
|
| def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): |
| super().__init__() |
| self.channels = channels |
| self.out_channels = out_channels or channels |
| self.use_conv = use_conv |
| self.dims = dims |
| if use_conv: |
| self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) |
|
|
| def forward(self, x): |
| assert x.shape[1] == self.channels |
| if self.dims == 3: |
| x = F.interpolate( |
| x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" |
| ) |
| else: |
| x = F.interpolate(x, scale_factor=2, mode="nearest") |
| if self.use_conv: |
| x = self.conv(x) |
| return x |
|
|
| class TransposedUpsample(nn.Module): |
| 'Learned 2x upsampling without padding' |
| def __init__(self, channels, out_channels=None, ks=5): |
| super().__init__() |
| self.channels = channels |
| self.out_channels = out_channels or channels |
|
|
| self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) |
|
|
| def forward(self,x): |
| return self.up(x) |
|
|
|
|
| class Downsample(nn.Module): |
| """ |
| A downsampling layer with an optional convolution. |
| :param channels: channels in the inputs and outputs. |
| :param use_conv: a bool determining if a convolution is applied. |
| :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then |
| downsampling occurs in the inner-two dimensions. |
| """ |
|
|
| def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): |
| super().__init__() |
| self.channels = channels |
| self.out_channels = out_channels or channels |
| self.use_conv = use_conv |
| self.dims = dims |
| stride = 2 if dims != 3 else (1, 2, 2) |
| if use_conv: |
| self.op = conv_nd( |
| dims, self.channels, self.out_channels, 3, stride=stride, padding=padding |
| ) |
| else: |
| assert self.channels == self.out_channels |
| self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) |
|
|
| def forward(self, x): |
| assert x.shape[1] == self.channels |
| return self.op(x) |
|
|
|
|
| class ResBlock(TimestepBlock): |
| """ |
| A residual block that can optionally change the number of channels. |
| :param channels: the number of input channels. |
| :param emb_channels: the number of timestep embedding channels. |
| :param dropout: the rate of dropout. |
| :param out_channels: if specified, the number of out channels. |
| :param use_conv: if True and out_channels is specified, use a spatial |
| convolution instead of a smaller 1x1 convolution to change the |
| channels in the skip connection. |
| :param dims: determines if the signal is 1D, 2D, or 3D. |
| :param use_checkpoint: if True, use gradient checkpointing on this module. |
| :param up: if True, use this block for upsampling. |
| :param down: if True, use this block for downsampling. |
| """ |
|
|
| def __init__( |
| self, |
| channels, |
| emb_channels, |
| dropout, |
| out_channels=None, |
| use_conv=False, |
| use_scale_shift_norm=False, |
| dims=2, |
| use_checkpoint=False, |
| up=False, |
| down=False, |
| ): |
| super().__init__() |
| self.channels = channels |
| self.emb_channels = emb_channels |
| self.dropout = dropout |
| self.out_channels = out_channels or channels |
| self.use_conv = use_conv |
| self.use_checkpoint = use_checkpoint |
| self.use_scale_shift_norm = use_scale_shift_norm |
|
|
| self.in_layers = nn.Sequential( |
| normalization(channels), |
| nn.SiLU(), |
| conv_nd(dims, channels, self.out_channels, 3, padding=1), |
| ) |
|
|
| self.updown = up or down |
|
|
| if up: |
| self.h_upd = Upsample(channels, False, dims) |
| self.x_upd = Upsample(channels, False, dims) |
| elif down: |
| self.h_upd = Downsample(channels, False, dims) |
| self.x_upd = Downsample(channels, False, dims) |
| else: |
| self.h_upd = self.x_upd = nn.Identity() |
|
|
| self.emb_layers = nn.Sequential( |
| nn.SiLU(), |
| linear( |
| emb_channels, |
| 2 * self.out_channels if use_scale_shift_norm else self.out_channels, |
| ), |
| ) |
| self.out_layers = nn.Sequential( |
| normalization(self.out_channels), |
| nn.SiLU(), |
| nn.Dropout(p=dropout), |
| zero_module( |
| conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) |
| ), |
| ) |
|
|
| if self.out_channels == channels: |
| self.skip_connection = nn.Identity() |
| elif use_conv: |
| self.skip_connection = conv_nd( |
| dims, channels, self.out_channels, 3, padding=1 |
| ) |
| else: |
| self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) |
|
|
| def forward(self, x, emb): |
| """ |
| Apply the block to a Tensor, conditioned on a timestep embedding. |
| :param x: an [N x C x ...] Tensor of features. |
| :param emb: an [N x emb_channels] Tensor of timestep embeddings. |
| :return: an [N x C x ...] Tensor of outputs. |
| """ |
| return checkpoint( |
| self._forward, (x, emb), self.parameters(), self.use_checkpoint |
| ) |
|
|
|
|
| def _forward(self, x, emb): |
| if self.updown: |
| in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] |
| h = in_rest(x) |
| h = self.h_upd(h) |
| x = self.x_upd(x) |
| h = in_conv(h) |
| else: |
| h = self.in_layers(x) |
| emb_out = self.emb_layers(emb).type(h.dtype) |
| while len(emb_out.shape) < len(h.shape): |
| emb_out = emb_out[..., None] |
| if self.use_scale_shift_norm: |
| out_norm, out_rest = self.out_layers[0], self.out_layers[1:] |
| scale, shift = th.chunk(emb_out, 2, dim=1) |
| h = out_norm(h) * (1 + scale) + shift |
| h = out_rest(h) |
| else: |
| h = h + emb_out |
| h = self.out_layers(h) |
| return self.skip_connection(x) + h |
|
|
|
|
| class AttentionBlock(nn.Module): |
| """ |
| An attention block that allows spatial positions to attend to each other. |
| Originally ported from here, but adapted to the N-d case. |
| https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. |
| """ |
|
|
| def __init__( |
| self, |
| channels, |
| num_heads=1, |
| num_head_channels=-1, |
| use_checkpoint=False, |
| use_new_attention_order=False, |
| ): |
| super().__init__() |
| self.channels = channels |
| if num_head_channels == -1: |
| self.num_heads = num_heads |
| else: |
| assert ( |
| channels % num_head_channels == 0 |
| ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" |
| self.num_heads = channels // num_head_channels |
| self.use_checkpoint = use_checkpoint |
| self.norm = normalization(channels) |
| self.qkv = conv_nd(1, channels, channels * 3, 1) |
| if use_new_attention_order: |
| |
| self.attention = QKVAttention(self.num_heads) |
| else: |
| |
| self.attention = QKVAttentionLegacy(self.num_heads) |
|
|
| self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) |
|
|
| def forward(self, x): |
| return checkpoint(self._forward, (x,), self.parameters(), True) |
|
|
| def _forward(self, x): |
| b, c, *spatial = x.shape |
| x = x.reshape(b, c, -1) |
| qkv = self.qkv(self.norm(x)) |
| h = self.attention(qkv) |
| h = self.proj_out(h) |
| return (x + h).reshape(b, c, *spatial) |
|
|
|
|
| def count_flops_attn(model, _x, y): |
| """ |
| A counter for the `thop` package to count the operations in an |
| attention operation. |
| Meant to be used like: |
| macs, params = thop.profile( |
| model, |
| inputs=(inputs, timestamps), |
| custom_ops={QKVAttention: QKVAttention.count_flops}, |
| ) |
| """ |
| b, c, *spatial = y[0].shape |
| num_spatial = int(np.prod(spatial)) |
| |
| |
| |
| matmul_ops = 2 * b * (num_spatial ** 2) * c |
| model.total_ops += th.DoubleTensor([matmul_ops]) |
|
|
|
|
| class QKVAttentionLegacy(nn.Module): |
| """ |
| A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping |
| """ |
|
|
| def __init__(self, n_heads): |
| super().__init__() |
| self.n_heads = n_heads |
|
|
| def forward(self, qkv): |
| """ |
| Apply QKV attention. |
| :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. |
| :return: an [N x (H * C) x T] tensor after attention. |
| """ |
| bs, width, length = qkv.shape |
| assert width % (3 * self.n_heads) == 0 |
| ch = width // (3 * self.n_heads) |
| q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) |
| scale = 1 / math.sqrt(math.sqrt(ch)) |
| weight = th.einsum( |
| "bct,bcs->bts", q * scale, k * scale |
| ) |
| weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) |
| a = th.einsum("bts,bcs->bct", weight, v) |
| return a.reshape(bs, -1, length) |
|
|
| @staticmethod |
| def count_flops(model, _x, y): |
| return count_flops_attn(model, _x, y) |
|
|
|
|
| class QKVAttention(nn.Module): |
| """ |
| A module which performs QKV attention and splits in a different order. |
| """ |
|
|
| def __init__(self, n_heads): |
| super().__init__() |
| self.n_heads = n_heads |
|
|
| def forward(self, qkv): |
| """ |
| Apply QKV attention. |
| :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. |
| :return: an [N x (H * C) x T] tensor after attention. |
| """ |
| bs, width, length = qkv.shape |
| assert width % (3 * self.n_heads) == 0 |
| ch = width // (3 * self.n_heads) |
| q, k, v = qkv.chunk(3, dim=1) |
| scale = 1 / math.sqrt(math.sqrt(ch)) |
| weight = th.einsum( |
| "bct,bcs->bts", |
| (q * scale).view(bs * self.n_heads, ch, length), |
| (k * scale).view(bs * self.n_heads, ch, length), |
| ) |
| weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) |
| a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) |
| return a.reshape(bs, -1, length) |
|
|
| @staticmethod |
| def count_flops(model, _x, y): |
| return count_flops_attn(model, _x, y) |
|
|
|
|
| class Timestep(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t): |
| return timestep_embedding(t, self.dim) |
|
|
|
|
| class UNetModel(nn.Module): |
| """ |
| The full UNet model with attention and timestep embedding. |
| :param in_channels: channels in the input Tensor. |
| :param model_channels: base channel count for the model. |
| :param out_channels: channels in the output Tensor. |
| :param num_res_blocks: number of residual blocks per downsample. |
| :param attention_resolutions: a collection of downsample rates at which |
| attention will take place. May be a set, list, or tuple. |
| For example, if this contains 4, then at 4x downsampling, attention |
| will be used. |
| :param dropout: the dropout probability. |
| :param channel_mult: channel multiplier for each level of the UNet. |
| :param conv_resample: if True, use learned convolutions for upsampling and |
| downsampling. |
| :param dims: determines if the signal is 1D, 2D, or 3D. |
| :param num_classes: if specified (as an int), then this model will be |
| class-conditional with `num_classes` classes. |
| :param use_checkpoint: use gradient checkpointing to reduce memory usage. |
| :param num_heads: the number of attention heads in each attention layer. |
| :param num_heads_channels: if specified, ignore num_heads and instead use |
| a fixed channel width per attention head. |
| :param num_heads_upsample: works with num_heads to set a different number |
| of heads for upsampling. Deprecated. |
| :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. |
| :param resblock_updown: use residual blocks for up/downsampling. |
| :param use_new_attention_order: use a different attention pattern for potentially |
| increased efficiency. |
| """ |
|
|
| def __init__( |
| self, |
| image_size, |
| in_channels, |
| model_channels, |
| out_channels, |
| num_res_blocks, |
| attention_resolutions, |
| dropout=0, |
| channel_mult=(1, 2, 4, 8), |
| conv_resample=True, |
| dims=2, |
| num_classes=None, |
| use_checkpoint=False, |
| use_fp16=False, |
| use_bf16=False, |
| num_heads=-1, |
| num_head_channels=-1, |
| num_heads_upsample=-1, |
| use_scale_shift_norm=False, |
| resblock_updown=False, |
| use_new_attention_order=False, |
| use_spatial_transformer=False, |
| transformer_depth=1, |
| context_dim=None, |
| n_embed=None, |
| legacy=True, |
| disable_self_attentions=None, |
| num_attention_blocks=None, |
| disable_middle_self_attn=False, |
| use_linear_in_transformer=False, |
| adm_in_channels=None, |
| ): |
| super().__init__() |
| if use_spatial_transformer: |
| assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' |
|
|
| if context_dim is not None: |
| assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' |
| from omegaconf.listconfig import ListConfig |
| if type(context_dim) == ListConfig: |
| context_dim = list(context_dim) |
|
|
| if num_heads_upsample == -1: |
| num_heads_upsample = num_heads |
|
|
| if num_heads == -1: |
| assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' |
|
|
| if num_head_channels == -1: |
| assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' |
|
|
| self.image_size = image_size |
| self.in_channels = in_channels |
| self.model_channels = model_channels |
| self.out_channels = out_channels |
| if isinstance(num_res_blocks, int): |
| self.num_res_blocks = len(channel_mult) * [num_res_blocks] |
| else: |
| if len(num_res_blocks) != len(channel_mult): |
| raise ValueError("provide num_res_blocks either as an int (globally constant) or " |
| "as a list/tuple (per-level) with the same length as channel_mult") |
| self.num_res_blocks = num_res_blocks |
| if disable_self_attentions is not None: |
| |
| assert len(disable_self_attentions) == len(channel_mult) |
| if num_attention_blocks is not None: |
| assert len(num_attention_blocks) == len(self.num_res_blocks) |
| assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) |
| print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " |
| f"This option has LESS priority than attention_resolutions {attention_resolutions}, " |
| f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " |
| f"attention will still not be set.") |
|
|
| self.attention_resolutions = attention_resolutions |
| self.dropout = dropout |
| self.channel_mult = channel_mult |
| self.conv_resample = conv_resample |
| self.num_classes = num_classes |
| self.use_checkpoint = use_checkpoint |
| self.dtype = th.float16 if use_fp16 else th.float32 |
| self.dtype = th.bfloat16 if use_bf16 else self.dtype |
| self.num_heads = num_heads |
| self.num_head_channels = num_head_channels |
| self.num_heads_upsample = num_heads_upsample |
| self.predict_codebook_ids = n_embed is not None |
|
|
| time_embed_dim = model_channels * 4 |
| self.time_embed = nn.Sequential( |
| linear(model_channels, time_embed_dim), |
| nn.SiLU(), |
| linear(time_embed_dim, time_embed_dim), |
| ) |
|
|
| if self.num_classes is not None: |
| if isinstance(self.num_classes, int): |
| self.label_emb = nn.Embedding(num_classes, time_embed_dim) |
| elif self.num_classes == "continuous": |
| print("setting up linear c_adm embedding layer") |
| self.label_emb = Linear(1, time_embed_dim) |
| elif self.num_classes == "sequential": |
| assert adm_in_channels is not None |
| self.label_emb = nn.Sequential( |
| nn.Sequential( |
| linear(adm_in_channels, time_embed_dim), |
| nn.SiLU(), |
| linear(time_embed_dim, time_embed_dim), |
| ) |
| ) |
| else: |
| raise ValueError() |
|
|
| self.input_blocks = nn.ModuleList( |
| [ |
| TimestepEmbedSequential( |
| conv_nd(dims, in_channels, model_channels, 3, padding=1) |
| ) |
| ] |
| ) |
| self._feature_size = model_channels |
| input_block_chans = [model_channels] |
| ch = model_channels |
| ds = 1 |
| for level, mult in enumerate(channel_mult): |
| for nr in range(self.num_res_blocks[level]): |
| layers = [ |
| ResBlock( |
| ch, |
| time_embed_dim, |
| dropout, |
| out_channels=mult * model_channels, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| ) |
| ] |
| ch = mult * model_channels |
| if ds in attention_resolutions: |
| if num_head_channels == -1: |
| dim_head = ch // num_heads |
| else: |
| num_heads = ch // num_head_channels |
| dim_head = num_head_channels |
| if legacy: |
| |
| dim_head = ch // num_heads if use_spatial_transformer else num_head_channels |
| if exists(disable_self_attentions): |
| disabled_sa = disable_self_attentions[level] |
| else: |
| disabled_sa = False |
|
|
| if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: |
| layers.append( |
| AttentionBlock( |
| ch, |
| use_checkpoint=use_checkpoint, |
| num_heads=num_heads, |
| num_head_channels=dim_head, |
| use_new_attention_order=use_new_attention_order, |
| ) if not use_spatial_transformer else SpatialTransformer( |
| ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, |
| disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, |
| use_checkpoint=use_checkpoint |
| ) |
| ) |
| self.input_blocks.append(TimestepEmbedSequential(*layers)) |
| self._feature_size += ch |
| input_block_chans.append(ch) |
| if level != len(channel_mult) - 1: |
| out_ch = ch |
| self.input_blocks.append( |
| TimestepEmbedSequential( |
| ResBlock( |
| ch, |
| time_embed_dim, |
| dropout, |
| out_channels=out_ch, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| down=True, |
| ) |
| if resblock_updown |
| else Downsample( |
| ch, conv_resample, dims=dims, out_channels=out_ch |
| ) |
| ) |
| ) |
| ch = out_ch |
| input_block_chans.append(ch) |
| ds *= 2 |
| self._feature_size += ch |
|
|
| if num_head_channels == -1: |
| dim_head = ch // num_heads |
| else: |
| num_heads = ch // num_head_channels |
| dim_head = num_head_channels |
| if legacy: |
| |
| dim_head = ch // num_heads if use_spatial_transformer else num_head_channels |
| self.middle_block = TimestepEmbedSequential( |
| ResBlock( |
| ch, |
| time_embed_dim, |
| dropout, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| ), |
| AttentionBlock( |
| ch, |
| use_checkpoint=use_checkpoint, |
| num_heads=num_heads, |
| num_head_channels=dim_head, |
| use_new_attention_order=use_new_attention_order, |
| ) if not use_spatial_transformer else SpatialTransformer( |
| ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, |
| disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, |
| use_checkpoint=use_checkpoint |
| ), |
| ResBlock( |
| ch, |
| time_embed_dim, |
| dropout, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| ), |
| ) |
| self._feature_size += ch |
|
|
| self.output_blocks = nn.ModuleList([]) |
| for level, mult in list(enumerate(channel_mult))[::-1]: |
| for i in range(self.num_res_blocks[level] + 1): |
| ich = input_block_chans.pop() |
| layers = [ |
| ResBlock( |
| ch + ich, |
| time_embed_dim, |
| dropout, |
| out_channels=model_channels * mult, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| ) |
| ] |
| ch = model_channels * mult |
| if ds in attention_resolutions: |
| if num_head_channels == -1: |
| dim_head = ch // num_heads |
| else: |
| num_heads = ch // num_head_channels |
| dim_head = num_head_channels |
| if legacy: |
| |
| dim_head = ch // num_heads if use_spatial_transformer else num_head_channels |
| if exists(disable_self_attentions): |
| disabled_sa = disable_self_attentions[level] |
| else: |
| disabled_sa = False |
|
|
| if not exists(num_attention_blocks) or i < num_attention_blocks[level]: |
| layers.append( |
| AttentionBlock( |
| ch, |
| use_checkpoint=use_checkpoint, |
| num_heads=num_heads_upsample, |
| num_head_channels=dim_head, |
| use_new_attention_order=use_new_attention_order, |
| ) if not use_spatial_transformer else SpatialTransformer( |
| ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, |
| disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, |
| use_checkpoint=use_checkpoint |
| ) |
| ) |
| if level and i == self.num_res_blocks[level]: |
| out_ch = ch |
| layers.append( |
| ResBlock( |
| ch, |
| time_embed_dim, |
| dropout, |
| out_channels=out_ch, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| up=True, |
| ) |
| if resblock_updown |
| else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) |
| ) |
| ds //= 2 |
| self.output_blocks.append(TimestepEmbedSequential(*layers)) |
| self._feature_size += ch |
|
|
| self.out = nn.Sequential( |
| normalization(ch), |
| nn.SiLU(), |
| zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), |
| ) |
| if self.predict_codebook_ids: |
| self.id_predictor = nn.Sequential( |
| normalization(ch), |
| conv_nd(dims, model_channels, n_embed, 1), |
| |
| ) |
|
|
| def convert_to_fp16(self): |
| """ |
| Convert the torso of the model to float16. |
| """ |
| self.input_blocks.apply(convert_module_to_f16) |
| self.middle_block.apply(convert_module_to_f16) |
| self.output_blocks.apply(convert_module_to_f16) |
|
|
| def convert_to_fp32(self): |
| """ |
| Convert the torso of the model to float32. |
| """ |
| self.input_blocks.apply(convert_module_to_f32) |
| self.middle_block.apply(convert_module_to_f32) |
| self.output_blocks.apply(convert_module_to_f32) |
|
|
| def forward(self, x, timesteps=None, context=None, y=None,**kwargs): |
| """ |
| Apply the model to an input batch. |
| :param x: an [N x C x ...] Tensor of inputs. |
| :param timesteps: a 1-D batch of timesteps. |
| :param context: conditioning plugged in via crossattn |
| :param y: an [N] Tensor of labels, if class-conditional. |
| :return: an [N x C x ...] Tensor of outputs. |
| """ |
| assert (y is not None) == ( |
| self.num_classes is not None |
| ), "must specify y if and only if the model is class-conditional" |
| hs = [] |
| t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) |
| emb = self.time_embed(t_emb) |
|
|
| if self.num_classes is not None: |
| assert y.shape[0] == x.shape[0] |
| emb = emb + self.label_emb(y) |
|
|
| h = x.type(self.dtype) |
| for module in self.input_blocks: |
| h = module(h, emb, context) |
| hs.append(h) |
| h = self.middle_block(h, emb, context) |
| for module in self.output_blocks: |
| h = th.cat([h, hs.pop()], dim=1) |
| h = module(h, emb, context) |
| h = h.type(x.dtype) |
| if self.predict_codebook_ids: |
| return self.id_predictor(h) |
| else: |
| return self.out(h) |
|
|
|
|
| class MultiViewUNetModel(nn.Module): |
| """ |
| The full multi-view UNet model with attention, timestep embedding and camera embedding. |
| :param in_channels: channels in the input Tensor. |
| :param model_channels: base channel count for the model. |
| :param out_channels: channels in the output Tensor. |
| :param num_res_blocks: number of residual blocks per downsample. |
| :param attention_resolutions: a collection of downsample rates at which |
| attention will take place. May be a set, list, or tuple. |
| For example, if this contains 4, then at 4x downsampling, attention |
| will be used. |
| :param dropout: the dropout probability. |
| :param channel_mult: channel multiplier for each level of the UNet. |
| :param conv_resample: if True, use learned convolutions for upsampling and |
| downsampling. |
| :param dims: determines if the signal is 1D, 2D, or 3D. |
| :param num_classes: if specified (as an int), then this model will be |
| class-conditional with `num_classes` classes. |
| :param use_checkpoint: use gradient checkpointing to reduce memory usage. |
| :param num_heads: the number of attention heads in each attention layer. |
| :param num_heads_channels: if specified, ignore num_heads and instead use |
| a fixed channel width per attention head. |
| :param num_heads_upsample: works with num_heads to set a different number |
| of heads for upsampling. Deprecated. |
| :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. |
| :param resblock_updown: use residual blocks for up/downsampling. |
| :param use_new_attention_order: use a different attention pattern for potentially |
| increased efficiency. |
| :param camera_dim: dimensionality of camera input. |
| """ |
|
|
| def __init__( |
| self, |
| image_size, |
| in_channels, |
| model_channels, |
| out_channels, |
| num_res_blocks, |
| attention_resolutions, |
| dropout=0, |
| channel_mult=(1, 2, 4, 8), |
| conv_resample=True, |
| dims=2, |
| num_classes=None, |
| use_checkpoint=False, |
| use_fp16=False, |
| use_bf16=False, |
| num_heads=-1, |
| num_head_channels=-1, |
| num_heads_upsample=-1, |
| use_scale_shift_norm=False, |
| resblock_updown=False, |
| use_new_attention_order=False, |
| use_spatial_transformer=False, |
| transformer_depth=1, |
| context_dim=None, |
| n_embed=None, |
| legacy=True, |
| disable_self_attentions=None, |
| num_attention_blocks=None, |
| disable_middle_self_attn=False, |
| use_linear_in_transformer=False, |
| adm_in_channels=None, |
| camera_dim=None, |
| ): |
| super().__init__() |
| if use_spatial_transformer: |
| assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' |
|
|
| if context_dim is not None: |
| assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' |
| from omegaconf.listconfig import ListConfig |
| if type(context_dim) == ListConfig: |
| context_dim = list(context_dim) |
|
|
| if num_heads_upsample == -1: |
| num_heads_upsample = num_heads |
|
|
| if num_heads == -1: |
| assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' |
|
|
| if num_head_channels == -1: |
| assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' |
|
|
| self.image_size = image_size |
| self.in_channels = in_channels |
| self.model_channels = model_channels |
| self.out_channels = out_channels |
| if isinstance(num_res_blocks, int): |
| self.num_res_blocks = len(channel_mult) * [num_res_blocks] |
| else: |
| if len(num_res_blocks) != len(channel_mult): |
| raise ValueError("provide num_res_blocks either as an int (globally constant) or " |
| "as a list/tuple (per-level) with the same length as channel_mult") |
| self.num_res_blocks = num_res_blocks |
| if disable_self_attentions is not None: |
| |
| assert len(disable_self_attentions) == len(channel_mult) |
| if num_attention_blocks is not None: |
| assert len(num_attention_blocks) == len(self.num_res_blocks) |
| assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) |
| print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " |
| f"This option has LESS priority than attention_resolutions {attention_resolutions}, " |
| f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " |
| f"attention will still not be set.") |
|
|
| self.attention_resolutions = attention_resolutions |
| self.dropout = dropout |
| self.channel_mult = channel_mult |
| self.conv_resample = conv_resample |
| self.num_classes = num_classes |
| self.use_checkpoint = use_checkpoint |
| self.dtype = th.float16 if use_fp16 else th.float32 |
| self.dtype = th.bfloat16 if use_bf16 else self.dtype |
| self.num_heads = num_heads |
| self.num_head_channels = num_head_channels |
| self.num_heads_upsample = num_heads_upsample |
| self.predict_codebook_ids = n_embed is not None |
|
|
| time_embed_dim = model_channels * 4 |
| self.time_embed = nn.Sequential( |
| linear(model_channels, time_embed_dim), |
| nn.SiLU(), |
| linear(time_embed_dim, time_embed_dim), |
| ) |
|
|
| if self.num_classes is not None: |
| if isinstance(self.num_classes, int): |
| self.label_emb = nn.Embedding(num_classes, time_embed_dim) |
| elif self.num_classes == "continuous": |
| print("setting up linear c_adm embedding layer") |
| self.label_emb = Linear(1, time_embed_dim) |
| elif self.num_classes == "sequential": |
| assert adm_in_channels is not None |
| self.label_emb = nn.Sequential( |
| nn.Sequential( |
| linear(adm_in_channels, time_embed_dim), |
| nn.SiLU(), |
| linear(time_embed_dim, time_embed_dim), |
| ) |
| ) |
| else: |
| raise ValueError() |
|
|
| self.input_blocks = nn.ModuleList( |
| [ |
| TimestepEmbedSequential( |
| conv_nd(dims, in_channels, model_channels, 3, padding=1) |
| ) |
| ] |
| ) |
| self._feature_size = model_channels |
| input_block_chans = [model_channels] |
| ch = model_channels |
| ds = 1 |
| for level, mult in enumerate(channel_mult): |
| for nr in range(self.num_res_blocks[level]): |
| layers = [ |
| ResBlock( |
| ch, |
| time_embed_dim, |
| dropout, |
| out_channels=mult * model_channels, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| ) |
| ] |
| ch = mult * model_channels |
| if ds in attention_resolutions: |
| if num_head_channels == -1: |
| dim_head = ch // num_heads |
| else: |
| num_heads = ch // num_head_channels |
| dim_head = num_head_channels |
| if legacy: |
| |
| dim_head = ch // num_heads if use_spatial_transformer else num_head_channels |
| if exists(disable_self_attentions): |
| disabled_sa = disable_self_attentions[level] |
| else: |
| disabled_sa = False |
|
|
| if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: |
| layers.append( |
| AttentionBlock( |
| ch, |
| use_checkpoint=use_checkpoint, |
| num_heads=num_heads, |
| num_head_channels=dim_head, |
| use_new_attention_order=use_new_attention_order, |
| ) if not use_spatial_transformer else SpatialTransformer3D( |
| ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, |
| disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, |
| use_checkpoint=use_checkpoint |
| ) |
| ) |
| self.input_blocks.append(TimestepEmbedSequential(*layers)) |
| self._feature_size += ch |
| input_block_chans.append(ch) |
| if level != len(channel_mult) - 1: |
| out_ch = ch |
| self.input_blocks.append( |
| TimestepEmbedSequential( |
| ResBlock( |
| ch, |
| time_embed_dim, |
| dropout, |
| out_channels=out_ch, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| down=True, |
| ) |
| if resblock_updown |
| else Downsample( |
| ch, conv_resample, dims=dims, out_channels=out_ch |
| ) |
| ) |
| ) |
| ch = out_ch |
| input_block_chans.append(ch) |
| ds *= 2 |
| self._feature_size += ch |
|
|
| if num_head_channels == -1: |
| dim_head = ch // num_heads |
| else: |
| num_heads = ch // num_head_channels |
| dim_head = num_head_channels |
| if legacy: |
| |
| dim_head = ch // num_heads if use_spatial_transformer else num_head_channels |
| self.middle_block = TimestepEmbedSequential( |
| ResBlock( |
| ch, |
| time_embed_dim, |
| dropout, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| ), |
| AttentionBlock( |
| ch, |
| use_checkpoint=use_checkpoint, |
| num_heads=num_heads, |
| num_head_channels=dim_head, |
| use_new_attention_order=use_new_attention_order, |
| ) if not use_spatial_transformer else SpatialTransformer3D( |
| ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, |
| disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, |
| use_checkpoint=use_checkpoint |
| ), |
| ResBlock( |
| ch, |
| time_embed_dim, |
| dropout, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| ), |
| ) |
| self._feature_size += ch |
|
|
| self.output_blocks = nn.ModuleList([]) |
| for level, mult in list(enumerate(channel_mult))[::-1]: |
| for i in range(self.num_res_blocks[level] + 1): |
| ich = input_block_chans.pop() |
| layers = [ |
| ResBlock( |
| ch + ich, |
| time_embed_dim, |
| dropout, |
| out_channels=model_channels * mult, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| ) |
| ] |
| ch = model_channels * mult |
| if ds in attention_resolutions: |
| if num_head_channels == -1: |
| dim_head = ch // num_heads |
| else: |
| num_heads = ch // num_head_channels |
| dim_head = num_head_channels |
| if legacy: |
| |
| dim_head = ch // num_heads if use_spatial_transformer else num_head_channels |
| if exists(disable_self_attentions): |
| disabled_sa = disable_self_attentions[level] |
| else: |
| disabled_sa = False |
|
|
| if not exists(num_attention_blocks) or i < num_attention_blocks[level]: |
| layers.append( |
| AttentionBlock( |
| ch, |
| use_checkpoint=use_checkpoint, |
| num_heads=num_heads_upsample, |
| num_head_channels=dim_head, |
| use_new_attention_order=use_new_attention_order, |
| ) if not use_spatial_transformer else SpatialTransformer3D( |
| ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, |
| disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, |
| use_checkpoint=use_checkpoint |
| ) |
| ) |
| if level and i == self.num_res_blocks[level]: |
| out_ch = ch |
| layers.append( |
| ResBlock( |
| ch, |
| time_embed_dim, |
| dropout, |
| out_channels=out_ch, |
| dims=dims, |
| use_checkpoint=use_checkpoint, |
| use_scale_shift_norm=use_scale_shift_norm, |
| up=True, |
| ) |
| if resblock_updown |
| else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) |
| ) |
| ds //= 2 |
| self.output_blocks.append(TimestepEmbedSequential(*layers)) |
| self._feature_size += ch |
|
|
| self.out = nn.Sequential( |
| normalization(ch), |
| nn.SiLU(), |
| zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), |
| ) |
| if self.predict_codebook_ids: |
| self.id_predictor = nn.Sequential( |
| normalization(ch), |
| conv_nd(dims, model_channels, n_embed, 1), |
| |
| ) |
|
|
| |
|
|
| def convert_to_fp16(self): |
| """ |
| Convert the torso of the model to float16. |
| """ |
| self.input_blocks.apply(convert_module_to_f16) |
| self.middle_block.apply(convert_module_to_f16) |
| self.output_blocks.apply(convert_module_to_f16) |
|
|
| def convert_to_fp32(self): |
| """ |
| Convert the torso of the model to float32. |
| """ |
| self.input_blocks.apply(convert_module_to_f32) |
| self.middle_block.apply(convert_module_to_f32) |
| self.output_blocks.apply(convert_module_to_f32) |
|
|
| def forward(self, x, timesteps=None, context=None, y=None, num_frames=1, **kwargs): |
| """ |
| Apply the model to an input batch. |
| :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). |
| :param timesteps: a 1-D batch of timesteps. |
| :param context: conditioning plugged in via crossattn |
| :param y: an [N] Tensor of labels, if class-conditional. |
| :param num_frames: a integer indicating number of frames for tensor reshaping. |
| :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). |
| """ |
| assert x.shape[0] % num_frames == 0, "[UNet] input batch size must be dividable by num_frames!" |
| assert (y is not None) == ( |
| self.num_classes is not None |
| ), "must specify y if and only if the model is class-conditional" |
| hs = [] |
| t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) |
| emb = self.time_embed(t_emb) |
|
|
| if self.num_classes is not None: |
| assert y.shape[0] == x.shape[0] |
| emb = emb + self.label_emb(y.to(self.dtype)) |
|
|
| h = x.to(self.dtype) |
| context = context.to(self.dtype) |
| for module in self.input_blocks: |
| h = module(h, emb, context, num_frames=num_frames) |
| hs.append(h) |
| h = self.middle_block(h, emb, context, num_frames=num_frames) |
| for module in self.output_blocks: |
| h = th.cat([h, hs.pop()], dim=1) |
| h = module(h, emb, context, num_frames=num_frames) |
| h = h |
| return self.out(h) |