| import torch |
| import torch.nn as nn |
| from einops import pack, rearrange |
|
|
| from .unit_nn import ( |
| SinusoidalPosEmb, |
| TimestepEmbedding, |
| ResnetBlock1D, |
| Block1D, |
| ) |
|
|
| from .transformer import BasicTransformerBlock |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _make_resnet(dim_in, dim_out, time_emb_dim, cond_dim): |
| return ResnetBlock1D( |
| dim=dim_in, |
| dim_out=dim_out, |
| time_emb_dim=time_emb_dim, |
| film_dim=cond_dim, |
| ) |
|
|
|
|
| def _make_transformer(dim, n_heads, head_dim, dropout, act_fn): |
| return BasicTransformerBlock( |
| dim=dim, |
| num_attention_heads=n_heads, |
| attention_head_dim=head_dim, |
| dropout=dropout, |
| activation_fn=act_fn, |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class DownBlock(nn.Module): |
| """ |
| ResNet -> n Transformer blocks -> channel-mixing Conv1d. |
| T is never halved; the conv is purely for channel interaction. |
| """ |
|
|
| def __init__( |
| self, |
| dim_in, |
| dim_out, |
| time_emb_dim, |
| cond_dim, |
| n_transformer, |
| n_heads, |
| head_dim, |
| dropout, |
| act_fn, |
| ): |
| super().__init__() |
| self.resnet = _make_resnet(dim_in, dim_out, time_emb_dim, cond_dim) |
| self.transformers = nn.ModuleList( |
| [ |
| _make_transformer(dim_out, n_heads, head_dim, dropout, act_fn) |
| for _ in range(n_transformer) |
| ] |
| ) |
| self.mix = nn.Conv1d(dim_out, dim_out, 3, padding=1) |
|
|
| def forward(self, x, t, cond=None): |
| x = self.resnet(x, t, film_cond=cond) |
| x = rearrange(x, "b c t -> b t c") |
| for block in self.transformers: |
| x = block(hidden_states=x, attention_mask=None, timestep=t) |
| x = rearrange(x, "b t c -> b c t") |
| return self.mix(x) |
|
|
|
|
| class MidBlock(nn.Module): |
| """ResNet -> n Transformer blocks. No spatial change.""" |
|
|
| def __init__( |
| self, |
| dim, |
| time_emb_dim, |
| cond_dim, |
| n_transformer, |
| n_heads, |
| head_dim, |
| dropout, |
| act_fn, |
| ): |
| super().__init__() |
| self.resnet = _make_resnet(dim, dim, time_emb_dim, cond_dim) |
| self.transformers = nn.ModuleList( |
| [ |
| _make_transformer(dim, n_heads, head_dim, dropout, act_fn) |
| for _ in range(n_transformer) |
| ] |
| ) |
|
|
| def forward(self, x, t, cond=None): |
| x = self.resnet(x, t, film_cond=cond) |
| x = rearrange(x, "b c t -> b t c") |
| for block in self.transformers: |
| x = block(hidden_states=x, attention_mask=None, timestep=t) |
| return rearrange(x, "b t c -> b c t") |
|
|
|
|
| class UpBlock(nn.Module): |
| """ |
| Skip-connection concat -> ResNet -> n Transformer blocks -> Conv1d. |
| dim_in is the channel dim of the skip, not the doubled dim; |
| the doubling is handled internally via ResnetBlock1D(dim=2*dim_in). |
| """ |
|
|
| def __init__( |
| self, |
| dim_in, |
| dim_out, |
| time_emb_dim, |
| cond_dim, |
| n_transformer, |
| n_heads, |
| head_dim, |
| dropout, |
| act_fn, |
| ): |
| super().__init__() |
| self.resnet = _make_resnet(2 * dim_in, dim_out, time_emb_dim, cond_dim) |
| self.transformers = nn.ModuleList( |
| [ |
| _make_transformer(dim_out, n_heads, head_dim, dropout, act_fn) |
| for _ in range(n_transformer) |
| ] |
| ) |
| self.mix = nn.Conv1d(dim_out, dim_out, 3, padding=1) |
|
|
| def forward(self, x, skip, t, cond=None): |
| x = self.resnet(pack([x, skip], "b * t")[0], t, film_cond=cond) |
| x = rearrange(x, "b c t -> b t c") |
| for block in self.transformers: |
| x = block(hidden_states=x, attention_mask=None, timestep=t) |
| x = rearrange(x, "b t c -> b c t") |
| return self.mix(x) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class Decoder(nn.Module): |
| """ |
| 1D U-Net decoder for Matcha-TTS CFM. |
| |
| Time dimension T is held constant throughout (no spatial downsampling). |
| The "U-Net" structure provides multi-scale channel mixing and |
| skip connections for gradient flow, not resolution pyramids. |
| |
| Input channels to the first block are 2 * in_channels because |
| mu is channel-concatenated with x before entering the U-Net. |
| |
| Args: |
| in_channels: feat_dim (mel bins or latent dim) |
| out_channels: feat_dim (same; predicts residual vector field) |
| channels: channel widths at each down/up level |
| n_mid_blocks: number of mid blocks at the bottleneck |
| n_transformer: transformer blocks per down/mid/up stage |
| n_heads: attention heads |
| head_dim: dimension per head |
| dropout: dropout in transformers |
| act_fn: activation in transformer FFN ('snake' or 'silu') |
| cond_dim: optional semantic conditioning dim (pooled mu etc.) |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int = None, |
| out_channels: int = None, |
| channels: tuple = (256, 256), |
| n_mid_blocks: int = 2, |
| n_transformer: int = 1, |
| n_heads: int = 4, |
| head_dim: int = 64, |
| dropout: float = 0.05, |
| act_fn: str = "snakebeta", |
| cond_dim: int = None, |
| in_c: int = None, |
| out_c: int = None, |
| ): |
| super().__init__() |
|
|
| if in_channels is None: |
| in_channels = in_c |
| elif in_c is not None and in_c != in_channels: |
| raise ValueError("Received conflicting values for in_channels and in_c.") |
|
|
| if out_channels is None: |
| out_channels = out_c |
| elif out_c is not None and out_c != out_channels: |
| raise ValueError("Received conflicting values for out_channels and out_c.") |
|
|
| if in_channels is None or out_channels is None: |
| raise ValueError( |
| "Decoder requires in_channels/out_channels (or aliases in_c/out_c)." |
| ) |
|
|
| |
| time_emb_dim = channels[0] * 4 |
| self.time_emb = SinusoidalPosEmb(channels[0]) |
| self.time_mlp = TimestepEmbedding( |
| in_channels=channels[0], |
| time_embed_dim=time_emb_dim, |
| act_fn="silu", |
| cond_proj_dim=cond_dim, |
| ) |
|
|
| |
| dims_in = (2 * in_channels,) + channels[:-1] |
| dims_out = channels |
|
|
| self.down_blocks = nn.ModuleList( |
| [ |
| DownBlock( |
| di, |
| do, |
| time_emb_dim, |
| cond_dim, |
| n_transformer, |
| n_heads, |
| head_dim, |
| dropout, |
| act_fn, |
| ) |
| for di, do in zip(dims_in, dims_out) |
| ] |
| ) |
|
|
| self.mid_blocks = nn.ModuleList( |
| [ |
| MidBlock( |
| channels[-1], |
| time_emb_dim, |
| cond_dim, |
| n_transformer, |
| n_heads, |
| head_dim, |
| dropout, |
| act_fn, |
| ) |
| for _ in range(n_mid_blocks) |
| ] |
| ) |
|
|
| |
| up_dims_in = channels[::-1] |
| up_dims_out = channels[::-1][1:] + (channels[0],) |
|
|
| self.up_blocks = nn.ModuleList( |
| [ |
| UpBlock( |
| di, |
| do, |
| time_emb_dim, |
| cond_dim, |
| n_transformer, |
| n_heads, |
| head_dim, |
| dropout, |
| act_fn, |
| ) |
| for di, do in zip(up_dims_in, up_dims_out) |
| ] |
| ) |
|
|
| self.final_block = Block1D(channels[0], channels[0]) |
| self.final_proj = nn.Conv1d(channels[0], out_channels, 1) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, (nn.Conv1d, nn.Linear)): |
| nn.init.kaiming_normal_(m.weight, nonlinearity="relu") |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.GroupNorm): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| mu: torch.Tensor, |
| t: torch.Tensor, |
| cond: torch.Tensor = None, |
| ) -> torch.Tensor: |
|
|
| |
| t = t.reshape(x.shape[0]) |
| t = self.time_mlp(self.time_emb(t), condition=cond) |
|
|
| |
| x = pack([x, mu], "b * t")[0] |
|
|
| |
| hiddens = [] |
| for block in self.down_blocks: |
| x = block(x, t, cond) |
| hiddens.append(x) |
|
|
| |
| for block in self.mid_blocks: |
| x = block(x, t, cond) |
|
|
| |
| for block in self.up_blocks: |
| x = block(x, hiddens.pop(), t, cond) |
|
|
| x = self.final_block(x) |
| return self.final_proj(x) |
|
|