import torch import torch.nn as nn from einops import pack, rearrange from .unit_nn import ( SinusoidalPosEmb, TimestepEmbedding, ResnetBlock1D, Block1D, ) from .transformer import BasicTransformerBlock # --------------------------------------------------------------------------- # Small helpers to keep block construction readable # --------------------------------------------------------------------------- def _make_resnet(dim_in, dim_out, time_emb_dim, cond_dim): return ResnetBlock1D( dim=dim_in, dim_out=dim_out, time_emb_dim=time_emb_dim, film_dim=cond_dim, ) def _make_transformer(dim, n_heads, head_dim, dropout, act_fn): return BasicTransformerBlock( dim=dim, num_attention_heads=n_heads, attention_head_dim=head_dim, dropout=dropout, activation_fn=act_fn, ) # --------------------------------------------------------------------------- # Building-block containers # --------------------------------------------------------------------------- class DownBlock(nn.Module): """ ResNet -> n Transformer blocks -> channel-mixing Conv1d. T is never halved; the conv is purely for channel interaction. """ def __init__( self, dim_in, dim_out, time_emb_dim, cond_dim, n_transformer, n_heads, head_dim, dropout, act_fn, ): super().__init__() self.resnet = _make_resnet(dim_in, dim_out, time_emb_dim, cond_dim) self.transformers = nn.ModuleList( [ _make_transformer(dim_out, n_heads, head_dim, dropout, act_fn) for _ in range(n_transformer) ] ) self.mix = nn.Conv1d(dim_out, dim_out, 3, padding=1) def forward(self, x, t, cond=None): x = self.resnet(x, t, film_cond=cond) x = rearrange(x, "b c t -> b t c") for block in self.transformers: x = block(hidden_states=x, attention_mask=None, timestep=t) x = rearrange(x, "b t c -> b c t") return self.mix(x) class MidBlock(nn.Module): """ResNet -> n Transformer blocks. No spatial change.""" def __init__( self, dim, time_emb_dim, cond_dim, n_transformer, n_heads, head_dim, dropout, act_fn, ): super().__init__() self.resnet = _make_resnet(dim, dim, time_emb_dim, cond_dim) self.transformers = nn.ModuleList( [ _make_transformer(dim, n_heads, head_dim, dropout, act_fn) for _ in range(n_transformer) ] ) def forward(self, x, t, cond=None): x = self.resnet(x, t, film_cond=cond) x = rearrange(x, "b c t -> b t c") for block in self.transformers: x = block(hidden_states=x, attention_mask=None, timestep=t) return rearrange(x, "b t c -> b c t") class UpBlock(nn.Module): """ Skip-connection concat -> ResNet -> n Transformer blocks -> Conv1d. dim_in is the channel dim of the skip, not the doubled dim; the doubling is handled internally via ResnetBlock1D(dim=2*dim_in). """ def __init__( self, dim_in, dim_out, time_emb_dim, cond_dim, n_transformer, n_heads, head_dim, dropout, act_fn, ): super().__init__() self.resnet = _make_resnet(2 * dim_in, dim_out, time_emb_dim, cond_dim) self.transformers = nn.ModuleList( [ _make_transformer(dim_out, n_heads, head_dim, dropout, act_fn) for _ in range(n_transformer) ] ) self.mix = nn.Conv1d(dim_out, dim_out, 3, padding=1) def forward(self, x, skip, t, cond=None): x = self.resnet(pack([x, skip], "b * t")[0], t, film_cond=cond) x = rearrange(x, "b c t -> b t c") for block in self.transformers: x = block(hidden_states=x, attention_mask=None, timestep=t) x = rearrange(x, "b t c -> b c t") return self.mix(x) # --------------------------------------------------------------------------- # Full Decoder # --------------------------------------------------------------------------- class Decoder(nn.Module): """ 1D U-Net decoder for Matcha-TTS CFM. Time dimension T is held constant throughout (no spatial downsampling). The "U-Net" structure provides multi-scale channel mixing and skip connections for gradient flow, not resolution pyramids. Input channels to the first block are 2 * in_channels because mu is channel-concatenated with x before entering the U-Net. Args: in_channels: feat_dim (mel bins or latent dim) out_channels: feat_dim (same; predicts residual vector field) channels: channel widths at each down/up level n_mid_blocks: number of mid blocks at the bottleneck n_transformer: transformer blocks per down/mid/up stage n_heads: attention heads head_dim: dimension per head dropout: dropout in transformers act_fn: activation in transformer FFN ('snake' or 'silu') cond_dim: optional semantic conditioning dim (pooled mu etc.) """ def __init__( self, in_channels: int = None, out_channels: int = None, channels: tuple = (256, 256), n_mid_blocks: int = 2, n_transformer: int = 1, n_heads: int = 4, head_dim: int = 64, dropout: float = 0.05, act_fn: str = "snakebeta", cond_dim: int = None, in_c: int = None, # TODO need to fix these dumbass details out_c: int = None, # TODO need to fix these dumbass details ): super().__init__() if in_channels is None: in_channels = in_c elif in_c is not None and in_c != in_channels: raise ValueError("Received conflicting values for in_channels and in_c.") if out_channels is None: out_channels = out_c elif out_c is not None and out_c != out_channels: raise ValueError("Received conflicting values for out_channels and out_c.") if in_channels is None or out_channels is None: raise ValueError( "Decoder requires in_channels/out_channels (or aliases in_c/out_c)." ) # Time conditioning time_emb_dim = channels[0] * 4 self.time_emb = SinusoidalPosEmb(channels[0]) self.time_mlp = TimestepEmbedding( in_channels=channels[0], time_embed_dim=time_emb_dim, act_fn="silu", cond_proj_dim=cond_dim, ) # mu is concatenated into x before the U-Net, so first dim_in = 2 * in_channels dims_in = (2 * in_channels,) + channels[:-1] dims_out = channels self.down_blocks = nn.ModuleList( [ DownBlock( di, do, time_emb_dim, cond_dim, n_transformer, n_heads, head_dim, dropout, act_fn, ) for di, do in zip(dims_in, dims_out) ] ) self.mid_blocks = nn.ModuleList( [ MidBlock( channels[-1], time_emb_dim, cond_dim, n_transformer, n_heads, head_dim, dropout, act_fn, ) for _ in range(n_mid_blocks) ] ) # Up path: channel dims mirror the down path in reverse up_dims_in = channels[::-1] up_dims_out = channels[::-1][1:] + (channels[0],) self.up_blocks = nn.ModuleList( [ UpBlock( di, do, time_emb_dim, cond_dim, n_transformer, n_heads, head_dim, dropout, act_fn, ) for di, do in zip(up_dims_in, up_dims_out) ] ) self.final_block = Block1D(channels[0], channels[0]) self.final_proj = nn.Conv1d(channels[0], out_channels, 1) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, (nn.Conv1d, nn.Linear)): nn.init.kaiming_normal_(m.weight, nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.GroupNorm): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward( self, x: torch.Tensor, # (B, feat_dim, T) noisy interpolant mu: torch.Tensor, # (B, feat_dim, T) encoder conditioning t: torch.Tensor, # (B,) flow timestep cond: torch.Tensor = None, # (B, cond_dim) optional semantic cond ) -> torch.Tensor: # (B, feat_dim, T) # Time embedding t = t.reshape(x.shape[0]) t = self.time_mlp(self.time_emb(t), condition=cond) # Condition on mu via channel concatenation x = pack([x, mu], "b * t")[0] # (B, 2*feat_dim, T) # Down path - save hiddens for skip connections hiddens = [] for block in self.down_blocks: x = block(x, t, cond) hiddens.append(x) # Bottleneck for block in self.mid_blocks: x = block(x, t, cond) # Up path - skip connections from corresponding down blocks for block in self.up_blocks: x = block(x, hiddens.pop(), t, cond) x = self.final_block(x) return self.final_proj(x)