| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Text encoder implementation in PyTorch.""" |
|
|
| import typing as t |
|
|
| import numpy as np |
| import sentencepiece as spm |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
|
|
| class Tokenizer(object): |
| """A simple tokenizer using SentencePiece.""" |
|
|
| def __init__(self, tokenizer_path: str): |
| self.sp = spm.SentencePieceProcessor(model_file=tokenizer_path) |
| |
| self.sp.SetEncodeExtraOptions("") |
| |
| self._add_bos = False |
| self._add_eos = False |
|
|
| def tokenize(self, input_texts, max_len=64): |
| if isinstance(input_texts, str): |
| input_texts = [input_texts] |
| batch_ids = [ |
| self.sp.encode(t.lower(), add_bos=self._add_bos, add_eos=self._add_eos) |
| for t in input_texts |
| ] |
| tokens = np.zeros((len(batch_ids), max_len), dtype=np.int64) |
| for i, ids in enumerate(batch_ids): |
| length = min(len(ids), max_len) |
| tokens[i, :length] = ids[:length] |
| is_padding = (tokens == 0).astype(np.int32) |
| return tokens, is_padding |
|
|
|
|
| class PositionalEmbedding(nn.Module): |
| """Generates position embedding for a given 1-d sequence. |
| |
| Attributes: |
| min_timescale: Start of the geometric index. Determines the periodicity of |
| the added signal. |
| max_timescale: End of the geometric index. Determines the frequency of the |
| added signal. |
| embedding_dim: Dimension of the embedding to be generated. |
| """ |
|
|
| min_timescale: int = 1 |
| max_timescale: int = 10_000 |
| embedding_dim: int = 0 |
|
|
| def __init__(self, embedding_dim: int): |
| super().__init__() |
| self.embedding_dim = embedding_dim |
|
|
| def __call__(self, seq_length: int = None, position: torch.tensor = None): |
| """Generates a torch.tensor of sinusoids with different frequencies. |
| |
| Args: |
| seq_length: an optional Python int defining the output sequence length. |
| if the `position` argument is specified. |
| position: [B, seq_length], optional position for each token in the |
| sequence, only required when the sequence is packed. |
| |
| Returns: |
| [B, seqlen, D] if `position` is specified, else [1, seqlen, D] |
| """ |
| if position is None: |
| assert seq_length is not None |
| |
| position = torch.arange(seq_length, dtype=torch.float32)[None, :] |
| else: |
| assert position.ndim == 2, position.shape |
|
|
| num_timescales = self.embedding_dim // 2 |
| log_timescale_increment = torch.log( |
| torch.tensor(float(self.max_timescale) / float(self.min_timescale)) |
| ) / torch.maximum( |
| torch.tensor(num_timescales, dtype=torch.float32) - 1, torch.tensor(1) |
| ) |
| inv_timescales = self.min_timescale * torch.exp( |
| torch.arange(num_timescales, dtype=torch.float32) |
| * -log_timescale_increment |
| ) |
| scaled_time = position[:, :, None] * inv_timescales[None, None, :] |
| signal = torch.cat((torch.sin(scaled_time), torch.cos(scaled_time)), dim=2) |
| |
| |
| signal = F.pad(signal, (0, self.embedding_dim % 2, 0, 0, 0, 0)) |
| return signal |
|
|
|
|
| class MlpBlockWithMask(nn.Module): |
| """Transformer MLP / feed-forward block that supports masking.""" |
|
|
| def __init__( |
| self, |
| mlp_dim: int, |
| d_model: int, |
| use_bias: bool = True, |
| dtype: torch.dtype = torch.float32, |
| activation_fn: nn.Module = nn.GELU, |
| ): |
| super().__init__() |
|
|
| self.mlp_dim = mlp_dim |
| self.d_model = d_model |
| self.use_bias = use_bias |
| self.dtype = dtype |
| self.activation_fn = activation_fn |
|
|
| self.c_fc = nn.Linear( |
| in_features=self.d_model, |
| out_features=self.mlp_dim, |
| dtype=self.dtype, |
| bias=self.use_bias, |
| ) |
| self.c_proj = nn.Linear( |
| in_features=self.mlp_dim, |
| out_features=self.d_model, |
| dtype=self.dtype, |
| bias=self.use_bias, |
| ) |
|
|
| def __call__( |
| self, inputs: torch.Tensor, mlp_mask: torch.Tensor |
| ) -> torch.Tensor: |
| """Applies Transformer MlpBlock with mask module.""" |
| x = self.c_fc(inputs) |
| x = self.activation_fn()(x) |
| x = x * mlp_mask[..., None] |
| x = self.c_proj(x) |
| x = x * mlp_mask[..., None] |
| return x |
|
|
|
|
| class ResidualAttentionBlock(nn.Module): |
| """Transformer residual attention block.""" |
|
|
| def __init__( |
| self, |
| d_model: int, |
| n_head: int, |
| mlp_dim: int, |
| dtype: torch.dtype = torch.float32, |
| ): |
| super().__init__() |
| self.d_model = d_model |
| self.n_head = n_head |
| self.mlp_dim = mlp_dim |
| self.dtype = dtype |
|
|
| self.attn = nn.MultiheadAttention(d_model, n_head, dtype=self.dtype) |
| self.ln_1 = nn.LayerNorm(d_model, dtype=self.dtype) |
| self.mlp = MlpBlockWithMask( |
| self.mlp_dim, |
| d_model, |
| use_bias=True, |
| dtype=self.dtype, |
| activation_fn=nn.ReLU, |
| ) |
| self.ln_2 = nn.LayerNorm(d_model, dtype=self.dtype) |
|
|
| def attention(self, x: torch.Tensor, mask: torch.Tensor): |
| attn_mask = ( |
| mask[:, None, None, :] |
| .repeat(1, self.n_head, x.shape[0], 1) |
| .flatten(0, 1) |
| ) |
| attn_mask[attn_mask == 0] = float('-inf') |
| attn_mask[attn_mask == 1] = 0 |
| return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor): |
| x = x + self.attention(self.ln_1(x), mask.permute(1, 0)) |
| x = x + self.mlp(self.ln_2(x), mask) |
| return x, mask |
|
|
|
|
| class SequentialMultiInput(nn.Sequential): |
| """Sequential module that can take multiple inputs.""" |
|
|
| def forward(self, *inputs): |
| for module in self._modules.values(): |
| if isinstance(inputs, tuple): |
| inputs = module(*inputs) |
| else: |
| inputs = module(inputs) |
| return inputs |
|
|
|
|
| class Transformer(nn.Module): |
| """Transformer implementation.""" |
|
|
| def __init__( |
| self, |
| width: int, |
| layers: int, |
| heads: int, |
| mlp_dim: int, |
| dtype: torch.dtype = torch.float32, |
| ): |
| super().__init__() |
| self.width = width |
| self.layers = layers |
| self.heads = heads |
| self.mlp_dim = mlp_dim |
| self.dtype = dtype |
|
|
| self.resblocks = SequentialMultiInput(*[ |
| ResidualAttentionBlock(self.width, self.heads, self.mlp_dim, self.dtype) |
| for _ in range(self.layers) |
| ]) |
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor): |
| return self.resblocks(x, mask)[0] |
|
|
|
|
| class GlobalAvgPooling(nn.Module): |
| """Performs a simple global pooling over the input with optional paddings. |
| |
| Attributes: |
| pooling_dims: A list of dims to perform pooling over. |
| keepdims: If True, keep dimension of inputs after pooling. |
| """ |
|
|
| pooling_dims: t.Sequence[int] |
| epsilon: float = 1e-8 |
|
|
| def __init__( |
| self, pooling_dims: t.Sequence[int], epsilon: float = 1e-8 |
| ): |
| super().__init__() |
| self.pooling_dims = pooling_dims |
| self.epsilon = epsilon |
|
|
| if not all([p_dims >= 0 for p_dims in self.pooling_dims]): |
| raise ValueError('pooling_dims must be non-negative integers.') |
|
|
| def __call__( |
| self, |
| inputs: torch.tensor, |
| compatible_paddings: torch.tensor, |
| ): |
| """Applies global average spatial pooling to inputs. |
| |
| Args: |
| inputs: An input tensor. |
| compatible_paddings: paddings of inputs with shapes compatible with |
| inputs, e.g. compatible_paddings with shape [B, 1] for inputs with shape |
| [B, D]. |
| |
| Returns: |
| Output tensor with global pooling applied. |
| """ |
| padded_value = torch.zeros_like(inputs) |
| padded_value = torch.ones_like(inputs) * padded_value |
| inputs = torch.where(compatible_paddings > 0, padded_value, inputs) |
| valid_inputs = ( |
| torch.sum( |
| 1.0 - compatible_paddings, |
| self.pooling_dims, |
| keepdims=True, |
| dtype=inputs.dtype, |
| ) |
| + self.epsilon |
| ) |
| inputs_sum = torch.sum(inputs, self.pooling_dims, keepdims=True) |
| outputs = torch.divide(inputs_sum, valid_inputs).type(inputs.dtype) |
| outputs = torch.squeeze(outputs, axis=self.pooling_dims) |
| return outputs |
|
|
|
|
| class TextEncoder(nn.Module): |
| """Text encoder implementation.""" |
|
|
| def __init__( |
| self, |
| config: t.Dict[str, int], |
| vocab_size: int, |
| dtype: torch.dtype = torch.float32, |
| scale_sqrt_depth: bool = True, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.dtype = dtype |
| self.scale_sqrt_depth = scale_sqrt_depth |
|
|
| |
| self.transformer_layers = config['num_layers'] |
| self.embedding_dim = config['hidden_size'] |
| self.transformer_width = config['hidden_size'] |
| self.mlp_dim = config['mlp_dim'] |
| self.transformer_heads = config['num_heads'] |
|
|
| self.token_embedding = nn.Embedding( |
| self.vocab_size, self.embedding_dim, dtype=self.dtype |
| ) |
| self.pos_embedder = PositionalEmbedding(embedding_dim=self.embedding_dim) |
| self.transformer = Transformer( |
| width=self.transformer_width, |
| layers=self.transformer_layers, |
| heads=self.transformer_heads, |
| mlp_dim=self.mlp_dim, |
| dtype=self.dtype, |
| ) |
| self.pooling = GlobalAvgPooling(pooling_dims=[1]) |
| self.ln_final = nn.LayerNorm(self.transformer_width, dtype=self.dtype) |
|
|
| def __call__( |
| self, |
| ids: torch.tensor, |
| paddings: torch.tensor, |
| ): |
| """Applies TextEncoder module.""" |
| _, seq_length = ids.shape |
| mask = (paddings == 0).type(torch.float32) |
| mask = mask.permute(1, 0) |
| x = self.token_embedding(ids) |
| if self.scale_sqrt_depth: |
| x = x * (self.embedding_dim**0.5) |
| x = x + self.pos_embedder(seq_length=seq_length).to(x.device) |
| x = x.permute(1, 0, 2) |
| x = self.transformer(x, mask) |
| x = x.permute(1, 0, 2) |
| x = self.ln_final(x) |
| x = self.pooling(x, compatible_paddings=paddings[:, :, None]) |
| return x |
|
|