| from typing import Tuple |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def resize(x_tensor, new_shape): |
| return F.interpolate(x_tensor.unsqueeze(0), size=new_shape, mode='linear').squeeze(0) |
|
|
|
|
| def resample(old: torch.Tensor, new_patch_len: int): |
| assert old.dim() == 2, "the size of input tensor should be (d_model, patch_size)" |
| if old.size(1) == new_patch_len: |
| return old |
|
|
| old = old.T |
| old_shape = old.size(0) |
| factor = new_patch_len / old_shape |
|
|
| basis_vectors = torch.eye(old_shape, dtype=torch.get_default_dtype(), device=old.device) |
| resize_mat = resize(basis_vectors, new_patch_len).T |
| resize_mat_pinv = torch.linalg.pinv(resize_mat.T) |
|
|
| resampled_kernels = resize_mat_pinv @ old * math.sqrt(factor) |
|
|
| return resampled_kernels.T |
|
|
|
|
| def RoPE(query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Apply Rotary Position Embedding (RoPE) to the query and key tensors. |
| |
| Args: |
| query (torch.Tensor): Query tensor with shape (bs, head, max_len, output_dim). |
| key (torch.Tensor): Key tensor with shape (bs, head, max_len, output_dim). |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: Query and key tensors after applying RoPE. |
| """ |
| |
| batch_size, num_heads, max_len, output_dim = query.shape |
| |
| pos_emb = sinusoidal_position_embedding(batch_size, num_heads, max_len, output_dim, query.device, factor=1) |
|
|
| |
| cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) |
| sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) |
|
|
| |
| query_rot = torch.stack([-query[..., 1::2], query[..., ::2]], dim=-1).reshape(query.shape) |
| query = query * cos_pos + query_rot * sin_pos |
|
|
| |
| key_rot = torch.stack([-key[..., 1::2], key[..., ::2]], dim=-1).reshape(key.shape) |
| key = key * cos_pos + key_rot * sin_pos |
|
|
| return query, key |
|
|
|
|
| def RoPE_decoder(query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Apply Rotary Position Embedding (RoPE) to the query and key tensors in the decoder. |
| |
| Args: |
| query (torch.Tensor): Query tensor with shape (bs, head, q_max_len, output_dim). |
| key (torch.Tensor): Key tensor with shape (bs, head, k_max_len, output_dim). |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: Query and key tensors after applying RoPE. |
| """ |
| |
| batch_size, num_heads, q_max_len, output_dim = query.shape |
| _, _, k_max_len, _ = key.shape |
| |
| pos_emb = sinusoidal_position_embedding(batch_size, num_heads, k_max_len + q_max_len, output_dim, query.device, |
| factor=1) |
|
|
| |
| cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) |
| sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) |
|
|
| |
| query_rot = torch.stack([-query[..., 1::2], query[..., ::2]], dim=-1).reshape(query.shape) |
| query = query * cos_pos[:, :, -q_max_len:, :] + query_rot * sin_pos[:, :, -q_max_len:, :] |
|
|
| |
| key_rot = torch.stack([-key[..., 1::2], key[..., ::2]], dim=-1).reshape(key.shape) |
| key = key * cos_pos[:, :, :k_max_len, :] + key_rot * sin_pos[:, :, :k_max_len, :] |
|
|
| return query, key |
|
|
|
|
| def sinusoidal_position_embedding( |
| batch_size: int, |
| num_heads: int, |
| max_len: int, |
| output_dim: int, |
| device: torch.device, |
| factor: float = 1.0 |
| ) -> torch.Tensor: |
| """ |
| Generate sinusoidal position embeddings. |
| |
| Args: |
| batch_size (int): Batch size. |
| num_heads (int): Number of attention heads. |
| max_len (int): Maximum sequence length. |
| output_dim (int): Output dimension. |
| device (torch.device): Device type. |
| factor (float, optional): Scaling factor. Defaults to 1.0. |
| |
| Returns: |
| torch.Tensor: Sinusoidal position embedding tensor with shape (bs, head, max_len, output_dim). |
| """ |
| |
| position = torch.arange(0, max_len * factor, 1 / factor, dtype=torch.float).unsqueeze(-1) |
| |
| ids = torch.arange(0, output_dim // 2, dtype=torch.float) |
| theta = torch.pow(10000, -2 * ids / output_dim) |
|
|
| |
| embeddings = position * theta |
| embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) |
|
|
| |
| embeddings = embeddings.repeat((batch_size, num_heads, *([1] * len(embeddings.shape)))) |
| embeddings = torch.reshape(embeddings, (batch_size, num_heads, -1, output_dim)) |
| embeddings = embeddings.to(device) |
|
|
| |
| if factor > 1.0: |
| interpolation_indices = torch.linspace(0, embeddings.shape[2] - 1, max_len).long() |
| embeddings = embeddings[:, :, interpolation_indices, :] |
|
|
| return embeddings |
|
|
|
|
| def causal_attention_mask(seq_length): |
| mask = torch.triu(torch.ones(seq_length, seq_length) * float('-inf'), diagonal=1) |
| return mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
| class Transpose(nn.Module): |
| def __init__(self, *dims, contiguous=False): |
| super().__init__() |
| self.dims, self.contiguous = dims, contiguous |
|
|
| def forward(self, x): |
| if self.contiguous: |
| return x.transpose(*self.dims).contiguous() |
| else: |
| return x.transpose(*self.dims) |
|
|