| from collections import namedtuple |
| from dataclasses import dataclass |
| import torch |
| from typing import Tuple, Optional |
|
|
|
|
| @dataclass |
| class LongLlamaMemConfig: |
| """ |
| Class for configuring memory caches for LongLlama model. |
| |
| Args: |
| positionals (`boolean`) |
| Whether to use positional embeddings in memory layer |
| cache_dtype (`torch.dtype`) |
| Specifies storing type for keys and values |
| attention_grouping (`Tuple[int, int]`, *optional*) |
| One can trade speed for memory by performing attention |
| in memory layers sequentially. |
| When equal to `(4, 128)` the memory layers will process at most 4 heads and 128 queries |
| from each head at once. That is at most 512 queries at once. |
| """ |
|
|
| positionals: bool = True |
| cache_dtype: torch.dtype = torch.bfloat16 |
| attention_grouping: Optional[Tuple[int, int]] = None |
|
|
|
|
| @dataclass |
| class LongLlamaMemCache: |
| """ |
| Class with LongLlama's memory cache |
| |
| Args: |
| keys (`torch.FloatTensor` of shape `(batch_size, num_heads, mem_length, embed_size_per_head)`) |
| values (`torch.FloatTensor` of shape `(batch_size, num_heads, mem_length, embed_size_per_head)`) |
| masks (`torch.FloatTensor` of shape `(batch_size, 1, mem_length, 1)`) |
| For masking out parts of memory |
| """ |
|
|
| keys: torch.FloatTensor |
| values: torch.FloatTensor |
| masks: torch.FloatTensor |
|
|
|
|
| def mem_apply_update( |
| prev_mem_cache: LongLlamaMemCache, new_mem_content: LongLlamaMemCache, mem_config: LongLlamaMemConfig |
| ): |
| def update_one(prev, new): |
| if len(prev.shape) != 4 or len(new.shape) != 4: |
| raise ValueError(f"Memory cache content should be consistent in shape got {prev.shape} {new.shape}") |
|
|
| return torch.concat([prev, new], dim=-2) |
|
|
| insert_size = new_mem_content.keys.shape[-2] |
|
|
| if new_mem_content.values.shape[-2] != insert_size or new_mem_content.masks.shape[-2] != insert_size: |
| raise ValueError(f"Inconsistent mem_length in new_mem_content") |
|
|
| return LongLlamaMemCache( |
| keys=update_one(prev_mem_cache.keys, new_mem_content.keys), |
| values=update_one(prev_mem_cache.values, new_mem_content.values), |
| masks=update_one(prev_mem_cache.masks, new_mem_content.masks), |
| ) |
|
|