flow-matching / src /stage2 /decoder.py
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
import torch
import torch.nn as nn
from einops import pack, rearrange
from .unit_nn import (
SinusoidalPosEmb,
TimestepEmbedding,
ResnetBlock1D,
Block1D,
)
from .transformer import BasicTransformerBlock
# ---------------------------------------------------------------------------
# Small helpers to keep block construction readable
# ---------------------------------------------------------------------------
def _make_resnet(dim_in, dim_out, time_emb_dim, cond_dim):
return ResnetBlock1D(
dim=dim_in,
dim_out=dim_out,
time_emb_dim=time_emb_dim,
film_dim=cond_dim,
)
def _make_transformer(dim, n_heads, head_dim, dropout, act_fn):
return BasicTransformerBlock(
dim=dim,
num_attention_heads=n_heads,
attention_head_dim=head_dim,
dropout=dropout,
activation_fn=act_fn,
)
# ---------------------------------------------------------------------------
# Building-block containers
# ---------------------------------------------------------------------------
class DownBlock(nn.Module):
"""
ResNet -> n Transformer blocks -> channel-mixing Conv1d.
T is never halved; the conv is purely for channel interaction.
"""
def __init__(
self,
dim_in,
dim_out,
time_emb_dim,
cond_dim,
n_transformer,
n_heads,
head_dim,
dropout,
act_fn,
):
super().__init__()
self.resnet = _make_resnet(dim_in, dim_out, time_emb_dim, cond_dim)
self.transformers = nn.ModuleList(
[
_make_transformer(dim_out, n_heads, head_dim, dropout, act_fn)
for _ in range(n_transformer)
]
)
self.mix = nn.Conv1d(dim_out, dim_out, 3, padding=1)
def forward(self, x, t, cond=None):
x = self.resnet(x, t, film_cond=cond)
x = rearrange(x, "b c t -> b t c")
for block in self.transformers:
x = block(hidden_states=x, attention_mask=None, timestep=t)
x = rearrange(x, "b t c -> b c t")
return self.mix(x)
class MidBlock(nn.Module):
"""ResNet -> n Transformer blocks. No spatial change."""
def __init__(
self,
dim,
time_emb_dim,
cond_dim,
n_transformer,
n_heads,
head_dim,
dropout,
act_fn,
):
super().__init__()
self.resnet = _make_resnet(dim, dim, time_emb_dim, cond_dim)
self.transformers = nn.ModuleList(
[
_make_transformer(dim, n_heads, head_dim, dropout, act_fn)
for _ in range(n_transformer)
]
)
def forward(self, x, t, cond=None):
x = self.resnet(x, t, film_cond=cond)
x = rearrange(x, "b c t -> b t c")
for block in self.transformers:
x = block(hidden_states=x, attention_mask=None, timestep=t)
return rearrange(x, "b t c -> b c t")
class UpBlock(nn.Module):
"""
Skip-connection concat -> ResNet -> n Transformer blocks -> Conv1d.
dim_in is the channel dim of the skip, not the doubled dim;
the doubling is handled internally via ResnetBlock1D(dim=2*dim_in).
"""
def __init__(
self,
dim_in,
dim_out,
time_emb_dim,
cond_dim,
n_transformer,
n_heads,
head_dim,
dropout,
act_fn,
):
super().__init__()
self.resnet = _make_resnet(2 * dim_in, dim_out, time_emb_dim, cond_dim)
self.transformers = nn.ModuleList(
[
_make_transformer(dim_out, n_heads, head_dim, dropout, act_fn)
for _ in range(n_transformer)
]
)
self.mix = nn.Conv1d(dim_out, dim_out, 3, padding=1)
def forward(self, x, skip, t, cond=None):
x = self.resnet(pack([x, skip], "b * t")[0], t, film_cond=cond)
x = rearrange(x, "b c t -> b t c")
for block in self.transformers:
x = block(hidden_states=x, attention_mask=None, timestep=t)
x = rearrange(x, "b t c -> b c t")
return self.mix(x)
# ---------------------------------------------------------------------------
# Full Decoder
# ---------------------------------------------------------------------------
class Decoder(nn.Module):
"""
1D U-Net decoder for Matcha-TTS CFM.
Time dimension T is held constant throughout (no spatial downsampling).
The "U-Net" structure provides multi-scale channel mixing and
skip connections for gradient flow, not resolution pyramids.
Input channels to the first block are 2 * in_channels because
mu is channel-concatenated with x before entering the U-Net.
Args:
in_channels: feat_dim (mel bins or latent dim)
out_channels: feat_dim (same; predicts residual vector field)
channels: channel widths at each down/up level
n_mid_blocks: number of mid blocks at the bottleneck
n_transformer: transformer blocks per down/mid/up stage
n_heads: attention heads
head_dim: dimension per head
dropout: dropout in transformers
act_fn: activation in transformer FFN ('snake' or 'silu')
cond_dim: optional semantic conditioning dim (pooled mu etc.)
"""
def __init__(
self,
in_channels: int = None,
out_channels: int = None,
channels: tuple = (256, 256),
n_mid_blocks: int = 2,
n_transformer: int = 1,
n_heads: int = 4,
head_dim: int = 64,
dropout: float = 0.05,
act_fn: str = "snakebeta",
cond_dim: int = None,
in_c: int = None, # TODO need to fix these dumbass details
out_c: int = None, # TODO need to fix these dumbass details
):
super().__init__()
if in_channels is None:
in_channels = in_c
elif in_c is not None and in_c != in_channels:
raise ValueError("Received conflicting values for in_channels and in_c.")
if out_channels is None:
out_channels = out_c
elif out_c is not None and out_c != out_channels:
raise ValueError("Received conflicting values for out_channels and out_c.")
if in_channels is None or out_channels is None:
raise ValueError(
"Decoder requires in_channels/out_channels (or aliases in_c/out_c)."
)
# Time conditioning
time_emb_dim = channels[0] * 4
self.time_emb = SinusoidalPosEmb(channels[0])
self.time_mlp = TimestepEmbedding(
in_channels=channels[0],
time_embed_dim=time_emb_dim,
act_fn="silu",
cond_proj_dim=cond_dim,
)
# mu is concatenated into x before the U-Net, so first dim_in = 2 * in_channels
dims_in = (2 * in_channels,) + channels[:-1]
dims_out = channels
self.down_blocks = nn.ModuleList(
[
DownBlock(
di,
do,
time_emb_dim,
cond_dim,
n_transformer,
n_heads,
head_dim,
dropout,
act_fn,
)
for di, do in zip(dims_in, dims_out)
]
)
self.mid_blocks = nn.ModuleList(
[
MidBlock(
channels[-1],
time_emb_dim,
cond_dim,
n_transformer,
n_heads,
head_dim,
dropout,
act_fn,
)
for _ in range(n_mid_blocks)
]
)
# Up path: channel dims mirror the down path in reverse
up_dims_in = channels[::-1]
up_dims_out = channels[::-1][1:] + (channels[0],)
self.up_blocks = nn.ModuleList(
[
UpBlock(
di,
do,
time_emb_dim,
cond_dim,
n_transformer,
n_heads,
head_dim,
dropout,
act_fn,
)
for di, do in zip(up_dims_in, up_dims_out)
]
)
self.final_block = Block1D(channels[0], channels[0])
self.final_proj = nn.Conv1d(channels[0], out_channels, 1)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(
self,
x: torch.Tensor, # (B, feat_dim, T) noisy interpolant
mu: torch.Tensor, # (B, feat_dim, T) encoder conditioning
t: torch.Tensor, # (B,) flow timestep
cond: torch.Tensor = None, # (B, cond_dim) optional semantic cond
) -> torch.Tensor: # (B, feat_dim, T)
# Time embedding
t = t.reshape(x.shape[0])
t = self.time_mlp(self.time_emb(t), condition=cond)
# Condition on mu via channel concatenation
x = pack([x, mu], "b * t")[0] # (B, 2*feat_dim, T)
# Down path - save hiddens for skip connections
hiddens = []
for block in self.down_blocks:
x = block(x, t, cond)
hiddens.append(x)
# Bottleneck
for block in self.mid_blocks:
x = block(x, t, cond)
# Up path - skip connections from corresponding down blocks
for block in self.up_blocks:
x = block(x, hiddens.pop(), t, cond)
x = self.final_block(x)
return self.final_proj(x)