| import copy |
| import importlib.metadata |
| import json |
| import os |
| import warnings |
| from dataclasses import dataclass |
| from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
|
|
| import torch |
| from packaging import version |
|
|
| from transformers.utils import is_hqq_available, is_optimum_quanto_available, logging |
|
|
| from transformers.cache_utils import CacheConfig, QuantizedCacheConfig, QuantizedCache |
|
|
| if is_hqq_available(): |
| from hqq.core.quantize import Quantizer as HQQQuantizer |
|
|
| logger = logging.get_logger(__name__) |
|
|
| @dataclass |
| class SQuatCacheConfig(QuantizedCacheConfig): |
| """ |
| Configuration class for SQuat cache settings. |
| """ |
| def __init__(self, |
| quant_group_size: Optional[int] = 64, |
| squat_lambda: Optional[float] = 0.0001, |
| subspace_dim: Optional[int] = 5, |
| shared_svd: Optional[bool] = True, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.cache_implementation = "squat" |
| self.quant_group_size = quant_group_size |
| self.squat_lambda = squat_lambda |
| self.subspace_dim = subspace_dim |
| self.shared_svd = shared_svd |
|
|
|
|
| class SQuatCache(QuantizedCache): |
| """ |
| Quantized Cache class that uses `SQuat` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only. |
| |
| Parameters: |
| cache_config (`SQuatCacheConfig`): |
| A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size. |
| |
| Example: |
| |
| ```python |
| >>> # Run pip install quanto first if you don't have it yet |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SQuatCache, SQuatCacheConfig |
| |
| >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
| |
| >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") |
| |
| >>> # Prepare a cache class and pass it to model's forward |
| >>> cache_config = SQuatCacheConfig(nbits=4) |
| >>> past_key_values = SQuatCache(cache_config=cache_config) |
| >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
| >>> outputs.past_key_values # access cache filled with key/values from generation |
| SQuatCache() |
| ``` |
| """ |
|
|
| def __init__(self, cache_config: CacheConfig) -> None: |
| super().__init__(cache_config) |
|
|
| if is_optimum_quanto_available(): |
| optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) |
| if optimum_quanto_version <= version.parse("0.2.5"): |
| raise ImportError( |
| f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." |
| ) |
| from optimum.quanto import MaxOptimizer, qint2, qint4 |
|
|
| if self.nbits not in [2, 4]: |
| raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") |
|
|
| if self.axis_key not in [0, -1]: |
| raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") |
|
|
| if self.axis_value not in [0, -1]: |
| raise ValueError( |
| f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" |
| ) |
|
|
| self.qtype = qint4 if self.nbits == 4 else qint2 |
| self.optimizer = MaxOptimizer() |
|
|
| self.auxiliary_matrices_A = [] |
| self.auxiliary_matrices_P = [] |
| self.squat_lambda = getattr(cache_config, "squat_lambda", 0.0005) |
| self.squat_q_group_size = getattr(cache_config, "quant_group_size", 64) |
| self.squat_subspace_dim = getattr(cache_config, "subspace_dim", 20) |
| self.squat_shared_svd = getattr(cache_config, "shared_svd", True) |
| |
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| layer_idx: int, |
| cache_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| |
| if layer_idx == 0: |
| self._seen_tokens += key_states.shape[-2] |
|
|
| if len(self.key_cache) < layer_idx: |
| raise ValueError("SQuatCache does not support model usage where layers are skipped. Use DynamicCache.") |
| elif len(self.key_cache) == layer_idx: |
| if len(self.auxiliary_matrices_A) == layer_idx: |
| Ainv_t, P_inv = self._get_query_subspace(key_states, cache_kwargs["query_states"], cache_kwargs["attention_mask"]) |
| self.auxiliary_matrices_A.append(Ainv_t) |
| self.auxiliary_matrices_P.append(P_inv) |
|
|
| if key_states.shape[-2] % self.residual_length != 0: |
| if key_states.shape[-2] < self.residual_length: |
| key_states_quant = None |
| key_states_full = key_states |
| value_states_quant = None |
| value_states_full = value_states |
| else: |
| key_states_quant = key_states[:, :, :-(key_states.shape[-2] % self.residual_length), :].contiguous() |
| key_states_full = key_states[:, :, -(key_states.shape[-2] % self.residual_length):, :].contiguous() |
| value_states_quant = value_states[:, :, :-(value_states.shape[-2] % self.residual_length), :].contiguous() |
| value_states_full = value_states[:, :, -(value_states.shape[-2] % self.residual_length):, :].contiguous() |
| else: |
| key_states_quant = key_states |
| key_states_full = None |
| value_states_quant = value_states |
| value_states_full = None |
| if key_states_quant is not None: |
| self._quantized_key_cache.append(self.squat_quantize_key(key_states_quant, self.squat_q_group_size, Ainv_t, P_inv)) |
| self._quantized_value_cache.append(self._quantize(value_states_quant, axis=self.axis_value)) |
| else: |
| self._quantized_key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) |
| self._quantized_value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) |
| if key_states_full is not None: |
| self.key_cache.append(key_states_full) |
| self.value_cache.append(value_states_full) |
| else: |
| self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) |
| self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device)) |
| |
| keys_to_return, values_to_return = key_states, value_states |
|
|
| else: |
| if len(self._quantized_key_cache[layer_idx]) == 0: |
| dequant_key = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) |
| else: |
| dequant_key = self._dequantize(self._quantized_key_cache[layer_idx]) |
| if len(self._quantized_value_cache[layer_idx]) == 0: |
| dequant_value = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) |
| else: |
| dequant_value = self._dequantize(self._quantized_value_cache[layer_idx]) |
| keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states] |
| values_to_return = [dequant_value, self.value_cache[layer_idx], value_states] |
|
|
| keys_to_return = torch.cat(keys_to_return, dim=-2) |
| values_to_return = torch.cat(values_to_return, dim=-2) |
| if ( |
| self.key_cache[layer_idx].dim() == 4 |
| and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length |
| ): |
| keys_to_quantize = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| quantized_key = self.squat_quantize_key( |
| keys_to_quantize, self.squat_q_group_size, self.auxiliary_matrices_A[layer_idx], |
| self.auxiliary_matrices_P[layer_idx] |
| ) |
| self._quantized_key_cache[layer_idx] = self._quantize( |
| torch.cat([dequant_key, self._dequantize(quantized_key)], dim=2), axis=self.axis_key |
| ) |
| self._quantized_value_cache[layer_idx] = self._quantize( |
| values_to_return.contiguous(), axis=self.axis_value |
| ) |
| self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) |
| self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device) |
| else: |
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
| return keys_to_return, values_to_return |
|
|
| def _get_query_subspace(self, key_states, query_states, attention_mask=None): |
| bsz = query_states.shape[0] |
| kv_nh = key_states.shape[1] |
| head_dim = query_states.shape[3] |
| num_key_value_groups = query_states.shape[1] // key_states.shape[1] |
| subspace_dim = min(self.squat_subspace_dim, num_key_value_groups*key_states.shape[2]) |
|
|
| |
| if attention_mask is not None: |
| if attention_mask.shape[2] == attention_mask.shape[3]-1: |
| attention_mask = attention_mask[:,:,:,:attention_mask.shape[2]] |
| |
| last_row_mask = attention_mask[:, :, -1, :] |
| |
| valid_tokens = (last_row_mask == 0).squeeze(1) |
| |
| |
| query_subspace = [] |
| for b in range(bsz): |
| |
| batch_valid = valid_tokens[b] |
| |
| batch_query = query_states[b] |
| batch_valid_query = batch_query[:, batch_valid, :] |
|
|
| valid_query_states_matrix = batch_valid_query.reshape(kv_nh, -1, head_dim) |
| U, S, Vh = torch.linalg.svd(valid_query_states_matrix.float(), full_matrices=False) |
| S_subspace = torch.diag_embed(S[:, :subspace_dim]).to(valid_query_states_matrix.dtype) |
| Vh_subspace = Vh[:, :subspace_dim, :].to(valid_query_states_matrix.dtype) |
| batch_query_subspace = torch.matmul(S_subspace, Vh_subspace) |
|
|
| query_subspace.append(batch_query_subspace) |
| if self.squat_shared_svd: |
| break |
| |
| |
| query_subspace = torch.stack(query_subspace) |
| else: |
| query_states_matrix = query_states.reshape(bsz, kv_nh, -1, head_dim) |
| U, S, Vh = torch.linalg.svd(query_states_matrix.float(), full_matrices=False) |
| S_subspace = torch.diag_embed(S[:, :, :subspace_dim]).to(query_states_matrix.dtype) |
| Vh_subspace = Vh[:, :, :subspace_dim, :].to(query_states_matrix.dtype) |
|
|
| |
| query_subspace = torch.matmul(S_subspace, Vh_subspace) |
|
|
| if self.squat_shared_svd: |
| query_subspace = query_subspace[0:1, ...] |
|
|
| |
| Ainv_t = self._generate_At_inv(self.squat_q_group_size, query_subspace.float(), lamb=self.squat_lambda) |
| P_inv = torch.inverse(Ainv_t[-1]) |
|
|
| return Ainv_t, P_inv |
| |
| def _generate_At_inv(self, quant_group_size, my_Qhat, lamb=1, tol=1e-7): |
| """ |
| Generate a list of T matrices where the t-th matrix has dimension (t*g, t*g). |
| |
| Parameters: |
| - quant_group_size (int): Factor for matrix dimension scaling |
| - lamb (float): Scaling factor for the final term |
| - my_Qhat (torch.Tensor): A matrix of size (d, d) |
| |
| Returns: |
| - List[torch.Tensor]: List of int(head_dim/quant_group_size) matrices |
| """ |
|
|
| bs, kv_nh, subspace_dim, head_dim = my_Qhat.shape |
| T = (head_dim+quant_group_size-1)//quant_group_size |
| matrices = [None] * T |
| device = my_Qhat.device |
| I = torch.eye(head_dim, device=device) |
| |
| A_T = I.expand(bs, kv_nh, head_dim, head_dim) + lamb * torch.matmul( |
| my_Qhat.transpose(-1, -2), my_Qhat |
| ) |
| matrices[T - 1] = A_T |
|
|
| for t in range(T - 1, 0, -1): |
| current_dim = t * quant_group_size |
|
|
| |
| M_t1 = A_T[:, :, :current_dim, :current_dim] |
| N_t1 = A_T[:, :, current_dim : current_dim + quant_group_size, :current_dim] |
| O_t1 = A_T[:, :, current_dim : current_dim + quant_group_size, current_dim : current_dim + quant_group_size] |
|
|
| |
| I_mat = torch.eye(quant_group_size, device=device) |
| O_t1_inv = torch.inverse(O_t1 + tol * I_mat.expand(bs, kv_nh, quant_group_size, quant_group_size)) |
| A_t = M_t1 - torch.matmul(N_t1.transpose(-1, -2), torch.matmul(O_t1_inv, N_t1)) |
| matrices[t - 1] = A_t[:, :, :, -quant_group_size:] |
|
|
| |
| A_T = A_t |
| return matrices |
| |
| def squat_quantize_key(self, key_states, quant_group_size, Ainv_t, P_inv): |
|
|
| bsz, nh, seq_len, hidden_dim = key_states.shape |
| dtype = key_states.dtype |
| T = (hidden_dim+quant_group_size-1)//quant_group_size |
| key_states_dequant = [] |
| group = key_states |
| for i in range(T): |
| key_states_quant_this_quant_group = self._quantize( |
| group[:, :, :, i * quant_group_size : (i + 1) * quant_group_size].contiguous(), |
| axis=self.axis_key |
| ) |
| dequantized = self._dequantize(key_states_quant_this_quant_group) |
|
|
| if i < T - 1: |
| d_vec = ( |
| dequantized |
| - group[:, :, :, i * quant_group_size : (i + 1) * quant_group_size] |
| ).float() |
| H_t = Ainv_t[i] |
| B_t = P_inv[ |
| :, :, (i + 1) * quant_group_size :, : (i + 1) * quant_group_size |
| ] |
| update = torch.matmul( |
| torch.matmul(d_vec, H_t.transpose(-2, -1)), B_t.transpose(-2, -1) |
| ) |
| group[:, :, :, (i + 1) * quant_group_size :] = ( |
| group[:, :, :, (i + 1) * quant_group_size :] + update |
| ) |
|
|
| key_states_dequant.append(dequantized) |
|
|
| key_states_dequant = torch.cat(key_states_dequant, dim=3) |
| key_states_quant = self._quantize(key_states_dequant, axis=self.axis_key) |
| return key_states_quant |
|
|
|
|
| class QuantoSQuatCache(SQuatCache): |
|
|
| def __init__(self, cache_config: CacheConfig) -> None: |
| super().__init__(cache_config) |
|
|
| if is_optimum_quanto_available(): |
| optimum_quanto_version = version.parse(importlib.metadata.version("optimum-quanto")) |
| if optimum_quanto_version <= version.parse("0.2.5"): |
| raise ImportError( |
| f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." |
| ) |
| from optimum.quanto import MaxOptimizer, qint2, qint4 |
|
|
| if self.nbits not in [2, 4]: |
| raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") |
|
|
| if self.axis_key not in [0, -1]: |
| raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}") |
|
|
| if self.axis_value not in [0, -1]: |
| raise ValueError( |
| f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}" |
| ) |
|
|
| self.qtype = qint4 if self.nbits == 4 else qint2 |
| self.optimizer = MaxOptimizer() |
|
|
| def _quantize(self, tensor, axis): |
| |
| if is_optimum_quanto_available(): |
| from optimum.quanto import quantize_weight |
|
|
| scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) |
| qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) |
| return qtensor |
|
|
| def _dequantize(self, qtensor): |
| return qtensor.dequantize() |
|
|
|
|
| class HQQSQuatCache(SQuatCache): |
|
|
| def __init__(self, cache_config: CacheConfig) -> None: |
| super().__init__(cache_config) |
| if self.nbits not in [1, 2, 3, 4, 8]: |
| raise ValueError( |
| f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}" |
| ) |
|
|
| if self.axis_key not in [0, 1]: |
| raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}") |
|
|
| if self.axis_value not in [0, 1]: |
| raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}") |
|
|
| self.quantizer = HQQQuantizer |
|
|
| def _quantize(self, tensor, axis): |
| qtensor, meta = self.quantizer.quantize( |
| tensor, |
| axis=axis, |
| device=self.device, |
| compute_dtype=self.compute_dtype, |
| nbits=self.nbits, |
| group_size=self.q_group_size, |
| ) |
| meta["compute_dtype"] = self.compute_dtype |
| self.quantizer.cuda(qtensor, meta=meta, device=self.device) |
| meta["scale"] = meta["scale"].to(qtensor.device) |
| meta["zero"] = meta["zero"].to(qtensor.device) |
| return qtensor, meta |
|
|
| def _dequantize(self, qtensor): |
| quant_tensor, meta = qtensor |
| tensor = self.quantizer.dequantize(quant_tensor, meta) |
| return tensor |
|
|
|
|
| SQUAT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoSQuatCache, "HQQ": HQQSQuatCache} |
|
|
| def generate(model, generation_config=None, backend="quanto", nbits=2, quant_group_size=64, residual_length=32, squat_lambda=0.001, subspace_dim=20, shared_svd=True, **kwargs): |
| """Custom generate function for SinkCache. |
| Args: |
| model (`PreTrainedModel`): |
| The model to generate from. |
| """ |
|
|
| cache_config = SQuatCacheConfig( |
| backend=backend, |
| nbits=nbits, |
| quant_group_size=quant_group_size, |
| residual_length=residual_length, |
| squat_lambda=squat_lambda, |
| subspace_dim=subspace_dim, |
| shared_svd=shared_svd, |
| ) |
| cache_class = SQUAT_BACKEND_CLASSES_MAPPING[cache_config.backend] |
|
|
| if cache_config.backend == "quanto" and not is_optimum_quanto_available(): |
| raise ImportError( |
| "You need to install optimum-quanto in order to use KV cache quantization with optimum-quanto backend. " |
| "Please install it via with `pip install optimum-quanto`" |
| ) |
| elif cache_config.backend == "HQQ" and not is_hqq_available(): |
| raise ImportError( |
| "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " |
| "Please install it via with `pip install hqq`" |
| ) |
|
|
| |
| if model.config.is_encoder_decoder: |
| raise ValueError("This custom generate function only works with decoder-only models") |
|
|
| |
| |
| kwargs.pop("custom_generate", None) |
|
|
| |
| |
| past_key_values = kwargs.pop("past_key_values", None) |
| if past_key_values is None: |
| past_key_values = cache_class(cache_config=cache_config) |
|
|
| |
| generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True) |
| return generation_outputs |
|
|