File size: 6,718 Bytes
0afe769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.diffusion_transformer import DiTBlock, DiTFinalLayer

# reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"

    def forward(self, x, scale=1000):
        if x.ndim < 1:
            x = x.unsqueeze(0)
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb)
        emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class TimestepEmbedding(nn.Module):
    def __init__(self, in_channels, out_channels, intermediate_size):
        super().__init__()

        self.layer = nn.Sequential(
            nn.Linear(in_channels, intermediate_size),
            nn.SiLU(inplace=True),
            nn.Linear(intermediate_size, out_channels)
        )

    def forward(self, x):
        return self.layer(x)
    

class Decoder(nn.Module):
    def __init__(self, configs):
        super().__init__()
        hidden_size = configs.hidden_size
        intermediate_size = configs.intermediate_size
        self.codebook_size = configs.codebook_size
        quantizers_num = configs.quantizers_num

        self.text_embed = nn.Embedding(configs.n_vocab, hidden_size)
        self.token_embed = nn.Parameter(
            torch.empty(quantizers_num, configs.codebook_size + configs.special_codebook_size, hidden_size)
            )
        self.input_proj = nn.Linear(quantizers_num * hidden_size, hidden_size)

        self.time_embeddings = SinusoidalPosEmb(hidden_size//2)
        self.time_mlp = TimestepEmbedding(hidden_size//2, hidden_size//2, intermediate_size//2)

        self.cfg_dropout = configs.cfg_dropout
        self.text_cfg_embed = nn.Parameter(torch.empty(hidden_size))
        self.token_cfg_embed = nn.Parameter(torch.empty(hidden_size))
        self.prompt_embed = nn.Parameter(torch.empty(hidden_size))
        self.lang_embed = nn.Embedding(configs.n_lang, hidden_size//2)
        
        self.blocks = nn.ModuleList([DiTBlock(hidden_size, intermediate_size, configs.n_heads, configs.dropout) for _ in range(configs.n_layers)])
        self.final_layer = DiTFinalLayer(hidden_size)

        self.init_weights()
    
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Linear, nn.Conv1d)):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            if isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, std=0.02)
        
        for block in self.blocks:
            nn.init.zeros_(block.adaLN_modulation[-1].weight)
            nn.init.zeros_(block.adaLN_modulation[-1].bias)
        nn.init.zeros_(self.final_layer.adaLN_modulation[-1].weight)
        nn.init.zeros_(self.final_layer.adaLN_modulation[-1].bias)

        for m in self.time_mlp.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.02)

        nn.init.normal_(self.token_embed, std=0.02)
        nn.init.normal_(self.text_cfg_embed, std=0.02)
        nn.init.normal_(self.token_cfg_embed, std=0.02)
        nn.init.normal_(self.prompt_embed, std=0.02)

    def forward(self, t, x_t, texts, mask, pred_mask, lang):
        b, l, c = x_t.shape
        texts_embed = self.text_embed(texts) # [b, l, h]

        idx_flat = torch.arange(c, device=x_t.device).view(1,1,c).expand(b,l,c).reshape(-1) # [b*l*c]
        embed = self.token_embed[idx_flat, x_t.reshape(-1), :].reshape(b, l, -1) # [b, l, c*h]
        x = self.input_proj(embed) # [b, l, h]
        
        x = x + (~pred_mask[:, :, None]).to(x.dtype) * self.prompt_embed  

        cfg_mask = torch.rand(b, device=x.device) < self.cfg_dropout
        texts_embed = torch.where(cfg_mask[:, None, None], self.text_cfg_embed, texts_embed)
        x = torch.where(cfg_mask[:, None, None] & ~pred_mask[:, :, None], self.token_cfg_embed, x)

        x = torch.cat([texts_embed, x], dim=1)

        mask = mask.unsqueeze(-1).to(x.dtype) # [b, l, 1]
        t = t.to(x.dtype)
        t = self.time_mlp(self.time_embeddings(t))
        lang_embed = self.lang_embed(lang)
        cond = torch.cat([t, lang_embed], dim=-1)
        
        attn_mask = mask * mask.transpose(1, 2) # [b, l, l]
        attn_mask = torch.zeros_like(attn_mask).masked_fill(attn_mask == 0, -torch.finfo(x.dtype).max).unsqueeze(1)  # [b, 1, l, l]

        for block in self.blocks:
            x = block(x, cond, mask, attn_mask)

        x = x[:, texts.shape[-1]:, :]
        x = self.final_layer(x, cond)

        logits = torch.einsum("blh,ckh->blck", x, self.token_embed[:, :self.codebook_size, :])
        return logits
    
    def infer(self, t, x_t, prompt_l, texts, lang, mask, rescale_cfg, cfg):
        b, l, c = x_t.shape # [b, l, c]
        texts_embed = self.text_embed(texts) # [1, l, h]
        texts_embed = torch.cat([texts_embed, 
                                 self.text_cfg_embed[None, None, :].expand_as(texts_embed)],
                                 dim=0)

        idx_flat = torch.arange(c, device=x_t.device).view(1,1,c).expand(b,l,c).reshape(-1) # [b*l*c]
        embed = self.token_embed[idx_flat, x_t.reshape(-1), :].view(b, l, -1)
        x = self.input_proj(embed) # [b, l, h]

        x[:, :prompt_l, :] += self.prompt_embed

        t = self.time_mlp(self.time_embeddings(t))
        lang_embed = self.lang_embed(lang)
        cond = torch.cat([t, lang_embed], dim=-1)

        # cfg
        x_cfg = x.clone()
        x_cfg[:, :prompt_l, :] = self.token_cfg_embed
        x = torch.cat([x, x_cfg], dim=0) # [2b, text+l, h]
        x = torch.cat([texts_embed, x], dim=1)

        attn_mask = mask * mask.transpose(1, 2) # [b, l, l]
        attn_mask = torch.zeros_like(attn_mask).masked_fill(attn_mask == 0, -torch.finfo(x.dtype).max).unsqueeze(1)  # [b, 1, l, l]

        for block in self.blocks:
            x = block(x, cond, mask, attn_mask)
            
        x = x[:, texts.shape[-1]:, :]
        x = self.final_layer(x, cond)
        
        # cfg
        x, x_cfg = torch.split(x, [b, b], dim=0)
        x_std = x.std()
        x = x + cfg * (x - x_cfg)
        rescale_x = x * x_std / x.std()
        x = rescale_cfg * rescale_x + (1 - rescale_cfg) * x

        logits = torch.einsum("blh,ckh->blck", x, self.token_embed[:, :self.codebook_size, :])
        return logits