| import torch
|
|
|
|
|
| def subsequent_mask(
|
| size: int,
|
| device: torch.device = torch.device("cpu"),
|
| ) -> torch.Tensor:
|
| """Create mask for subsequent steps (size, size).
|
|
|
| This mask is used only in decoder which works in an auto-regressive mode.
|
| This means the current step could only do attention with its left steps.
|
|
|
| In encoder, fully attention is used when streaming is not necessary and
|
| the sequence is not long. In this case, no attention mask is needed.
|
|
|
| When streaming is need, chunk-based attention is used in encoder. See
|
| subsequent_chunk_mask for the chunk-based attention mask.
|
|
|
| Args:
|
| size (int): size of mask
|
| str device (str): "cpu" or "cuda" or torch.Tensor.device
|
| dtype (torch.device): result dtype
|
|
|
| Returns:
|
| torch.Tensor: mask
|
|
|
| Examples:
|
| >>> subsequent_mask(3)
|
| [[1, 0, 0],
|
| [1, 1, 0],
|
| [1, 1, 1]]
|
| """
|
| arange = torch.arange(size, device=device)
|
| mask = arange.expand(size, size)
|
| arange = arange.unsqueeze(-1)
|
| mask = mask <= arange
|
| return mask
|
|
|
|
|
| def subsequent_chunk_mask(
|
| size: int,
|
| chunk_size: int,
|
| num_left_chunks: int = -1,
|
| device: torch.device = torch.device("cpu"),
|
| ) -> torch.Tensor:
|
| """Create mask for subsequent steps (size, size) with chunk size,
|
| this is for streaming encoder
|
|
|
| Args:
|
| size (int): size of mask
|
| chunk_size (int): size of chunk
|
| num_left_chunks (int): number of left chunks
|
| <0: use full chunk
|
| >=0: use num_left_chunks
|
| device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
|
|
| Returns:
|
| torch.Tensor: mask
|
|
|
| Examples:
|
| >>> subsequent_chunk_mask(4, 2)
|
| [[1, 1, 0, 0],
|
| [1, 1, 0, 0],
|
| [1, 1, 1, 1],
|
| [1, 1, 1, 1]]
|
| """
|
| ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
| for i in range(size):
|
| if num_left_chunks < 0:
|
| start = 0
|
| else:
|
| start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
| ending = min((i // chunk_size + 1) * chunk_size, size)
|
| ret[i, start:ending] = True
|
| return ret
|
|
|
|
|
| def add_optional_chunk_mask(xs: torch.Tensor,
|
| masks: torch.Tensor,
|
| use_dynamic_chunk: bool,
|
| use_dynamic_left_chunk: bool,
|
| decoding_chunk_size: int,
|
| static_chunk_size: int,
|
| num_decoding_left_chunks: int,
|
| enable_full_context: bool = True):
|
| """ Apply optional mask for encoder.
|
|
|
| Args:
|
| xs (torch.Tensor): padded input, (B, L, D), L for max length
|
| mask (torch.Tensor): mask for xs, (B, 1, L)
|
| use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
| use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
| training.
|
| decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
| 0: default for training, use random dynamic chunk.
|
| <0: for decoding, use full chunk.
|
| >0: for decoding, use fixed chunk size as set.
|
| static_chunk_size (int): chunk size for static chunk training/decoding
|
| if it's greater than 0, if use_dynamic_chunk is true,
|
| this parameter will be ignored
|
| num_decoding_left_chunks: number of left chunks, this is for decoding,
|
| the chunk size is decoding_chunk_size.
|
| >=0: use num_decoding_left_chunks
|
| <0: use all left chunks
|
| enable_full_context (bool):
|
| True: chunk size is either [1, 25] or full context(max_len)
|
| False: chunk size ~ U[1, 25]
|
|
|
| Returns:
|
| torch.Tensor: chunk mask of the input xs.
|
| """
|
|
|
| if use_dynamic_chunk:
|
| max_len = xs.size(1)
|
| if decoding_chunk_size < 0:
|
| chunk_size = max_len
|
| num_left_chunks = -1
|
| elif decoding_chunk_size > 0:
|
| chunk_size = decoding_chunk_size
|
| num_left_chunks = num_decoding_left_chunks
|
| else:
|
|
|
|
|
|
|
| chunk_size = torch.randint(1, max_len, (1, )).item()
|
| num_left_chunks = -1
|
| if chunk_size > max_len // 2 and enable_full_context:
|
| chunk_size = max_len
|
| else:
|
| chunk_size = chunk_size % 25 + 1
|
| if use_dynamic_left_chunk:
|
| max_left_chunks = (max_len - 1) // chunk_size
|
| num_left_chunks = torch.randint(0, max_left_chunks,
|
| (1, )).item()
|
| chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
| num_left_chunks,
|
| xs.device)
|
| chunk_masks = chunk_masks.unsqueeze(0)
|
| chunk_masks = masks & chunk_masks
|
| elif static_chunk_size > 0:
|
| num_left_chunks = num_decoding_left_chunks
|
| chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
| num_left_chunks,
|
| xs.device)
|
| chunk_masks = chunk_masks.unsqueeze(0)
|
| chunk_masks = masks & chunk_masks
|
| else:
|
| chunk_masks = masks
|
| return chunk_masks
|
|
|
|
|
| def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| """Make mask tensor containing indices of padded part.
|
|
|
| See description of make_non_pad_mask.
|
|
|
| Args:
|
| lengths (torch.Tensor): Batch of lengths (B,).
|
| Returns:
|
| torch.Tensor: Mask tensor containing indices of padded part.
|
|
|
| Examples:
|
| >>> lengths = [5, 3, 2]
|
| >>> make_pad_mask(lengths)
|
| masks = [[0, 0, 0, 0 ,0],
|
| [0, 0, 0, 1, 1],
|
| [0, 0, 1, 1, 1]]
|
| """
|
| batch_size = lengths.size(0)
|
| max_len = max_len if max_len > 0 else lengths.max().item()
|
| seq_range = torch.arange(0,
|
| max_len,
|
| dtype=torch.int64,
|
| device=lengths.device)
|
| seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
| seq_length_expand = lengths.unsqueeze(-1)
|
| mask = seq_range_expand >= seq_length_expand
|
| return mask
|
|
|