| """ |
| @author : Hyunwoong |
| @when : 2019-12-18 |
| @homepage : https://github.com/gusdnd852 |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class EncoderLayer(nn.Module): |
|
|
| def __init__(self, d_model, ffn_hidden, n_head, drop_prob): |
| super(EncoderLayer, self).__init__() |
| self.attention = MultiHeadAttention(d_model=d_model, n_head=n_head) |
| self.norm1 = LayerNorm(d_model=d_model) |
| self.dropout1 = nn.Dropout(p=drop_prob) |
|
|
| self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob) |
| self.norm2 = LayerNorm(d_model=d_model) |
| self.dropout2 = nn.Dropout(p=drop_prob) |
|
|
| def forward(self, x, s_mask): |
| |
| _x = x |
| x = self.attention(q=x, k=x, v=x, mask=s_mask) |
| |
| |
| x = self.dropout1(x) |
| x = self.norm1(x + _x) |
| |
| |
| _x = x |
| x = self.ffn(x) |
| |
| |
| x = self.dropout2(x) |
| x = self.norm2(x + _x) |
| return x |
|
|
|
|
| class DecoderLayer(nn.Module): |
|
|
| def __init__(self, d_model, ffn_hidden, n_head, drop_prob): |
| super(DecoderLayer, self).__init__() |
| self.self_attention = MultiHeadAttention(d_model=d_model, n_head=n_head) |
| self.norm1 = LayerNorm(d_model=d_model) |
| self.dropout1 = nn.Dropout(p=drop_prob) |
|
|
| self.enc_dec_attention = MultiHeadAttention(d_model=d_model, n_head=n_head) |
| self.norm2 = LayerNorm(d_model=d_model) |
| self.dropout2 = nn.Dropout(p=drop_prob) |
|
|
| self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob) |
| self.norm3 = LayerNorm(d_model=d_model) |
| self.dropout3 = nn.Dropout(p=drop_prob) |
|
|
| def forward(self, dec, enc, t_mask, s_mask): |
| |
| _x = dec |
| x = self.self_attention(q=dec, k=dec, v=dec, mask=t_mask) |
| |
| |
| x = self.dropout1(x) |
| x = self.norm1(x + _x) |
|
|
| if enc is not None: |
| |
| _x = x |
| x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=s_mask) |
| |
| |
| x = self.dropout2(x) |
| x = self.norm2(x + _x) |
|
|
| |
| _x = x |
| x = self.ffn(x) |
| |
| |
| x = self.dropout3(x) |
| x = self.norm3(x + _x) |
| return x |
|
|
|
|
| class ScaleDotProductAttention(nn.Module): |
| """ |
| compute scale dot product attention |
| |
| Query : given sentence that we focused on (decoder) |
| Key : every sentence to check relationship with Qeury(encoder) |
| Value : every sentence same with Key (encoder) |
| """ |
|
|
| def __init__(self): |
| super(ScaleDotProductAttention, self).__init__() |
| self.softmax = nn.Softmax(dim=-1) |
|
|
| def forward(self, q, k, v, mask=None, e=1e-12): |
| |
| |
| batch_size, head, length, d_tensor = k.size() |
|
|
| |
| k_t = k.transpose(2, 3) |
| score = (q @ k_t) / math.sqrt(d_tensor) |
|
|
| |
| if mask is not None: |
| score = score.masked_fill(mask == 0, -10000) |
|
|
| |
| score = self.softmax(score) |
|
|
| |
| v = score @ v |
|
|
| return v, score |
|
|
|
|
| class PositionwiseFeedForward(nn.Module): |
|
|
| def __init__(self, d_model, hidden, drop_prob=0.1): |
| super(PositionwiseFeedForward, self).__init__() |
| self.linear1 = nn.Linear(d_model, hidden) |
| self.linear2 = nn.Linear(hidden, d_model) |
| self.relu = nn.ReLU() |
| self.dropout = nn.Dropout(p=drop_prob) |
|
|
| def forward(self, x): |
| x = self.linear1(x) |
| x = self.relu(x) |
| x = self.dropout(x) |
| x = self.linear2(x) |
| return x |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
|
|
| def __init__(self, d_model, n_head): |
| super(MultiHeadAttention, self).__init__() |
| self.n_head = n_head |
| self.attention = ScaleDotProductAttention() |
| self.w_q = nn.Linear(d_model, d_model, bias=False) |
| self.w_k = nn.Linear(d_model, d_model, bias=False) |
| self.w_v = nn.Linear(d_model, d_model, bias=False) |
| self.w_concat = nn.Linear(d_model, d_model, bias=False) |
|
|
| def forward(self, q, k, v, mask=None): |
| |
| q, k, v = self.w_q(q), self.w_k(k), self.w_v(v) |
|
|
| |
| q, k, v = self.split(q), self.split(k), self.split(v) |
|
|
| |
| out, attention = self.attention(q, k, v, mask=mask) |
|
|
| |
| out = self.concat(out) |
| out = self.w_concat(out) |
|
|
| |
| |
|
|
| return out |
|
|
| def split(self, tensor): |
| """ |
| split tensor by number of head |
| |
| :param tensor: [batch_size, length, d_model] |
| :return: [batch_size, head, length, d_tensor] |
| """ |
| batch_size, length, d_model = tensor.size() |
|
|
| d_tensor = d_model // self.n_head |
| tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2) |
| |
|
|
| return tensor |
|
|
| def concat(self, tensor): |
| """ |
| inverse function of self.split(tensor : torch.Tensor) |
| |
| :param tensor: [batch_size, head, length, d_tensor] |
| :return: [batch_size, length, d_model] |
| """ |
| batch_size, head, length, d_tensor = tensor.size() |
| d_model = head * d_tensor |
|
|
| tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model) |
| return tensor |
|
|
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, d_model, eps=1e-12): |
| super(LayerNorm, self).__init__() |
| self.gamma = nn.Parameter(torch.ones(d_model)) |
| self.beta = nn.Parameter(torch.zeros(d_model)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| mean = x.mean(-1, keepdim=True) |
| var = x.var(-1, unbiased=False, keepdim=True) |
| |
|
|
| out = (x - mean) / torch.sqrt(var + self.eps) |
| out = self.gamma * out + self.beta |
| return out |
|
|
|
|
| class TransformerEmbedding(nn.Module): |
| """ |
| token embedding + positional encoding (sinusoid) |
| positional encoding can give positional information to network |
| """ |
|
|
| def __init__(self, vocab_size, d_model, max_len, drop_prob, padding_idx, learnable_pos_emb=True): |
| """ |
| class for word embedding that included positional information |
| |
| :param vocab_size: size of vocabulary |
| :param d_model: dimensions of model |
| """ |
| super(TransformerEmbedding, self).__init__() |
| self.tok_emb = TokenEmbedding(vocab_size, d_model, padding_idx) |
| if learnable_pos_emb: |
| self.pos_emb = LearnablePositionalEncoding(d_model, max_len) |
| else: |
| self.pos_emb = SinusoidalPositionalEncoding(d_model, max_len) |
| self.drop_out = nn.Dropout(p=drop_prob) |
|
|
| def forward(self, x): |
| tok_emb = self.tok_emb(x) |
| pos_emb = self.pos_emb(x).to(tok_emb.device) |
| return self.drop_out(tok_emb + pos_emb) |
|
|
|
|
| class TokenEmbedding(nn.Embedding): |
| """ |
| Token Embedding using torch.nn |
| they will dense representation of word using weighted matrix |
| """ |
|
|
| def __init__(self, vocab_size, d_model, padding_idx): |
| """ |
| class for token embedding that included positional information |
| |
| :param vocab_size: size of vocabulary |
| :param d_model: dimensions of model |
| """ |
| super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=padding_idx) |
|
|
|
|
| class SinusoidalPositionalEncoding(nn.Module): |
| """ |
| compute sinusoid encoding. |
| """ |
|
|
| def __init__(self, d_model, max_len): |
| """ |
| constructor of sinusoid encoding class |
| |
| :param d_model: dimension of model |
| :param max_len: max sequence length |
| |
| """ |
| super(SinusoidalPositionalEncoding, self).__init__() |
|
|
| |
| self.encoding = torch.zeros(max_len, d_model) |
| self.encoding.requires_grad = False |
|
|
| pos = torch.arange(0, max_len) |
| pos = pos.float().unsqueeze(dim=1) |
| |
|
|
| _2i = torch.arange(0, d_model, step=2).float() |
| |
| |
|
|
| self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model))) |
| self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model))) |
| |
|
|
| def forward(self, x): |
| |
| |
|
|
| batch_size, seq_len = x.size() |
| |
|
|
| return self.encoding[:seq_len, :] |
| |
| |
|
|
|
|
| class LearnablePositionalEncoding(nn.Module): |
| """ |
| compute sinusoid encoding. |
| """ |
|
|
| def __init__(self, d_model, max_seq_len): |
| """ |
| constructor of learnable positonal encoding class |
| |
| :param d_model: dimension of model |
| :param max_seq_len: max sequence length |
| |
| """ |
| super(LearnablePositionalEncoding, self).__init__() |
| self.max_seq_len = max_seq_len |
| self.wpe = nn.Embedding(max_seq_len, d_model) |
|
|
| def forward(self, x): |
| |
| |
| device = x.device |
| batch_size, seq_len = x.size() |
| assert seq_len <= self.max_seq_len, f"Cannot forward sequence of length {seq_len}, max_seq_len is {self.max_seq_len}" |
| pos = torch.arange(0, seq_len, dtype=torch.long, device=device) |
| pos_emb = self.wpe(pos) |
|
|
| return pos_emb |
| |
| |
|
|
|
|
| class Encoder(nn.Module): |
|
|
| def __init__(self, enc_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, padding_idx, learnable_pos_emb=True): |
| super().__init__() |
| self.emb = TransformerEmbedding(d_model=d_model, |
| max_len=max_len, |
| vocab_size=enc_voc_size, |
| drop_prob=drop_prob, |
| padding_idx=padding_idx, |
| learnable_pos_emb=learnable_pos_emb |
| ) |
|
|
| self.layers = nn.ModuleList([EncoderLayer(d_model=d_model, |
| ffn_hidden=ffn_hidden, |
| n_head=n_head, |
| drop_prob=drop_prob) |
| for _ in range(n_layers)]) |
|
|
| def forward(self, x, s_mask): |
| x = self.emb(x) |
|
|
| for layer in self.layers: |
| x = layer(x, s_mask) |
|
|
| return x |
|
|
| class Decoder(nn.Module): |
| def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, padding_idx, learnable_pos_emb=True): |
| super().__init__() |
| self.emb = TransformerEmbedding(d_model=d_model, |
| drop_prob=drop_prob, |
| max_len=max_len, |
| vocab_size=dec_voc_size, |
| padding_idx=padding_idx, |
| learnable_pos_emb=learnable_pos_emb |
| ) |
|
|
| self.layers = nn.ModuleList([DecoderLayer(d_model=d_model, |
| ffn_hidden=ffn_hidden, |
| n_head=n_head, |
| drop_prob=drop_prob) |
| for _ in range(n_layers)]) |
|
|
| self.linear = nn.Linear(d_model, dec_voc_size) |
|
|
| def forward(self, trg, enc_src, trg_mask, src_mask): |
| trg = self.emb(trg) |
|
|
| for layer in self.layers: |
| trg = layer(trg, enc_src, trg_mask, src_mask) |
|
|
| |
| output = self.linear(trg) |
| return output |
|
|
| class Transformer(nn.Module): |
|
|
| def __init__(self, src_pad_idx, trg_pad_idx, enc_voc_size, dec_voc_size, d_model, n_head, max_len, |
| ffn_hidden, n_layers, drop_prob, learnable_pos_emb=True): |
| super().__init__() |
| self.src_pad_idx = src_pad_idx |
| self.trg_pad_idx = trg_pad_idx |
| self.encoder = Encoder(d_model=d_model, |
| n_head=n_head, |
| max_len=max_len, |
| ffn_hidden=ffn_hidden, |
| enc_voc_size=enc_voc_size, |
| drop_prob=drop_prob, |
| n_layers=n_layers, |
| padding_idx=src_pad_idx, |
| learnable_pos_emb=learnable_pos_emb) |
|
|
| self.decoder = Decoder(d_model=d_model, |
| n_head=n_head, |
| max_len=max_len, |
| ffn_hidden=ffn_hidden, |
| dec_voc_size=dec_voc_size, |
| drop_prob=drop_prob, |
| n_layers=n_layers, |
| padding_idx=trg_pad_idx, |
| learnable_pos_emb=learnable_pos_emb) |
|
|
| def get_device(self): |
| return next(self.parameters()).device |
|
|
| def forward(self, src, trg): |
| device = self.get_device() |
| src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx).to(device) |
| src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx).to(device) |
| trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx).to(device) * \ |
| self.make_no_peak_mask(trg, trg).to(device) |
|
|
| |
| |
| |
| enc_src = self.encoder(src, src_mask) |
| output = self.decoder(trg, enc_src, trg_mask, src_trg_mask) |
| return output |
|
|
| def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx): |
| len_q, len_k = q.size(1), k.size(1) |
|
|
| |
| k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2) |
| |
| k = k.repeat(1, 1, len_q, 1) |
|
|
| |
| q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3) |
| |
| q = q.repeat(1, 1, 1, len_k) |
|
|
| mask = k & q |
| return mask |
|
|
| def make_no_peak_mask(self, q, k): |
| len_q, len_k = q.size(1), k.size(1) |
|
|
| |
| mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor) |
|
|
| return mask |
|
|
|
|
| def make_pad_mask(x, pad_idx): |
| q = k = x |
| q_pad_idx = k_pad_idx = pad_idx |
| len_q, len_k = q.size(1), k.size(1) |
|
|
| |
| k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2) |
| |
| k = k.repeat(1, 1, len_q, 1) |
|
|
| |
| q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3) |
| |
| q = q.repeat(1, 1, 1, len_k) |
|
|
| mask = k & q |
| return mask |
|
|
|
|
| from torch.nn.utils.rnn import pad_sequence |
| |
| def pad_seq_v2(sequences, batch_first=True, padding_value=0.0, prepadding=True): |
| lens = [i.shape[0]for i in sequences] |
| padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_value) |
| if prepadding: |
| for i in range(len(lens)): |
| padded_sequences[i] = padded_sequences[i].roll(-lens[i]) |
| if not batch_first: |
| padded_sequences = padded_sequences.transpose(0, 1) |
| return padded_sequences |
|
|
|
|
|
|
| if __name__ == '__main__': |
| import torch |
| import random |
| import numpy as np |
|
|
| rand_seed = 10 |
|
|
| device = 'cpu' |
|
|
| |
| batch_size = 128 |
| max_len = 256 |
| d_model = 512 |
| n_layers = 3 |
| n_heads = 16 |
| ffn_hidden = 2048 |
| drop_prob = 0.1 |
|
|
| |
| init_lr = 1e-5 |
| factor = 0.9 |
| adam_eps = 5e-9 |
| patience = 10 |
| warmup = 100 |
| epoch = 1000 |
| clip = 1.0 |
| weight_decay = 5e-4 |
| inf = float('inf') |
| |
| src_pad_idx = 2 |
| trg_pad_idx = 3 |
| |
| enc_voc_size = 37 |
| dec_voc_size = 15 |
| model = Transformer(src_pad_idx=src_pad_idx, |
| trg_pad_idx=trg_pad_idx, |
| d_model=d_model, |
| enc_voc_size=enc_voc_size, |
| dec_voc_size=dec_voc_size, |
| max_len=max_len, |
| ffn_hidden=ffn_hidden, |
| n_head=n_heads, |
| n_layers=n_layers, |
| drop_prob=drop_prob |
| ).to(device) |
|
|
| random.seed(rand_seed) |
| |
| np.random.seed(rand_seed) |
| torch.manual_seed(rand_seed) |
|
|
| x_list = [ |
| torch.tensor([[1, 1]]).transpose(0, 1), |
| torch.tensor([[1, 1, 1, 1, 1, 1, 1]]).transpose(0, 1), |
| torch.tensor([[1, 1, 1]]).transpose(0, 1) |
| ] |
|
|
|
|
| src_pad_idx = model.src_pad_idx |
| trg_pad_idx = model.trg_pad_idx |
|
|
| src = pad_seq_v2(x_list, padding_value=src_pad_idx, prepadding=False).squeeze(2) |
| trg = pad_seq_v2(x_list, padding_value=trg_pad_idx, prepadding=False).squeeze(2) |
| out = model(src, trg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|