| from typing import Tuple |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| class Cache: |
| def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: |
| assert embed_dim % num_heads == 0 |
| self._n, self._cache, self._size = num_samples, None, None |
| self._reset = lambda n: torch.empty(n, num_heads, max_tokens, embed_dim // num_heads, device=device) |
| self.reset() |
|
|
| @property |
| def shape(self) -> Tuple[int, int, int, int]: |
| n, num_heads, _, head_dim = self._cache.shape |
| return n, num_heads, self._size, head_dim |
|
|
| def reset(self) -> None: |
| self._cache = self._reset(self._n) |
| self._size = 0 |
|
|
| def prune(self, mask: np.ndarray) -> None: |
| assert mask.ndim == 1 and mask.shape[0] == self.shape[0] |
| self._cache = self._cache[mask] |
| self._n = self._cache.shape[0] |
|
|
| def get(self) -> torch.Tensor: |
| return self._cache[:, :, :self._size, :] |
|
|
| def update(self, x: torch.Tensor) -> None: |
| assert (x.ndim == self._cache.ndim) and all([x.size(i) == self._cache.size(i) for i in (0, 1, 3)]) |
| assert self._size + x.size(2) <= self._cache.shape[2] |
| self._cache = AssignWithoutInplaceCheck.apply(self._cache, x, 2, self._size, self._size + x.size(2)) |
| self._size += x.size(2) |
|
|
|
|
| class KVCache: |
| def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: |
| self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device) |
| self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device) |
|
|
| @property |
| def shape(self) -> Tuple[int, int, int, int]: |
| return self._k_cache.shape |
|
|
| def reset(self) -> None: |
| self._k_cache.reset() |
| self._v_cache.reset() |
|
|
| def prune(self, mask: np.ndarray) -> None: |
| self._k_cache.prune(mask) |
| self._v_cache.prune(mask) |
|
|
| def get(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| return self._k_cache.get(), self._v_cache.get() |
|
|
| def update(self, k: torch.Tensor, v: torch.Tensor): |
| self._k_cache.update(k) |
| self._v_cache.update(v) |
|
|
|
|
| class KeysValues: |
| def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None: |
| self._keys_values = tuple([KVCache(n, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers)]) |
|
|
| def __getitem__(self, key: int) -> KVCache: |
| return self._keys_values[key] |
|
|
| def __len__(self): |
| return len(self._keys_values) |
|
|
| @property |
| def size(self): |
| return self._keys_values[0].shape[2] |
|
|
| def reset(self) -> None: |
| for kv_cache in self._keys_values: |
| kv_cache.reset() |
|
|
| def prune(self, mask: np.ndarray) -> None: |
| for kv_cache in self._keys_values: |
| kv_cache.prune(mask) |
|
|
|
|
| class AssignWithoutInplaceCheck(torch.autograd.Function): |
| """ |
| Inspired from : https://discuss.pytorch.org/t/disable-in-place-correctness-version-check-any-other-workaround/90738/4 |
| Warning : do not use it to overwrite a slice twice. |
| """ |
|
|
| @staticmethod |
| def get_slice(dim: int, start: int, stop: int) -> Tuple[slice]: |
| return tuple([slice(None), ] * dim + [slice(start, stop)]) |
|
|
| @staticmethod |
| def forward(ctx, input: torch.Tensor, value: torch.Tensor, dim: int, start: int, stop: int) -> torch.Tensor: |
| ctx.dim = dim |
| ctx.start = start |
| ctx.stop = stop |
| input.data[AssignWithoutInplaceCheck.get_slice(dim, start, stop)] = value |
| return input |
|
|
| @staticmethod |
| def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor]: |
| return grad_out, grad_out[AssignWithoutInplaceCheck.get_slice(ctx.dim, ctx.start, ctx.stop)], None, None, None |
|
|