| import torch |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from transformers import Cache, GenerationConfig |
|
|
|
|
| UNSUPPORTED_GENERATION_ARGS = [ |
| "cache_implementation", |
| "cache_config", |
| "return_legacy_cache", |
| "num_beams", |
| "compile_config", |
| "assistant_model", |
| ] |
|
|
| class SinkCache(Cache): |
| """ |
| A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to |
| generate beyond the length of its context window, without losing fluency in the conversation. As it discards past |
| tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. |
| |
| It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
| `[batch_size, num_heads, seq_len, head_dim]`. |
| |
| This class was copied from transformers 4.52.0, with minor modifications. |
| |
| Parameters: |
| window_length (`int`): |
| The length of the context window. |
| num_sink_tokens (`int`): |
| The number of sink tokens. See the original paper for more information. |
| """ |
|
|
| def __init__(self, window_length: int, num_sink_tokens: int) -> None: |
| super().__init__() |
| self.key_cache: List[torch.Tensor] = [] |
| self.value_cache: List[torch.Tensor] = [] |
| self.window_length = window_length |
| self.num_sink_tokens = num_sink_tokens |
| self.cos_sin_rerotation_cache = {} |
| self._cos_cache = None |
| self._sin_cache = None |
|
|
| @staticmethod |
| def _rotate_half(x): |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
| def _apply_key_rotary_pos_emb( |
| self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| ) -> torch.Tensor: |
| rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) |
| return rotated_key_states |
|
|
| def _get_rerotation_cos_sin( |
| self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| if key_states.shape[-2] not in self.cos_sin_rerotation_cache: |
| |
| cos = cos.to(torch.float32) |
| sin = sin.to(torch.float32) |
|
|
| |
| original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] |
| shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] |
| original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] |
| shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] |
| rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin |
| rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin |
|
|
| self.cos_sin_rerotation_cache[key_states.shape[-2]] = ( |
| rerotation_cos.to(key_states.dtype).unsqueeze(0), |
| rerotation_sin.to(key_states.dtype).unsqueeze(0), |
| ) |
| return self.cos_sin_rerotation_cache[key_states.shape[-2]] |
|
|
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
| if len(self.key_cache) <= layer_idx: |
| return 0 |
| return self.key_cache[layer_idx].shape[-2] |
|
|
| def get_max_cache_shape(self) -> Optional[int]: |
| """Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length.""" |
| return self.window_length |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
| |
| Parameters: |
| key_states (`torch.Tensor`): |
| The new key states to cache. |
| value_states (`torch.Tensor`): |
| The new value states to cache. |
| layer_idx (`int`): |
| The index of the layer to cache the states for. |
| cache_kwargs (`Dict[str, Any]`, `optional`): |
| Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, |
| `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the |
| rotation as the tokens are shifted. |
| |
| Return: |
| A tuple containing the updated key and value states. |
| """ |
| |
| |
| if cache_kwargs is None: |
| cache_kwargs = {} |
| sin = cache_kwargs.get("sin") |
| cos = cache_kwargs.get("cos") |
| partial_rotation_size = cache_kwargs.get("partial_rotation_size") |
| using_rope = cos is not None and sin is not None |
|
|
| |
| if using_rope and layer_idx == 0: |
| |
| |
| if cos.dim() == 2: |
| self._cos_cache = cos |
| self._sin_cache = sin |
| else: |
| if self._cos_cache is None: |
| self._cos_cache = cos[0, ...] |
| self._sin_cache = sin[0, ...] |
| elif self._cos_cache.shape[0] < self.window_length: |
| self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0) |
| self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0) |
|
|
| |
| if len(self.key_cache) <= layer_idx: |
| |
| self.key_cache.append(key_states) |
| self.value_cache.append(value_states) |
|
|
| elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: |
| |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
| else: |
| |
| keys_to_keep = self.key_cache[layer_idx][ |
| :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : |
| ] |
|
|
| |
| if using_rope: |
| rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( |
| key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length] |
| ) |
| if partial_rotation_size is not None: |
| keys_to_keep, keys_pass = ( |
| keys_to_keep[..., :partial_rotation_size], |
| keys_to_keep[..., partial_rotation_size:], |
| ) |
| keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) |
| if partial_rotation_size is not None: |
| keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) |
|
|
| |
| sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] |
| self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) |
|
|
| sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] |
| values_to_keep = self.value_cache[layer_idx][ |
| :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : |
| ] |
| self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) |
|
|
| return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
|
|
| def generate(model, window_length=256, num_sink_tokens=4, **kwargs): |
| """Custom generate function for SinkCache. |
| |
| Args: |
| model (`PreTrainedModel`): |
| The model to generate from. |
| window_length (`int`, *optional*, defaults to 256): |
| The length of the context window. |
| num_sink_tokens (`int`, *optional*, defaults to 4): |
| The number of sink tokens. See the original paper for more information. |
| """ |
| |
| |
| generation_config = kwargs.get("generation_config") |
| default_global_generation_config = GenerationConfig() |
| default_model_generation_config = model.generation_config |
| for arg in UNSUPPORTED_GENERATION_ARGS: |
| has_custom_gen_config_arg = ( |
| generation_config is not None |
| |
| and not ( |
| getattr(default_model_generation_config, arg) == getattr(generation_config, arg) |
| or getattr(default_global_generation_config, arg) == getattr(generation_config, arg) |
| ) |
| ) |
| kwargs_has_arg = arg in kwargs and kwargs[arg] is not None |
| if kwargs_has_arg or has_custom_gen_config_arg: |
| raise ValueError( |
| f"`{arg}` is set, but it's not supported in this custom generate function. List of " |
| f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}" |
| ) |
|
|
| |
| if model.config.is_encoder_decoder: |
| raise ValueError("This custom generate function only works with decoder-only models") |
|
|
| |
| |
| kwargs.pop("custom_generate", None) |
|
|
| |
| |
| past_key_values = kwargs.pop("past_key_values", None) |
| if past_key_values is None: |
| past_key_values = SinkCache(window_length=window_length, num_sink_tokens=num_sink_tokens) |
| elif not isinstance(past_key_values, SinkCache): |
| raise ValueError(f"`past_key_values` must be a `SinkCache` instance, got a {type(past_key_values)} instance") |
|
|
| |
| generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True) |
| return generation_outputs |
|
|