| import torch |
|
|
| import types |
| from typing import Any, Dict, List, Optional, Tuple, Union |
| import transformers |
| from transformers import Cache, GenerationConfig |
| import torch.nn as nn |
| from transformers.modeling_utils import PreTrainedModel |
|
|
| from .functions_2_patch import _validate_model_kwargs, llama_atten_forward |
| from .monkey_patching_utils import monkey_patching |
| from .sep_cache_utils import SepCache |
|
|
|
|
| UNSUPPORTED_GENERATION_ARGS = [ |
| "cache_implementation", |
| "cache_config", |
| "return_legacy_cache", |
| "num_beams", |
| "compile_config", |
| "assistant_model", |
| ] |
|
|
|
|
| def generate(model, |
| |
| init_cache_size: Union[int, List] = 4, |
| sep_cache_size: Union[int, List] = 128, |
| local_size: Union[int, List]=256, |
| cache_size: Union[int, List]=512, |
| SEP_ACCUMULATION: bool = True, |
| USE_MAX_SEP_CACHE: bool = False, |
| SEP_PADDING_IN_BATCH: bool = False, |
| separator_token_ids: List[int] = None, |
| PADDING_ID: int = None, |
| |
| |
| past_tok_ids: List[torch.Tensor] = None, |
| key_cache: List[torch.Tensor] = None, |
| value_cache: List[torch.Tensor] = None, |
| |
| |
| PRINT_KV_RATIO_INSIDE: bool = False, |
| print_KV_inside_per_steps: int = 1000, |
| _seen_tokens: int = 0, |
| _kept_kv_ratio: List[Tuple[int]] = None, |
| |
| |
| APPLY_PE_SHIFT: bool = False, |
| APPLY_PES_INSIDE: bool = False, |
| _shifted_position_ids: List[torch.Tensor] = None, |
| _rope_unsqueeze_dim: int = 1, |
| _rope_seq_dim: int=1, |
| pe_scaling_factor:float = 1.0, |
| pe_dim:int=128, |
| max_position_embeddings: int = 8192, |
| base: int=10000, |
| |
| |
| k_seq_dim: int=2, |
| v_seq_dim: int=2, |
| layer_num: int = None, |
| |
| model_type: str = 'llama', |
| device = None, |
| |
| |
| monkey_patch_verbose: bool = False, |
| |
| **kwargs |
| ): |
| """Custom generate function for SepCache. |
| |
| A cache as described in the [SepLLM paper - ICML 2025](https://arxiv.org/abs/2412.12094). In the training phase, |
| SepLLM condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the |
| corresponding SepCache only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation. |
| |
| It stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is |
| `[batch_size, num_heads, seq_len, head_dim]`. |
| |
| Frequently-Used Parameters: |
| |
| `init_cache_size: Union[int, List]`: |
| The maximum number of KVs to be stored for initial tokens. |
| In the paper, the hyperparameter `a` is an abbreviated alias for `self.init_cache_size`. |
| |
| `sep_cache_size: Union[int, List]`: |
| The maximum number of KVs to be stored for separator tokens. |
| In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`. |
| |
| `local_size: Union[int, List]`: |
| The maximum number of KVs to be stored for local tokens (i.e., sliding window). |
| In the paper, the hyperparameter `w` is an abbreviated alias for `self.local_size`. |
| |
| `cache_size: Union[int, List]`: |
| The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache. |
| In the paper, the hyperparameter `c` is an abbreviated alias for `self.cache_size`. |
| |
| Concerning these four parameters above: |
| When a list is passed (its length must be `layer_num`), it represents different values for each layer. |
| When an integer is passed, it means the setting is the same for all layers. |
| |
| |
| `USE_MAX_SEP_CACHE: bool`: |
| If True, it means we only keep at most `self.sep_cache_size` seperators' KVs. |
| If the number exceeds this limit, older separator's KVs will be discarded, keeping only the most recent `self.sep_cache_size` KVs. |
| In the paper, the hyperparameter `s` is an abbreviated alias for `self.sep_cache_size`. |
| |
| `separator_token_ids: List[int]`: |
| The token ids of the separator tokens for the current model's tokenizer. |
| We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you |
| to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them). |
| |
| `PADDING_ID: int`: |
| The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model. |
| |
| Important Note: |
| When `cache_size` and `local_size` are set to infinity (i.e., sufficiently large positive integers), and `USE_MAX_SEP_CACHE` is `False`, `SepCache` degenerates into a regular Cache. |
| However, you must always ensure that `init_cache_size` + `sep_cache_size` + `local_size` + `left_padding_offset` < `cache_size`. |
| Here, `left_padding_offset` denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. `left_padding_offset` can only be determined at runtime. |
| To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality: |
| `init_cache_size` + `sep_cache_size` + `local_size` < `cache_size`, i.e., `a`+`s`+`w`<`c` in the [SepLLM paper - ICML 2025] |
| to leave room for `left_padding_offset`. |
| |
| Please refer to the `__init__` function's comments for more details on the parameters. |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, AutoModelForCausalLM, |
| >>> from .sep_cache_utils import SepCache |
| >>> import torch |
| >>> from huggingface_hub import login |
| >>> login("hf_xxxXXXxxx") |
| |
| |
| >>> def to_cuda(a_dict: dict) -> dict: |
| >>> new_dict = {} |
| >>> for k,v in a_dict.items(): |
| >>> if isinstance(v, torch.Tensor): |
| >>> new_dict[k] = v.cuda() |
| >>> else: |
| >>> new_dict[k] = v |
| >>> return new_dict |
| |
| >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", attn_implementation="flash_attention_2", device_map="cuda:0") |
| >>> model.bfloat16().cuda() |
| >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct") |
| >>> inputs = tokenizer(text="My name is Llama 3", return_tensors="pt") |
| >>> inputs = to_cuda(inputs) |
| >>> # Prepare a cache and pass it to model's forward; `layer_num` is the number of layers for the pretrained model. |
| >>> past_key_values = SepCache(init_cache_size=4, sep_cache_size=128, local_size=256, cache_size=512, layer_num=32, USE_MAX_SEP_CACHE=True, model_type='llama') |
| >>> # `separator_token_ids` and `PADDING_ID` must also be provided if you are not using `model_type='llama'` like this demo. |
| >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
| >>> outputs.past_key_values # access SepCache filled with keys/values |
| SepCache() |
| ``` |
| |
| ```python |
| >>> ## When using the `update` function of SepCache to update the keys/values and the past token ids (necessary in SepCache), the current `input_ids` must also be provided. |
| >>> key_states, value_states = past_key_values.update( |
| key_states = key_states, |
| value_states = value_states, |
| input_ids = input_ids, |
| layer_idx = layer_idx, |
| PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states` |
| ) |
| |
| ``` |
| For detailed usage instructions, please refer to https://github.com/HKUDS/SepLLM |
| """ |
|
|
| |
| model_layers = monkey_patching(model, model_atten_forward=llama_atten_forward, verbose=monkey_patch_verbose) |
|
|
| |
| |
| generation_config = kwargs.get("generation_config") |
| default_global_generation_config = GenerationConfig() |
| default_model_generation_config = model.generation_config |
| for arg in UNSUPPORTED_GENERATION_ARGS: |
| has_custom_gen_config_arg = ( |
| generation_config is not None |
| |
| and not ( |
| getattr(default_model_generation_config, arg) == getattr(generation_config, arg) |
| or getattr(default_global_generation_config, arg) == getattr(generation_config, arg) |
| ) |
| ) |
| kwargs_has_arg = arg in kwargs and kwargs[arg] is not None |
| if kwargs_has_arg or has_custom_gen_config_arg: |
| raise ValueError( |
| f"`{arg}` is set, but it's not supported in this custom generate function. List of " |
| f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}" |
| ) |
|
|
| |
| |
| |
| if model.config.is_encoder_decoder: |
| raise ValueError("This custom generate function only works with decoder-only models") |
|
|
| |
| |
| kwargs.pop("custom_generate", None) |
|
|
|
|
| sepllm_kwargs = {} |
| sepllm_kwargs["input_ids"] = kwargs["input_ids"] |
| kwargs["sepllm_kwargs"] = sepllm_kwargs |
|
|
| |
| |
| past_key_values = kwargs.pop("past_key_values", None) |
| if past_key_values is None: |
| past_key_values = SepCache( |
| |
| init_cache_size = init_cache_size, |
| sep_cache_size = sep_cache_size, |
| local_size = local_size, |
| cache_size = cache_size, |
| SEP_ACCUMULATION = SEP_ACCUMULATION, |
| USE_MAX_SEP_CACHE = USE_MAX_SEP_CACHE, |
| SEP_PADDING_IN_BATCH = SEP_PADDING_IN_BATCH, |
| separator_token_ids = separator_token_ids, |
| PADDING_ID = PADDING_ID, |
|
|
| |
| past_tok_ids = past_tok_ids, |
| key_cache = key_cache, |
| value_cache = value_cache, |
|
|
| |
| PRINT_KV_RATIO_INSIDE = PRINT_KV_RATIO_INSIDE, |
| print_KV_inside_per_steps = print_KV_inside_per_steps, |
| _seen_tokens = _seen_tokens, |
| _kept_kv_ratio = _kept_kv_ratio, |
| |
| |
| APPLY_PE_SHIFT = APPLY_PE_SHIFT, |
| APPLY_PES_INSIDE = APPLY_PES_INSIDE, |
| _shifted_position_ids = _shifted_position_ids, |
| _rope_unsqueeze_dim = _rope_unsqueeze_dim, |
| _rope_seq_dim =_rope_seq_dim, |
| pe_scaling_factor = pe_scaling_factor, |
| pe_dim = pe_dim, |
| max_position_embeddings = max_position_embeddings, |
| base = base, |
| |
| |
| k_seq_dim = k_seq_dim, |
| v_seq_dim = v_seq_dim, |
| layer_num = len(model_layers), |
|
|
| model_type = model_type, |
| device = device, |
| ) |
|
|
| elif not isinstance(past_key_values, SepCache): |
| raise ValueError(f"`past_key_values` must be a `SepCache` instance, got a {type(past_key_values)} instance") |
|
|
| |
| kwargs["use_cache"] = True |
| generation_outputs = model.generate(**kwargs, past_key_values=past_key_values) |
| return generation_outputs |
|
|