| |
| |
|
|
| import math |
| from inspect import isfunction |
| from math import ceil, floor, log, pi, log2 |
| from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union |
| from packaging import version |
|
|
| import torch |
| import torch.nn as nn |
| from einops import rearrange, reduce, repeat |
| from einops.layers.torch import Rearrange |
| from einops_exts import rearrange_many |
| from torch import Tensor, einsum |
| from torch.backends.cuda import sdp_kernel |
| from torch.nn import functional as F |
| from dac.nn.layers import Snake1d |
|
|
| """ |
| Utils |
| """ |
|
|
|
|
| class ConditionedSequential(nn.Module): |
| def __init__(self, *modules): |
| super().__init__() |
| self.module_list = nn.ModuleList(*modules) |
|
|
| def forward(self, x: Tensor, mapping: Optional[Tensor] = None): |
| for module in self.module_list: |
| x = module(x, mapping) |
| return x |
|
|
| T = TypeVar("T") |
|
|
| def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: |
| if exists(val): |
| return val |
| return d() if isfunction(d) else d |
|
|
| def exists(val: Optional[T]) -> T: |
| return val is not None |
|
|
| def closest_power_2(x: float) -> int: |
| exponent = log2(x) |
| distance_fn = lambda z: abs(x - 2 ** z) |
| exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) |
| return 2 ** int(exponent_closest) |
|
|
| def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: |
| return_dicts: Tuple[Dict, Dict] = ({}, {}) |
| for key in d.keys(): |
| no_prefix = int(not key.startswith(prefix)) |
| return_dicts[no_prefix][key] = d[key] |
| return return_dicts |
|
|
| def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: |
| kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) |
| if keep_prefix: |
| return kwargs_with_prefix, kwargs |
| kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} |
| return kwargs_no_prefix, kwargs |
|
|
| """ |
| Convolutional Blocks |
| """ |
| import typing as tp |
|
|
| |
| |
|
|
| def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, |
| padding_total: int = 0) -> int: |
| """See `pad_for_conv1d`.""" |
| length = x.shape[-1] |
| n_frames = (length - kernel_size + padding_total) / stride + 1 |
| ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) |
| return ideal_length - length |
|
|
|
|
| def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): |
| """Pad for a convolution to make sure that the last window is full. |
| Extra padding is added at the end. This is required to ensure that we can rebuild |
| an output of the same length, as otherwise, even with padding, some time steps |
| might get removed. |
| For instance, with total padding = 4, kernel size = 4, stride = 2: |
| 0 0 1 2 3 4 5 0 0 # (0s are padding) |
| 1 2 3 # (output frames of a convolution, last 0 is never used) |
| 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) |
| 1 2 3 4 # once you removed padding, we are missing one time step ! |
| """ |
| extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) |
| return F.pad(x, (0, extra_padding)) |
|
|
|
|
| def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): |
| """Tiny wrapper around F.pad, just to allow for reflect padding on small input. |
| If this is the case, we insert extra 0 padding to the right before the reflection happen. |
| """ |
| length = x.shape[-1] |
| padding_left, padding_right = paddings |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) |
| if mode == 'reflect': |
| max_pad = max(padding_left, padding_right) |
| extra_pad = 0 |
| if length <= max_pad: |
| extra_pad = max_pad - length + 1 |
| x = F.pad(x, (0, extra_pad)) |
| padded = F.pad(x, paddings, mode, value) |
| end = padded.shape[-1] - extra_pad |
| return padded[..., :end] |
| else: |
| return F.pad(x, paddings, mode, value) |
|
|
|
|
| def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): |
| """Remove padding from x, handling properly zero padding. Only for 1d!""" |
| padding_left, padding_right = paddings |
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) |
| assert (padding_left + padding_right) <= x.shape[-1] |
| end = x.shape[-1] - padding_right |
| return x[..., padding_left: end] |
|
|
|
|
| class Conv1d(nn.Conv1d): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| |
| def forward(self, x: Tensor, causal=False) -> Tensor: |
| kernel_size = self.kernel_size[0] |
| stride = self.stride[0] |
| dilation = self.dilation[0] |
| kernel_size = (kernel_size - 1) * dilation + 1 |
| padding_total = kernel_size - stride |
| extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) |
| if causal: |
| |
| x = pad1d(x, (padding_total, extra_padding)) |
| else: |
| |
| padding_right = padding_total // 2 |
| padding_left = padding_total - padding_right |
| x = pad1d(x, (padding_left, padding_right + extra_padding)) |
| return super().forward(x) |
| |
| class ConvTranspose1d(nn.ConvTranspose1d): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def forward(self, x: Tensor, causal=False) -> Tensor: |
| kernel_size = self.kernel_size[0] |
| stride = self.stride[0] |
| padding_total = kernel_size - stride |
|
|
| y = super().forward(x) |
|
|
| |
| |
| |
| |
| if causal: |
| padding_right = ceil(padding_total) |
| padding_left = padding_total - padding_right |
| y = unpad1d(y, (padding_left, padding_right)) |
| else: |
| |
| padding_right = padding_total // 2 |
| padding_left = padding_total - padding_right |
| y = unpad1d(y, (padding_left, padding_right)) |
| return y |
| |
|
|
| def Downsample1d( |
| in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 |
| ) -> nn.Module: |
| assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" |
|
|
| return Conv1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=factor * kernel_multiplier + 1, |
| stride=factor |
| ) |
|
|
|
|
| def Upsample1d( |
| in_channels: int, out_channels: int, factor: int, use_nearest: bool = False |
| ) -> nn.Module: |
|
|
| if factor == 1: |
| return Conv1d( |
| in_channels=in_channels, out_channels=out_channels, kernel_size=3 |
| ) |
|
|
| if use_nearest: |
| return nn.Sequential( |
| nn.Upsample(scale_factor=factor, mode="nearest"), |
| Conv1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=3 |
| ), |
| ) |
| else: |
| return ConvTranspose1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=factor * 2, |
| stride=factor |
| ) |
|
|
|
|
| class ConvBlock1d(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| *, |
| kernel_size: int = 3, |
| stride: int = 1, |
| dilation: int = 1, |
| num_groups: int = 8, |
| use_norm: bool = True, |
| use_snake: bool = False |
| ) -> None: |
| super().__init__() |
|
|
| self.groupnorm = ( |
| nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) |
| if use_norm |
| else nn.Identity() |
| ) |
|
|
| if use_snake: |
| self.activation = Snake1d(in_channels) |
| else: |
| self.activation = nn.SiLU() |
|
|
| self.project = Conv1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| dilation=dilation, |
| ) |
|
|
| def forward( |
| self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False |
| ) -> Tensor: |
| x = self.groupnorm(x) |
| if exists(scale_shift): |
| scale, shift = scale_shift |
| x = x * (scale + 1) + shift |
| x = self.activation(x) |
| return self.project(x, causal=causal) |
|
|
|
|
| class MappingToScaleShift(nn.Module): |
| def __init__( |
| self, |
| features: int, |
| channels: int, |
| ): |
| super().__init__() |
|
|
| self.to_scale_shift = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(in_features=features, out_features=channels * 2), |
| ) |
|
|
| def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: |
| scale_shift = self.to_scale_shift(mapping) |
| scale_shift = rearrange(scale_shift, "b c -> b c 1") |
| scale, shift = scale_shift.chunk(2, dim=1) |
| return scale, shift |
|
|
|
|
| class ResnetBlock1d(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| *, |
| kernel_size: int = 3, |
| stride: int = 1, |
| dilation: int = 1, |
| use_norm: bool = True, |
| use_snake: bool = False, |
| num_groups: int = 8, |
| context_mapping_features: Optional[int] = None, |
| ) -> None: |
| super().__init__() |
|
|
| self.use_mapping = exists(context_mapping_features) |
|
|
| self.block1 = ConvBlock1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| dilation=dilation, |
| use_norm=use_norm, |
| num_groups=num_groups, |
| use_snake=use_snake |
| ) |
|
|
| if self.use_mapping: |
| assert exists(context_mapping_features) |
| self.to_scale_shift = MappingToScaleShift( |
| features=context_mapping_features, channels=out_channels |
| ) |
|
|
| self.block2 = ConvBlock1d( |
| in_channels=out_channels, |
| out_channels=out_channels, |
| use_norm=use_norm, |
| num_groups=num_groups, |
| use_snake=use_snake |
| ) |
|
|
| self.to_out = ( |
| Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) |
| if in_channels != out_channels |
| else nn.Identity() |
| ) |
|
|
| def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: |
| assert_message = "context mapping required if context_mapping_features > 0" |
| assert not (self.use_mapping ^ exists(mapping)), assert_message |
|
|
| h = self.block1(x, causal=causal) |
|
|
| scale_shift = None |
| if self.use_mapping: |
| scale_shift = self.to_scale_shift(mapping) |
|
|
| h = self.block2(h, scale_shift=scale_shift, causal=causal) |
|
|
| return h + self.to_out(x) |
|
|
|
|
| class Patcher(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| patch_size: int, |
| context_mapping_features: Optional[int] = None, |
| use_snake: bool = False, |
| ): |
| super().__init__() |
| assert_message = f"out_channels must be divisible by patch_size ({patch_size})" |
| assert out_channels % patch_size == 0, assert_message |
| self.patch_size = patch_size |
|
|
| self.block = ResnetBlock1d( |
| in_channels=in_channels, |
| out_channels=out_channels // patch_size, |
| num_groups=1, |
| context_mapping_features=context_mapping_features, |
| use_snake=use_snake |
| ) |
|
|
| def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: |
| x = self.block(x, mapping, causal=causal) |
| x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) |
| return x |
|
|
|
|
| class Unpatcher(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| patch_size: int, |
| context_mapping_features: Optional[int] = None, |
| use_snake: bool = False |
| ): |
| super().__init__() |
| assert_message = f"in_channels must be divisible by patch_size ({patch_size})" |
| assert in_channels % patch_size == 0, assert_message |
| self.patch_size = patch_size |
|
|
| self.block = ResnetBlock1d( |
| in_channels=in_channels // patch_size, |
| out_channels=out_channels, |
| num_groups=1, |
| context_mapping_features=context_mapping_features, |
| use_snake=use_snake |
| ) |
|
|
| def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: |
| x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) |
| x = self.block(x, mapping, causal=causal) |
| return x |
|
|
|
|
| """ |
| Attention Components |
| """ |
| def FeedForward(features: int, multiplier: int) -> nn.Module: |
| mid_features = features * multiplier |
| return nn.Sequential( |
| nn.Linear(in_features=features, out_features=mid_features), |
| nn.GELU(), |
| nn.Linear(in_features=mid_features, out_features=features), |
| ) |
|
|
| def add_mask(sim: Tensor, mask: Tensor) -> Tensor: |
| b, ndim = sim.shape[0], mask.ndim |
| if ndim == 3: |
| mask = rearrange(mask, "b n m -> b 1 n m") |
| if ndim == 2: |
| mask = repeat(mask, "n m -> b 1 n m", b=b) |
| max_neg_value = -torch.finfo(sim.dtype).max |
| sim = sim.masked_fill(~mask, max_neg_value) |
| return sim |
|
|
| def causal_mask(q: Tensor, k: Tensor) -> Tensor: |
| b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device |
| mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) |
| mask = repeat(mask, "n m -> b n m", b=b) |
| return mask |
|
|
| class AttentionBase(nn.Module): |
| def __init__( |
| self, |
| features: int, |
| *, |
| head_features: int, |
| num_heads: int, |
| out_features: Optional[int] = None, |
| ): |
| super().__init__() |
| self.scale = head_features**-0.5 |
| self.num_heads = num_heads |
| mid_features = head_features * num_heads |
| out_features = default(out_features, features) |
|
|
| self.to_out = nn.Linear( |
| in_features=mid_features, out_features=out_features |
| ) |
|
|
| self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') |
|
|
| if not self.use_flash: |
| return |
|
|
| device_properties = torch.cuda.get_device_properties(torch.device('cuda')) |
|
|
| if device_properties.major == 8 and device_properties.minor == 0: |
| |
| self.sdp_kernel_config = (True, False, False) |
| else: |
| |
| self.sdp_kernel_config = (False, True, True) |
|
|
| def forward( |
| self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False |
| ) -> Tensor: |
| |
| q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) |
|
|
| if not self.use_flash: |
| if is_causal and not mask: |
| |
| mask = causal_mask(q, k) |
|
|
| |
| sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale |
| sim = add_mask(sim, mask) if exists(mask) else sim |
|
|
| |
| attn = sim.softmax(dim=-1, dtype=torch.float32) |
|
|
| |
| out = einsum("... n m, ... m d -> ... n d", attn, v) |
| else: |
| with sdp_kernel(*self.sdp_kernel_config): |
| out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) |
|
|
| out = rearrange(out, "b h n d -> b n (h d)") |
| return self.to_out(out) |
|
|
| class Attention(nn.Module): |
| def __init__( |
| self, |
| features: int, |
| *, |
| head_features: int, |
| num_heads: int, |
| out_features: Optional[int] = None, |
| context_features: Optional[int] = None, |
| causal: bool = False, |
| ): |
| super().__init__() |
| self.context_features = context_features |
| self.causal = causal |
| mid_features = head_features * num_heads |
| context_features = default(context_features, features) |
|
|
| self.norm = nn.LayerNorm(features) |
| self.norm_context = nn.LayerNorm(context_features) |
| self.to_q = nn.Linear( |
| in_features=features, out_features=mid_features, bias=False |
| ) |
| self.to_kv = nn.Linear( |
| in_features=context_features, out_features=mid_features * 2, bias=False |
| ) |
| self.attention = AttentionBase( |
| features, |
| num_heads=num_heads, |
| head_features=head_features, |
| out_features=out_features, |
| ) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| context: Optional[Tensor] = None, |
| context_mask: Optional[Tensor] = None, |
| causal: Optional[bool] = False, |
| ) -> Tensor: |
| assert_message = "You must provide a context when using context_features" |
| assert not self.context_features or exists(context), assert_message |
| |
| context = default(context, x) |
| |
| x, context = self.norm(x), self.norm_context(context) |
|
|
| q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) |
|
|
| if exists(context_mask): |
| |
| mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) |
| k, v = k * mask, v * mask |
|
|
| |
| return self.attention(q, k, v, is_causal=self.causal or causal) |
|
|
|
|
| def FeedForward(features: int, multiplier: int) -> nn.Module: |
| mid_features = features * multiplier |
| return nn.Sequential( |
| nn.Linear(in_features=features, out_features=mid_features), |
| nn.GELU(), |
| nn.Linear(in_features=mid_features, out_features=features), |
| ) |
|
|
| """ |
| Transformer Blocks |
| """ |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__( |
| self, |
| features: int, |
| num_heads: int, |
| head_features: int, |
| multiplier: int, |
| context_features: Optional[int] = None, |
| ): |
| super().__init__() |
|
|
| self.use_cross_attention = exists(context_features) and context_features > 0 |
|
|
| self.attention = Attention( |
| features=features, |
| num_heads=num_heads, |
| head_features=head_features |
| ) |
|
|
| if self.use_cross_attention: |
| self.cross_attention = Attention( |
| features=features, |
| num_heads=num_heads, |
| head_features=head_features, |
| context_features=context_features |
| ) |
|
|
| self.feed_forward = FeedForward(features=features, multiplier=multiplier) |
|
|
| def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: |
| x = self.attention(x, causal=causal) + x |
| if self.use_cross_attention: |
| x = self.cross_attention(x, context=context, context_mask=context_mask) + x |
| x = self.feed_forward(x) + x |
| return x |
|
|
|
|
| """ |
| Transformers |
| """ |
|
|
|
|
| class Transformer1d(nn.Module): |
| def __init__( |
| self, |
| num_layers: int, |
| channels: int, |
| num_heads: int, |
| head_features: int, |
| multiplier: int, |
| context_features: Optional[int] = None, |
| ): |
| super().__init__() |
|
|
| self.to_in = nn.Sequential( |
| nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), |
| Conv1d( |
| in_channels=channels, |
| out_channels=channels, |
| kernel_size=1, |
| ), |
| Rearrange("b c t -> b t c"), |
| ) |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| TransformerBlock( |
| features=channels, |
| head_features=head_features, |
| num_heads=num_heads, |
| multiplier=multiplier, |
| context_features=context_features, |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
|
|
| self.to_out = nn.Sequential( |
| Rearrange("b t c -> b c t"), |
| Conv1d( |
| in_channels=channels, |
| out_channels=channels, |
| kernel_size=1, |
| ), |
| ) |
|
|
| def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: |
| x = self.to_in(x) |
| for block in self.blocks: |
| x = block(x, context=context, context_mask=context_mask, causal=causal) |
| x = self.to_out(x) |
| return x |
|
|
|
|
| """ |
| Time Embeddings |
| """ |
|
|
|
|
| class SinusoidalEmbedding(nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| device, half_dim = x.device, self.dim // 2 |
| emb = torch.tensor(log(10000) / (half_dim - 1), device=device) |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
| emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") |
| return torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
|
|
|
| class LearnedPositionalEmbedding(nn.Module): |
| """Used for continuous time""" |
|
|
| def __init__(self, dim: int): |
| super().__init__() |
| assert (dim % 2) == 0 |
| half_dim = dim // 2 |
| self.weights = nn.Parameter(torch.randn(half_dim)) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = rearrange(x, "b -> b 1") |
| freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi |
| fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) |
| fouriered = torch.cat((x, fouriered), dim=-1) |
| return fouriered |
|
|
|
|
| def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: |
| return nn.Sequential( |
| LearnedPositionalEmbedding(dim), |
| nn.Linear(in_features=dim + 1, out_features=out_features), |
| ) |
|
|
|
|
| """ |
| Encoder/Decoder Components |
| """ |
|
|
|
|
| class DownsampleBlock1d(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| *, |
| factor: int, |
| num_groups: int, |
| num_layers: int, |
| kernel_multiplier: int = 2, |
| use_pre_downsample: bool = True, |
| use_skip: bool = False, |
| use_snake: bool = False, |
| extract_channels: int = 0, |
| context_channels: int = 0, |
| num_transformer_blocks: int = 0, |
| attention_heads: Optional[int] = None, |
| attention_features: Optional[int] = None, |
| attention_multiplier: Optional[int] = None, |
| context_mapping_features: Optional[int] = None, |
| context_embedding_features: Optional[int] = None, |
| ): |
| super().__init__() |
| self.use_pre_downsample = use_pre_downsample |
| self.use_skip = use_skip |
| self.use_transformer = num_transformer_blocks > 0 |
| self.use_extract = extract_channels > 0 |
| self.use_context = context_channels > 0 |
|
|
| channels = out_channels if use_pre_downsample else in_channels |
|
|
| self.downsample = Downsample1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| factor=factor, |
| kernel_multiplier=kernel_multiplier, |
| ) |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| ResnetBlock1d( |
| in_channels=channels + context_channels if i == 0 else channels, |
| out_channels=channels, |
| num_groups=num_groups, |
| context_mapping_features=context_mapping_features, |
| use_snake=use_snake |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
|
|
| if self.use_transformer: |
| assert ( |
| (exists(attention_heads) or exists(attention_features)) |
| and exists(attention_multiplier) |
| ) |
|
|
| if attention_features is None and attention_heads is not None: |
| attention_features = channels // attention_heads |
|
|
| if attention_heads is None and attention_features is not None: |
| attention_heads = channels // attention_features |
|
|
| self.transformer = Transformer1d( |
| num_layers=num_transformer_blocks, |
| channels=channels, |
| num_heads=attention_heads, |
| head_features=attention_features, |
| multiplier=attention_multiplier, |
| context_features=context_embedding_features |
| ) |
|
|
| if self.use_extract: |
| num_extract_groups = min(num_groups, extract_channels) |
| self.to_extracted = ResnetBlock1d( |
| in_channels=out_channels, |
| out_channels=extract_channels, |
| num_groups=num_extract_groups, |
| use_snake=use_snake |
| ) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| *, |
| mapping: Optional[Tensor] = None, |
| channels: Optional[Tensor] = None, |
| embedding: Optional[Tensor] = None, |
| embedding_mask: Optional[Tensor] = None, |
| causal: Optional[bool] = False |
| ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: |
|
|
| if self.use_pre_downsample: |
| x = self.downsample(x) |
|
|
| if self.use_context and exists(channels): |
| x = torch.cat([x, channels], dim=1) |
|
|
| skips = [] |
| for block in self.blocks: |
| x = block(x, mapping=mapping, causal=causal) |
| skips += [x] if self.use_skip else [] |
|
|
| if self.use_transformer: |
| x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) |
| skips += [x] if self.use_skip else [] |
|
|
| if not self.use_pre_downsample: |
| x = self.downsample(x) |
|
|
| if self.use_extract: |
| extracted = self.to_extracted(x) |
| return x, extracted |
|
|
| return (x, skips) if self.use_skip else x |
|
|
|
|
| class UpsampleBlock1d(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| *, |
| factor: int, |
| num_layers: int, |
| num_groups: int, |
| use_nearest: bool = False, |
| use_pre_upsample: bool = False, |
| use_skip: bool = False, |
| use_snake: bool = False, |
| skip_channels: int = 0, |
| use_skip_scale: bool = False, |
| extract_channels: int = 0, |
| num_transformer_blocks: int = 0, |
| attention_heads: Optional[int] = None, |
| attention_features: Optional[int] = None, |
| attention_multiplier: Optional[int] = None, |
| context_mapping_features: Optional[int] = None, |
| context_embedding_features: Optional[int] = None, |
| ): |
| super().__init__() |
|
|
| self.use_extract = extract_channels > 0 |
| self.use_pre_upsample = use_pre_upsample |
| self.use_transformer = num_transformer_blocks > 0 |
| self.use_skip = use_skip |
| self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 |
|
|
| channels = out_channels if use_pre_upsample else in_channels |
|
|
| self.blocks = nn.ModuleList( |
| [ |
| ResnetBlock1d( |
| in_channels=channels + skip_channels, |
| out_channels=channels, |
| num_groups=num_groups, |
| context_mapping_features=context_mapping_features, |
| use_snake=use_snake |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| if self.use_transformer: |
| assert ( |
| (exists(attention_heads) or exists(attention_features)) |
| and exists(attention_multiplier) |
| ) |
|
|
| if attention_features is None and attention_heads is not None: |
| attention_features = channels // attention_heads |
|
|
| if attention_heads is None and attention_features is not None: |
| attention_heads = channels // attention_features |
|
|
| self.transformer = Transformer1d( |
| num_layers=num_transformer_blocks, |
| channels=channels, |
| num_heads=attention_heads, |
| head_features=attention_features, |
| multiplier=attention_multiplier, |
| context_features=context_embedding_features, |
| ) |
|
|
| self.upsample = Upsample1d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| factor=factor, |
| use_nearest=use_nearest, |
| ) |
|
|
| if self.use_extract: |
| num_extract_groups = min(num_groups, extract_channels) |
| self.to_extracted = ResnetBlock1d( |
| in_channels=out_channels, |
| out_channels=extract_channels, |
| num_groups=num_extract_groups, |
| use_snake=use_snake |
| ) |
|
|
| def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: |
| return torch.cat([x, skip * self.skip_scale], dim=1) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| *, |
| skips: Optional[List[Tensor]] = None, |
| mapping: Optional[Tensor] = None, |
| embedding: Optional[Tensor] = None, |
| embedding_mask: Optional[Tensor] = None, |
| causal: Optional[bool] = False |
| ) -> Union[Tuple[Tensor, Tensor], Tensor]: |
|
|
| if self.use_pre_upsample: |
| x = self.upsample(x) |
|
|
| for block in self.blocks: |
| x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x |
| x = block(x, mapping=mapping, causal=causal) |
|
|
| if self.use_transformer: |
| x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) |
|
|
| if not self.use_pre_upsample: |
| x = self.upsample(x) |
|
|
| if self.use_extract: |
| extracted = self.to_extracted(x) |
| return x, extracted |
|
|
| return x |
|
|
|
|
| class BottleneckBlock1d(nn.Module): |
| def __init__( |
| self, |
| channels: int, |
| *, |
| num_groups: int, |
| num_transformer_blocks: int = 0, |
| attention_heads: Optional[int] = None, |
| attention_features: Optional[int] = None, |
| attention_multiplier: Optional[int] = None, |
| context_mapping_features: Optional[int] = None, |
| context_embedding_features: Optional[int] = None, |
| use_snake: bool = False, |
| ): |
| super().__init__() |
| self.use_transformer = num_transformer_blocks > 0 |
|
|
| self.pre_block = ResnetBlock1d( |
| in_channels=channels, |
| out_channels=channels, |
| num_groups=num_groups, |
| context_mapping_features=context_mapping_features, |
| use_snake=use_snake |
| ) |
|
|
| if self.use_transformer: |
| assert ( |
| (exists(attention_heads) or exists(attention_features)) |
| and exists(attention_multiplier) |
| ) |
|
|
| if attention_features is None and attention_heads is not None: |
| attention_features = channels // attention_heads |
|
|
| if attention_heads is None and attention_features is not None: |
| attention_heads = channels // attention_features |
|
|
| self.transformer = Transformer1d( |
| num_layers=num_transformer_blocks, |
| channels=channels, |
| num_heads=attention_heads, |
| head_features=attention_features, |
| multiplier=attention_multiplier, |
| context_features=context_embedding_features, |
| ) |
|
|
| self.post_block = ResnetBlock1d( |
| in_channels=channels, |
| out_channels=channels, |
| num_groups=num_groups, |
| context_mapping_features=context_mapping_features, |
| use_snake=use_snake |
| ) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| *, |
| mapping: Optional[Tensor] = None, |
| embedding: Optional[Tensor] = None, |
| embedding_mask: Optional[Tensor] = None, |
| causal: Optional[bool] = False |
| ) -> Tensor: |
| x = self.pre_block(x, mapping=mapping, causal=causal) |
| if self.use_transformer: |
| x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) |
| x = self.post_block(x, mapping=mapping, causal=causal) |
| return x |
|
|
|
|
| """ |
| UNet |
| """ |
|
|
|
|
| class UNet1d(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| channels: int, |
| multipliers: Sequence[int], |
| factors: Sequence[int], |
| num_blocks: Sequence[int], |
| attentions: Sequence[int], |
| patch_size: int = 1, |
| resnet_groups: int = 8, |
| use_context_time: bool = True, |
| kernel_multiplier_downsample: int = 2, |
| use_nearest_upsample: bool = False, |
| use_skip_scale: bool = True, |
| use_snake: bool = False, |
| use_stft: bool = False, |
| use_stft_context: bool = False, |
| out_channels: Optional[int] = None, |
| context_features: Optional[int] = None, |
| context_features_multiplier: int = 4, |
| context_channels: Optional[Sequence[int]] = None, |
| context_embedding_features: Optional[int] = None, |
| **kwargs, |
| ): |
| super().__init__() |
| out_channels = default(out_channels, in_channels) |
| context_channels = list(default(context_channels, [])) |
| num_layers = len(multipliers) - 1 |
| use_context_features = exists(context_features) |
| use_context_channels = len(context_channels) > 0 |
| context_mapping_features = None |
|
|
| attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) |
|
|
| self.num_layers = num_layers |
| self.use_context_time = use_context_time |
| self.use_context_features = use_context_features |
| self.use_context_channels = use_context_channels |
| self.use_stft = use_stft |
| self.use_stft_context = use_stft_context |
|
|
| self.context_features = context_features |
| context_channels_pad_length = num_layers + 1 - len(context_channels) |
| context_channels = context_channels + [0] * context_channels_pad_length |
| self.context_channels = context_channels |
| self.context_embedding_features = context_embedding_features |
|
|
| if use_context_channels: |
| has_context = [c > 0 for c in context_channels] |
| self.has_context = has_context |
| self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] |
|
|
| assert ( |
| len(factors) == num_layers |
| and len(attentions) >= num_layers |
| and len(num_blocks) == num_layers |
| ) |
|
|
| if use_context_time or use_context_features: |
| context_mapping_features = channels * context_features_multiplier |
|
|
| self.to_mapping = nn.Sequential( |
| nn.Linear(context_mapping_features, context_mapping_features), |
| nn.GELU(), |
| nn.Linear(context_mapping_features, context_mapping_features), |
| nn.GELU(), |
| ) |
|
|
| if use_context_time: |
| assert exists(context_mapping_features) |
| self.to_time = nn.Sequential( |
| TimePositionalEmbedding( |
| dim=channels, out_features=context_mapping_features |
| ), |
| nn.GELU(), |
| ) |
|
|
| if use_context_features: |
| assert exists(context_features) and exists(context_mapping_features) |
| self.to_features = nn.Sequential( |
| nn.Linear( |
| in_features=context_features, out_features=context_mapping_features |
| ), |
| nn.GELU(), |
| ) |
|
|
| if use_stft: |
| stft_kwargs, kwargs = groupby("stft_", kwargs) |
| assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" |
| stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 |
| in_channels *= stft_channels |
| out_channels *= stft_channels |
| context_channels[0] *= stft_channels if use_stft_context else 1 |
| assert exists(in_channels) and exists(out_channels) |
| self.stft = STFT(**stft_kwargs) |
|
|
| assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" |
|
|
| self.to_in = Patcher( |
| in_channels=in_channels + context_channels[0], |
| out_channels=channels * multipliers[0], |
| patch_size=patch_size, |
| context_mapping_features=context_mapping_features, |
| use_snake=use_snake |
| ) |
|
|
| self.downsamples = nn.ModuleList( |
| [ |
| DownsampleBlock1d( |
| in_channels=channels * multipliers[i], |
| out_channels=channels * multipliers[i + 1], |
| context_mapping_features=context_mapping_features, |
| context_channels=context_channels[i + 1], |
| context_embedding_features=context_embedding_features, |
| num_layers=num_blocks[i], |
| factor=factors[i], |
| kernel_multiplier=kernel_multiplier_downsample, |
| num_groups=resnet_groups, |
| use_pre_downsample=True, |
| use_skip=True, |
| use_snake=use_snake, |
| num_transformer_blocks=attentions[i], |
| **attention_kwargs, |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
|
|
| self.bottleneck = BottleneckBlock1d( |
| channels=channels * multipliers[-1], |
| context_mapping_features=context_mapping_features, |
| context_embedding_features=context_embedding_features, |
| num_groups=resnet_groups, |
| num_transformer_blocks=attentions[-1], |
| use_snake=use_snake, |
| **attention_kwargs, |
| ) |
|
|
| self.upsamples = nn.ModuleList( |
| [ |
| UpsampleBlock1d( |
| in_channels=channels * multipliers[i + 1], |
| out_channels=channels * multipliers[i], |
| context_mapping_features=context_mapping_features, |
| context_embedding_features=context_embedding_features, |
| num_layers=num_blocks[i] + (1 if attentions[i] else 0), |
| factor=factors[i], |
| use_nearest=use_nearest_upsample, |
| num_groups=resnet_groups, |
| use_skip_scale=use_skip_scale, |
| use_pre_upsample=False, |
| use_skip=True, |
| use_snake=use_snake, |
| skip_channels=channels * multipliers[i + 1], |
| num_transformer_blocks=attentions[i], |
| **attention_kwargs, |
| ) |
| for i in reversed(range(num_layers)) |
| ] |
| ) |
|
|
| self.to_out = Unpatcher( |
| in_channels=channels * multipliers[0], |
| out_channels=out_channels, |
| patch_size=patch_size, |
| context_mapping_features=context_mapping_features, |
| use_snake=use_snake |
| ) |
|
|
| def get_channels( |
| self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 |
| ) -> Optional[Tensor]: |
| """Gets context channels at `layer` and checks that shape is correct""" |
| use_context_channels = self.use_context_channels and self.has_context[layer] |
| if not use_context_channels: |
| return None |
| assert exists(channels_list), "Missing context" |
| |
| channels_id = self.channels_ids[layer] |
| |
| channels = channels_list[channels_id] |
| message = f"Missing context for layer {layer} at index {channels_id}" |
| assert exists(channels), message |
| |
| num_channels = self.context_channels[layer] |
| message = f"Expected context with {num_channels} channels at idx {channels_id}" |
| assert channels.shape[1] == num_channels, message |
| |
| channels = self.stft.encode1d(channels) if self.use_stft_context else channels |
| return channels |
|
|
| def get_mapping( |
| self, time: Optional[Tensor] = None, features: Optional[Tensor] = None |
| ) -> Optional[Tensor]: |
| """Combines context time features and features into mapping""" |
| items, mapping = [], None |
| |
| if self.use_context_time: |
| assert_message = "use_context_time=True but no time features provided" |
| assert exists(time), assert_message |
| items += [self.to_time(time)] |
| |
| if self.use_context_features: |
| assert_message = "context_features exists but no features provided" |
| assert exists(features), assert_message |
| items += [self.to_features(features)] |
| |
| if self.use_context_time or self.use_context_features: |
| mapping = reduce(torch.stack(items), "n b m -> b m", "sum") |
| mapping = self.to_mapping(mapping) |
| return mapping |
|
|
| def forward( |
| self, |
| x: Tensor, |
| time: Optional[Tensor] = None, |
| *, |
| features: Optional[Tensor] = None, |
| channels_list: Optional[Sequence[Tensor]] = None, |
| embedding: Optional[Tensor] = None, |
| embedding_mask: Optional[Tensor] = None, |
| causal: Optional[bool] = False, |
| ) -> Tensor: |
| channels = self.get_channels(channels_list, layer=0) |
| |
| x = self.stft.encode1d(x) if self.use_stft else x |
| |
| x = torch.cat([x, channels], dim=1) if exists(channels) else x |
| |
| mapping = self.get_mapping(time, features) |
| x = self.to_in(x, mapping, causal=causal) |
| skips_list = [x] |
|
|
| for i, downsample in enumerate(self.downsamples): |
| channels = self.get_channels(channels_list, layer=i + 1) |
| x, skips = downsample( |
| x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal |
| ) |
| skips_list += [skips] |
|
|
| x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) |
|
|
| for i, upsample in enumerate(self.upsamples): |
| skips = skips_list.pop() |
| x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) |
|
|
| x += skips_list.pop() |
| x = self.to_out(x, mapping, causal=causal) |
| x = self.stft.decode1d(x) if self.use_stft else x |
|
|
| return x |
|
|
|
|
| """ Conditioning Modules """ |
|
|
|
|
| class FixedEmbedding(nn.Module): |
| def __init__(self, max_length: int, features: int): |
| super().__init__() |
| self.max_length = max_length |
| self.embedding = nn.Embedding(max_length, features) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| batch_size, length, device = *x.shape[0:2], x.device |
| assert_message = "Input sequence length must be <= max_length" |
| assert length <= self.max_length, assert_message |
| position = torch.arange(length, device=device) |
| fixed_embedding = self.embedding(position) |
| fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) |
| return fixed_embedding |
|
|
|
|
| def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: |
| if proba == 1: |
| return torch.ones(shape, device=device, dtype=torch.bool) |
| elif proba == 0: |
| return torch.zeros(shape, device=device, dtype=torch.bool) |
| else: |
| return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) |
|
|
|
|
| class UNetCFG1d(UNet1d): |
|
|
| """UNet1d with Classifier-Free Guidance""" |
|
|
| def __init__( |
| self, |
| context_embedding_max_length: int, |
| context_embedding_features: int, |
| use_xattn_time: bool = False, |
| **kwargs, |
| ): |
| super().__init__( |
| context_embedding_features=context_embedding_features, **kwargs |
| ) |
|
|
| self.use_xattn_time = use_xattn_time |
|
|
| if use_xattn_time: |
| assert exists(context_embedding_features) |
| self.to_time_embedding = nn.Sequential( |
| TimePositionalEmbedding( |
| dim=kwargs["channels"], out_features=context_embedding_features |
| ), |
| nn.GELU(), |
| ) |
|
|
| context_embedding_max_length += 1 |
|
|
| self.fixed_embedding = FixedEmbedding( |
| max_length=context_embedding_max_length, features=context_embedding_features |
| ) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| time: Tensor, |
| *, |
| embedding: Tensor, |
| embedding_mask: Optional[Tensor] = None, |
| embedding_scale: float = 1.0, |
| embedding_mask_proba: float = 0.0, |
| batch_cfg: bool = False, |
| rescale_cfg: bool = False, |
| scale_phi: float = 0.4, |
| negative_embedding: Optional[Tensor] = None, |
| negative_embedding_mask: Optional[Tensor] = None, |
| **kwargs, |
| ) -> Tensor: |
| b, device = embedding.shape[0], embedding.device |
|
|
| if self.use_xattn_time: |
| embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) |
|
|
| if embedding_mask is not None: |
| embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) |
|
|
| fixed_embedding = self.fixed_embedding(embedding) |
|
|
| if embedding_mask_proba > 0.0: |
| |
| batch_mask = rand_bool( |
| shape=(b, 1, 1), proba=embedding_mask_proba, device=device |
| ) |
| embedding = torch.where(batch_mask, fixed_embedding, embedding) |
|
|
| if embedding_scale != 1.0: |
| if batch_cfg: |
| batch_x = torch.cat([x, x], dim=0) |
| batch_time = torch.cat([time, time], dim=0) |
|
|
| if negative_embedding is not None: |
| if negative_embedding_mask is not None: |
| negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) |
|
|
| negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) |
| |
| batch_embed = torch.cat([embedding, negative_embedding], dim=0) |
|
|
| else: |
| batch_embed = torch.cat([embedding, fixed_embedding], dim=0) |
|
|
| batch_mask = None |
| if embedding_mask is not None: |
| batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) |
|
|
| batch_features = None |
| features = kwargs.pop("features", None) |
| if self.use_context_features: |
| batch_features = torch.cat([features, features], dim=0) |
|
|
| batch_channels = None |
| channels_list = kwargs.pop("channels_list", None) |
| if self.use_context_channels: |
| batch_channels = [] |
| for channels in channels_list: |
| batch_channels += [torch.cat([channels, channels], dim=0)] |
|
|
| |
| batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) |
| out, out_masked = batch_out.chunk(2, dim=0) |
| |
| else: |
| |
| out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) |
| out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) |
|
|
| out_cfg = out_masked + (out - out_masked) * embedding_scale |
|
|
| if rescale_cfg: |
|
|
| out_std = out.std(dim=1, keepdim=True) |
| out_cfg_std = out_cfg.std(dim=1, keepdim=True) |
|
|
| return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg |
|
|
| else: |
|
|
| return out_cfg |
| |
| else: |
| return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) |
|
|
|
|
| class UNetNCCA1d(UNet1d): |
|
|
| """UNet1d with Noise Channel Conditioning Augmentation""" |
|
|
| def __init__(self, context_features: int, **kwargs): |
| super().__init__(context_features=context_features, **kwargs) |
| self.embedder = NumberEmbedder(features=context_features) |
|
|
| def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: |
| x = x if torch.is_tensor(x) else torch.tensor(x) |
| return x.expand(shape) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| time: Tensor, |
| *, |
| channels_list: Sequence[Tensor], |
| channels_augmentation: Union[ |
| bool, Sequence[bool], Sequence[Sequence[bool]], Tensor |
| ] = False, |
| channels_scale: Union[ |
| float, Sequence[float], Sequence[Sequence[float]], Tensor |
| ] = 0, |
| **kwargs, |
| ) -> Tensor: |
| b, n = x.shape[0], len(channels_list) |
| channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) |
| channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) |
|
|
| |
| for i in range(n): |
| scale = channels_scale[:, i] * channels_augmentation[:, i] |
| scale = rearrange(scale, "b -> b 1 1") |
| item = channels_list[i] |
| channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) |
|
|
| |
| channels_scale_emb = self.embedder(channels_scale) |
| channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") |
|
|
| return super().forward( |
| x=x, |
| time=time, |
| channels_list=channels_list, |
| features=channels_scale_emb, |
| **kwargs, |
| ) |
|
|
|
|
| class UNetAll1d(UNetCFG1d, UNetNCCA1d): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def forward(self, *args, **kwargs): |
| return UNetCFG1d.forward(self, *args, **kwargs) |
|
|
|
|
| def XUNet1d(type: str = "base", **kwargs) -> UNet1d: |
| if type == "base": |
| return UNet1d(**kwargs) |
| elif type == "all": |
| return UNetAll1d(**kwargs) |
| elif type == "cfg": |
| return UNetCFG1d(**kwargs) |
| elif type == "ncca": |
| return UNetNCCA1d(**kwargs) |
| else: |
| raise ValueError(f"Unknown XUNet1d type: {type}") |
|
|
| class NumberEmbedder(nn.Module): |
| def __init__( |
| self, |
| features: int, |
| dim: int = 256, |
| ): |
| super().__init__() |
| self.features = features |
| self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) |
|
|
| def forward(self, x: Union[List[float], Tensor]) -> Tensor: |
| if not torch.is_tensor(x): |
| device = next(self.embedding.parameters()).device |
| x = torch.tensor(x, device=device) |
| assert isinstance(x, Tensor) |
| shape = x.shape |
| x = rearrange(x, "... -> (...)") |
| embedding = self.embedding(x) |
| x = embedding.view(*shape, self.features) |
| return x |
|
|
|
|
| """ |
| Audio Transforms |
| """ |
|
|
|
|
| class STFT(nn.Module): |
| """Helper for torch stft and istft""" |
|
|
| def __init__( |
| self, |
| num_fft: int = 1023, |
| hop_length: int = 256, |
| window_length: Optional[int] = None, |
| length: Optional[int] = None, |
| use_complex: bool = False, |
| ): |
| super().__init__() |
| self.num_fft = num_fft |
| self.hop_length = default(hop_length, floor(num_fft // 4)) |
| self.window_length = default(window_length, num_fft) |
| self.length = length |
| self.register_buffer("window", torch.hann_window(self.window_length)) |
| self.use_complex = use_complex |
|
|
| def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: |
| b = wave.shape[0] |
| wave = rearrange(wave, "b c t -> (b c) t") |
|
|
| stft = torch.stft( |
| wave, |
| n_fft=self.num_fft, |
| hop_length=self.hop_length, |
| win_length=self.window_length, |
| window=self.window, |
| return_complex=True, |
| normalized=True, |
| ) |
|
|
| if self.use_complex: |
| |
| stft_a, stft_b = stft.real, stft.imag |
| else: |
| |
| magnitude, phase = torch.abs(stft), torch.angle(stft) |
| stft_a, stft_b = magnitude, phase |
|
|
| return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) |
|
|
| def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: |
| b, l = stft_a.shape[0], stft_a.shape[-1] |
| length = closest_power_2(l * self.hop_length) |
|
|
| stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") |
|
|
| if self.use_complex: |
| real, imag = stft_a, stft_b |
| else: |
| magnitude, phase = stft_a, stft_b |
| real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) |
|
|
| stft = torch.stack([real, imag], dim=-1) |
|
|
| wave = torch.istft( |
| stft, |
| n_fft=self.num_fft, |
| hop_length=self.hop_length, |
| win_length=self.window_length, |
| window=self.window, |
| length=default(self.length, length), |
| normalized=True, |
| ) |
|
|
| return rearrange(wave, "(b c) t -> b c t", b=b) |
|
|
| def encode1d( |
| self, wave: Tensor, stacked: bool = True |
| ) -> Union[Tensor, Tuple[Tensor, Tensor]]: |
| stft_a, stft_b = self.encode(wave) |
| stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") |
| return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) |
|
|
| def decode1d(self, stft_pair: Tensor) -> Tensor: |
| f = self.num_fft // 2 + 1 |
| stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) |
| stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) |
| return self.decode(stft_a, stft_b) |
|
|