Text Ranking
sentence-transformers
Safetensors
cross-encoder
reranker
Generated from Trainer
dataset_size:3190
loss:ListNetLoss
custom_code
Instructions to use Pranjal2002/jina_finance_v2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use Pranjal2002/jina_finance_v2 with sentence-transformers:
from sentence_transformers import CrossEncoder model = CrossEncoder("Pranjal2002/jina_finance_v2", trust_remote_code=True) query = "Which planet is known as the Red Planet?" passages = [ "Venus is often called Earth's twin because of its similar size and proximity.", "Mars, known for its reddish appearance, is often referred to as the Red Planet.", "Jupiter, the largest planet in our solar system, has a prominent red spot.", "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." ] scores = model.predict([(query, passage) for passage in passages]) print(scores) - Notebooks
- Google Colab
- Kaggle
| # Copyright (c) 2023, Tri Dao. | |
| # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556 | |
| import math | |
| from functools import partial | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange, repeat | |
| try: | |
| from flash_attn import ( | |
| flash_attn_kvpacked_func, | |
| flash_attn_qkvpacked_func, | |
| flash_attn_varlen_kvpacked_func, | |
| flash_attn_varlen_qkvpacked_func, | |
| flash_attn_with_kvcache, | |
| ) | |
| except ImportError: | |
| flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None | |
| flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None | |
| flash_attn_with_kvcache = None | |
| try: | |
| from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear | |
| except ImportError: | |
| FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None | |
| class FlashSelfAttention(nn.Module): | |
| """Implement the scaled dot product attention with softmax. | |
| Arguments | |
| --------- | |
| softmax_scale: The temperature to use for the softmax attention. | |
| (default: 1/sqrt(d_keys) where d_keys is computed at | |
| runtime) | |
| attention_dropout: The dropout rate to apply to the attention | |
| (default: 0.0) | |
| """ | |
| def __init__( | |
| self, | |
| causal=False, | |
| softmax_scale=None, | |
| attention_dropout=0.0, | |
| window_size=(-1, -1), | |
| deterministic=False, | |
| ): | |
| super().__init__() | |
| assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" | |
| assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" | |
| self.causal = causal | |
| self.softmax_scale = softmax_scale | |
| self.drop = nn.Dropout(attention_dropout) | |
| self.window_size = window_size | |
| self.deterministic = deterministic | |
| def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): | |
| """Implements the multihead softmax attention. | |
| Arguments | |
| --------- | |
| qkv: The tensor containing the query, key, and value. | |
| If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). | |
| If cu_seqlens is not None and max_seqlen is not None, then qkv has shape | |
| (total, 3, H, D), where total is the sum of the sequence lengths in the batch. | |
| causal: if passed, will override self.causal | |
| cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths | |
| of the sequences in the batch, used to index into qkv. | |
| max_seqlen: int. Maximum sequence length in the batch. | |
| Returns: | |
| -------- | |
| out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, | |
| else (B, S, H, D). | |
| """ | |
| assert qkv.dtype in [torch.float16, torch.bfloat16] | |
| assert qkv.is_cuda | |
| causal = self.causal if causal is None else causal | |
| unpadded = cu_seqlens is not None | |
| if unpadded: | |
| assert cu_seqlens.dtype == torch.int32 | |
| assert max_seqlen is not None | |
| assert isinstance(max_seqlen, int) | |
| return flash_attn_varlen_qkvpacked_func( | |
| qkv, | |
| cu_seqlens, | |
| max_seqlen, | |
| self.drop.p if self.training else 0.0, | |
| softmax_scale=self.softmax_scale, | |
| causal=causal, | |
| alibi_slopes=None, | |
| window_size=self.window_size, | |
| deterministic=self.deterministic, | |
| ) | |
| else: | |
| return flash_attn_qkvpacked_func( | |
| qkv, | |
| self.drop.p if self.training else 0.0, | |
| softmax_scale=self.softmax_scale, | |
| causal=causal, | |
| alibi_slopes=None, | |
| window_size=self.window_size, | |
| deterministic=self.deterministic, | |
| ) | |
| class FlashCrossAttention(nn.Module): | |
| """Implement the scaled dot product attention with softmax. | |
| Arguments | |
| --------- | |
| softmax_scale: The temperature to use for the softmax attention. | |
| (default: 1/sqrt(d_keys) where d_keys is computed at | |
| runtime) | |
| attention_dropout: The dropout rate to apply to the attention | |
| (default: 0.0) | |
| """ | |
| def __init__( | |
| self, | |
| causal=False, | |
| softmax_scale=None, | |
| attention_dropout=0.0, | |
| window_size=(-1, -1), | |
| deterministic=False, | |
| ): | |
| super().__init__() | |
| assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" | |
| assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" | |
| self.causal = causal | |
| self.softmax_scale = softmax_scale | |
| self.drop = nn.Dropout(attention_dropout) | |
| self.window_size = window_size | |
| self.deterministic = deterministic | |
| def forward( | |
| self, | |
| q, | |
| kv, | |
| causal=None, | |
| cu_seqlens=None, | |
| max_seqlen=None, | |
| cu_seqlens_k=None, | |
| max_seqlen_k=None, | |
| ): | |
| """Implements the multihead softmax attention. | |
| Arguments | |
| --------- | |
| q: The tensor containing the query. (B, Sq, H, D) | |
| kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) | |
| causal: if passed, will override self.causal | |
| cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths | |
| of the sequences in the batch, used to index into q. | |
| max_seqlen: int. Maximum sequence length in the batch of q. | |
| cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths | |
| of the sequences in the batch, used to index into kv. | |
| max_seqlen_k: int. Maximum sequence length in the batch of k and v. | |
| """ | |
| assert q.dtype in [torch.float16, torch.bfloat16] | |
| assert q.is_cuda and kv.is_cuda | |
| causal = self.causal if causal is None else causal | |
| unpadded = cu_seqlens is not None | |
| if unpadded: | |
| assert cu_seqlens.dtype == torch.int32 | |
| assert max_seqlen is not None | |
| assert isinstance(max_seqlen, int) | |
| assert cu_seqlens_k is not None | |
| assert cu_seqlens_k.dtype == torch.int32 | |
| assert max_seqlen_k is not None | |
| assert isinstance(max_seqlen, int) | |
| return flash_attn_varlen_kvpacked_func( | |
| q, | |
| kv, | |
| cu_seqlens, | |
| cu_seqlens_k, | |
| max_seqlen, | |
| max_seqlen_k, | |
| self.drop.p if self.training else 0.0, | |
| softmax_scale=self.softmax_scale, | |
| causal=causal, | |
| alibi_slopes=None, | |
| window_size=self.window_size, | |
| deterministic=self.deterministic, | |
| ) | |
| else: | |
| batch_size, seqlen_q = q.shape[0], q.shape[1] | |
| seqlen_k = kv.shape[1] | |
| assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] | |
| return flash_attn_kvpacked_func( | |
| q, | |
| kv, | |
| self.drop.p if self.training else 0.0, | |
| causal=causal, | |
| softmax_scale=self.softmax_scale, | |
| alibi_slopes=None, | |
| window_size=self.window_size, | |
| deterministic=self.deterministic, | |
| ) | |
| class SelfAttention(nn.Module): | |
| """Implement the scaled dot product attention with softmax. | |
| Arguments | |
| --------- | |
| softmax_scale: The temperature to use for the softmax attention. | |
| (default: 1/sqrt(d_keys) where d_keys is computed at | |
| runtime) | |
| attention_dropout: The dropout rate to apply to the attention | |
| (default: 0.0) | |
| """ | |
| def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): | |
| super().__init__() | |
| self.causal = causal | |
| self.softmax_scale = softmax_scale | |
| self.drop = nn.Dropout(attention_dropout) | |
| def forward(self, qkv, causal=None, key_padding_mask=None): | |
| """Implements the multihead softmax attention. | |
| Arguments | |
| --------- | |
| qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) | |
| causal: if passed, will override self.causal | |
| key_padding_mask: boolean mask to apply to the attention weights. True means to keep, | |
| False means to mask out. (B, S) | |
| """ | |
| batch_size, seqlen = qkv.shape[0], qkv.shape[1] | |
| causal = self.causal if causal is None else causal | |
| q, k, v = qkv.unbind(dim=2) | |
| softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) | |
| scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) | |
| if key_padding_mask is not None: | |
| padding_mask = torch.full( | |
| (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device | |
| ) | |
| padding_mask.masked_fill_(key_padding_mask, 0.0) | |
| # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) | |
| scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") | |
| if causal: | |
| # "triu_tril_cuda_template" not implemented for 'BFloat16' | |
| # So we have to construct the mask in float | |
| causal_mask = torch.triu( | |
| torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 | |
| ) | |
| # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) | |
| scores = scores + causal_mask.to(dtype=scores.dtype) | |
| attention = torch.softmax(scores, dim=-1, dtype=v.dtype) | |
| attention_drop = self.drop(attention) | |
| output = torch.einsum("bhts,bshd->bthd", attention_drop, v) | |
| return output | |
| class CrossAttention(nn.Module): | |
| """Implement the scaled dot product attention with softmax. | |
| Arguments | |
| --------- | |
| softmax_scale: The temperature to use for the softmax attention. | |
| (default: 1/sqrt(d_keys) where d_keys is computed at | |
| runtime) | |
| attention_dropout: The dropout rate to apply to the attention | |
| (default: 0.0) | |
| """ | |
| def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): | |
| super().__init__() | |
| self.causal = causal | |
| self.softmax_scale = softmax_scale | |
| self.drop = nn.Dropout(attention_dropout) | |
| def forward(self, q, kv, causal=None, key_padding_mask=None): | |
| """Implements the multihead softmax attention. | |
| Arguments | |
| --------- | |
| q: The tensor containing the query. (B, Sq, H, D) | |
| kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) | |
| causal: if passed, will override self.causal | |
| key_padding_mask: boolean mask to apply to the attention weights. True means to keep, | |
| False means to mask out. (B, Sk) | |
| """ | |
| batch_size, seqlen_q = q.shape[0], q.shape[1] | |
| causal = self.causal if causal is None else causal | |
| seqlen_k = kv.shape[1] | |
| assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] | |
| if kv.shape[3] != q.shape[2]: # MQA/GQA | |
| kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) | |
| k, v = kv.unbind(dim=2) | |
| softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) | |
| scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) | |
| if key_padding_mask is not None: | |
| padding_mask = torch.full( | |
| (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device | |
| ) | |
| padding_mask.masked_fill_(key_padding_mask, 0.0) | |
| # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) | |
| scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") | |
| if causal: | |
| # causal mask needs to take into account the difference between seqlen_q and seqlen_k | |
| row_idx = rearrange( | |
| torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1" | |
| ) | |
| col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) | |
| sk = ( | |
| seqlen_k | |
| if key_padding_mask is None | |
| else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") | |
| ) | |
| causal_mask = col_idx > row_idx + sk - seqlen_q | |
| scores = scores.masked_fill(causal_mask, -10000.0) | |
| attention = torch.softmax(scores, dim=-1, dtype=v.dtype) | |
| attention_drop = self.drop(attention) | |
| output = torch.einsum("bhts,bshd->bthd", attention_drop, v) | |
| return output | |
| class LinearResidual(nn.Linear): | |
| """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.""" | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return super().forward(input), input | |
| def _update_kv_cache(kv, inference_params, layer_idx): | |
| """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" | |
| # Pre-allocate memory for key-values for inference. | |
| num_heads, head_dim = kv.shape[-2:] | |
| if layer_idx not in inference_params.key_value_memory_dict: | |
| kv_cache = torch.empty( | |
| inference_params.max_batch_size, | |
| inference_params.max_seqlen, | |
| 2, | |
| num_heads, | |
| head_dim, | |
| dtype=kv.dtype, | |
| device=kv.device, | |
| ) | |
| inference_params.key_value_memory_dict[layer_idx] = kv_cache | |
| else: | |
| kv_cache = inference_params.key_value_memory_dict[layer_idx] | |
| # Adjust key and value for inference | |
| batch_start = inference_params.batch_size_offset | |
| batch_end = batch_start + kv.shape[0] | |
| sequence_start = inference_params.seqlen_offset | |
| sequence_end = sequence_start + kv.shape[1] | |
| assert batch_end <= kv_cache.shape[0] | |
| assert sequence_end <= kv_cache.shape[1] | |
| assert kv_cache is not None | |
| kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv | |
| return kv_cache[batch_start:batch_end, :sequence_end, ...] | |
| class MHA(nn.Module): | |
| """Multi-head self-attention and cross-attention""" | |
| def __init__( | |
| self, | |
| embed_dim, | |
| num_heads, | |
| num_heads_kv=None, | |
| cross_attn=False, | |
| qkv_proj_bias=True, | |
| out_proj_bias=True, | |
| dropout=0.0, | |
| softmax_scale=None, | |
| causal=False, | |
| layer_idx=None, | |
| dwconv=False, | |
| window_size=(-1, -1), | |
| fused_bias_fc=False, | |
| use_flash_attn=False, | |
| return_residual=False, | |
| checkpointing=False, | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| """ | |
| num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. | |
| return_residual: whether to return the input x along with the output. This is for | |
| performance reason: for post-norm architecture, returning the input allows us | |
| to fuse the backward of nn.Linear with the residual connection. | |
| """ | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.cross_attn = cross_attn | |
| self.causal = causal | |
| self.layer_idx = layer_idx | |
| self.dwconv = dwconv | |
| self.use_flash_attn = use_flash_attn | |
| self.return_residual = return_residual | |
| self.checkpointing = checkpointing | |
| if window_size != (-1, -1): | |
| assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" | |
| self.num_heads = num_heads | |
| self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads | |
| assert ( | |
| self.num_heads % self.num_heads_kv == 0 | |
| ), "num_heads must be divisible by num_heads_kv" | |
| assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" | |
| self.head_dim = self.embed_dim // num_heads | |
| qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) | |
| kv_dim = 2 * self.head_dim * self.num_heads_kv | |
| if fused_bias_fc and FusedDense is None: | |
| raise ImportError("fused_dense is not installed") | |
| linear_cls = nn.Linear if not fused_bias_fc else FusedDense | |
| linear_resid_cls = ( | |
| LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) | |
| ) | |
| wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls | |
| inner_attn_cls = ( | |
| partial(FlashSelfAttention, window_size=window_size) | |
| if use_flash_attn | |
| else SelfAttention | |
| ) | |
| inner_cross_attn_cls = ( | |
| partial(FlashCrossAttention, window_size=window_size) | |
| if use_flash_attn | |
| else CrossAttention | |
| ) | |
| if not self.cross_attn: | |
| self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) | |
| else: | |
| self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) | |
| self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) | |
| if self.dwconv: | |
| if self.num_heads_kv == self.num_heads: | |
| self.dwconv_qkv = nn.Conv1d( | |
| qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim | |
| ) | |
| else: | |
| self.dwconv_q = nn.Conv1d( | |
| embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim | |
| ) | |
| self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim) | |
| self.inner_attn = inner_attn_cls( | |
| causal=causal, | |
| softmax_scale=softmax_scale, | |
| attention_dropout=dropout, | |
| ) | |
| self.inner_cross_attn = inner_cross_attn_cls( | |
| causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout | |
| ) | |
| self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) | |
| def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): | |
| dtype = self.out_proj.weight.dtype if dtype is None else dtype | |
| device = self.out_proj.weight.device | |
| return torch.empty( | |
| batch_size, | |
| max_seqlen, | |
| 2, | |
| self.num_heads_kv, | |
| self.head_dim, | |
| dtype=dtype, | |
| device=device, | |
| ) | |
| def _update_kv_cache(self, kv, inference_params): | |
| """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" | |
| assert not self.dwconv, "Generation does not support dwconv yet" | |
| assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" | |
| return _update_kv_cache(kv, inference_params, self.layer_idx) | |
| def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): | |
| """ | |
| Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. | |
| q: (batch_size, seqlen_q, nheads, head_dim) | |
| kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) | |
| """ | |
| assert inference_params is not None and inference_params.seqlen_offset > 0 | |
| assert self.use_flash_attn | |
| batch = q.shape[0] | |
| kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] | |
| cache_seqlens = ( | |
| inference_params.lengths_per_sample[:batch] | |
| if inference_params.lengths_per_sample is not None | |
| else inference_params.seqlen_offset | |
| ) | |
| context = flash_attn_with_kvcache( | |
| q, | |
| kv_cache[:, :, 0], | |
| kv_cache[:, :, 1], | |
| kv[:, :, 0], | |
| kv[:, :, 1], | |
| cache_seqlens=cache_seqlens, | |
| softmax_scale=self.inner_cross_attn.softmax_scale, | |
| causal=self.inner_cross_attn.causal, | |
| rotary_interleaved=False, | |
| alibi_slopes=None, | |
| ) | |
| return context | |
| def _update_kvcache_attention(self, q, kv, inference_params): | |
| """Write kv to inference_params, then do attention""" | |
| if ( | |
| inference_params.seqlen_offset == 0 | |
| or flash_attn_with_kvcache is None | |
| or not self.use_flash_attn | |
| ): | |
| # TODO: this only uses seqlen_offset and not lengths_per_sample. | |
| kv = self._update_kv_cache(kv, inference_params) | |
| return self.inner_cross_attn(q, kv) | |
| else: | |
| batch = q.shape[0] | |
| kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] | |
| cache_seqlens = ( | |
| inference_params.lengths_per_sample[:batch] | |
| if inference_params.lengths_per_sample is not None | |
| else inference_params.seqlen_offset | |
| ) | |
| return flash_attn_with_kvcache( | |
| q, | |
| kv_cache[:, :, 0], | |
| kv_cache[:, :, 1], | |
| kv[:, :, 0], | |
| kv[:, :, 1], | |
| cache_seqlens=cache_seqlens, | |
| softmax_scale=self.inner_cross_attn.softmax_scale, | |
| causal=self.inner_cross_attn.causal, | |
| alibi_slopes=None, | |
| ) | |
| def forward( | |
| self, | |
| x, | |
| x_kv=None, | |
| key_padding_mask=None, | |
| cu_seqlens=None, | |
| max_seqlen=None, | |
| mixer_subset=None, | |
| inference_params=None, | |
| **kwargs, | |
| ): | |
| """ | |
| Arguments: | |
| x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if | |
| cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total | |
| is the is the sum of the sequence lengths in the batch. | |
| x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. | |
| cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths | |
| of the sequences in the batch, used to index into x. Only applicable when using | |
| FlashAttention. | |
| max_seqlen: int. Maximum sequence length in the batch. | |
| key_padding_mask: boolean mask, True means to keep, False means to mask out. | |
| (batch, seqlen). Only applicable when not using FlashAttention. | |
| mixer_subset: for cross-attention only. If not None, will take a subset of x | |
| before applying the query projection. Useful for e.g., ViT where we only care | |
| about the CLS token in the last layer. | |
| inference_params: for generation. Adapted from Megatron-LM (and Apex) | |
| https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 | |
| """ | |
| if cu_seqlens is not None: | |
| assert max_seqlen is not None | |
| assert key_padding_mask is None | |
| assert self.use_flash_attn | |
| assert not self.dwconv | |
| if key_padding_mask is not None: | |
| assert cu_seqlens is None | |
| assert max_seqlen is None | |
| assert not self.use_flash_attn | |
| if inference_params is not None: | |
| assert key_padding_mask is None | |
| assert cu_seqlens is None and max_seqlen is None | |
| assert not self.dwconv | |
| kwargs = ( | |
| {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs} | |
| if self.use_flash_attn | |
| else {"key_padding_mask": key_padding_mask, **kwargs} | |
| ) | |
| seqlen_offset = ( | |
| 0 | |
| if inference_params is None | |
| else ( | |
| inference_params.lengths_per_sample | |
| if inference_params.lengths_per_sample is not None | |
| else inference_params.seqlen_offset | |
| ) | |
| ) | |
| rotary_max_seqlen = ( | |
| inference_params.max_sequence_len if inference_params is not None else max_seqlen | |
| ) | |
| batch, seqlen = x.shape[:2] | |
| if not self.cross_attn and self.num_heads_kv == self.num_heads: | |
| assert x_kv is None and mixer_subset is None | |
| if not self.return_residual: | |
| qkv = self.Wqkv(x) | |
| else: | |
| qkv, x = self.Wqkv(x) | |
| if self.dwconv: | |
| qkv = rearrange( | |
| self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" | |
| ).contiguous() | |
| qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) | |
| if ( | |
| inference_params is None | |
| or inference_params.seqlen_offset == 0 | |
| or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) | |
| or not self.use_flash_attn | |
| ): | |
| if inference_params is None: | |
| if not self.checkpointing: | |
| context = self.inner_attn(qkv, **kwargs) | |
| else: | |
| context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) | |
| else: | |
| context = self._update_kvcache_attention( | |
| qkv[:, :, 0], qkv[:, :, 1:], inference_params | |
| ) | |
| else: | |
| context = self._apply_rotary_update_kvcache_attention( | |
| qkv[:, :, 0], qkv[:, :, 1:], inference_params | |
| ) | |
| else: | |
| if self.cross_attn: | |
| if not self.return_residual: | |
| q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) | |
| kv = self.Wkv(x_kv if x_kv is not None else x) | |
| else: | |
| if x_kv is not None: | |
| kv, x_kv = self.Wkv(x_kv) | |
| else: | |
| kv, x = self.Wkv(x) | |
| q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) | |
| else: | |
| assert self.num_heads_kv != self.num_heads | |
| if not self.return_residual: | |
| qkv = self.Wqkv(x) | |
| else: | |
| qkv, x = self.Wqkv(x) | |
| q = qkv[..., : self.num_heads * self.head_dim] | |
| kv = qkv[..., self.num_heads * self.head_dim :] | |
| q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) | |
| kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) | |
| if self.dwconv: | |
| q = rearrange( | |
| self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" | |
| ).contiguous() | |
| kv = rearrange( | |
| self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" | |
| ).contiguous() | |
| if ( | |
| inference_params is None | |
| or inference_params.seqlen_offset == 0 | |
| or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) | |
| or not self.use_flash_attn | |
| ): | |
| if inference_params is None: | |
| if not self.checkpointing: | |
| context = self.inner_cross_attn(q, kv, **kwargs) | |
| else: | |
| context = torch.utils.checkpoint.checkpoint( | |
| self.inner_cross_attn, q, kv, **kwargs | |
| ) | |
| else: | |
| context = self._update_kvcache_attention(q, kv, inference_params) | |
| else: | |
| context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) | |
| out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) | |
| return out if not self.return_residual else (out, x) | |