| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from itertools import chain |
| from typing import Dict, List, Tuple |
| import einops |
| import torch |
|
|
|
|
| def rearrange( |
| hid: torch.FloatTensor, |
| hid_shape: torch.LongTensor, |
| pattern: str, |
| **kwargs: Dict[str, int], |
| ) -> Tuple[ |
| torch.FloatTensor, |
| torch.LongTensor, |
| ]: |
| return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)]) |
|
|
|
|
| def repeat( |
| hid: torch.FloatTensor, |
| hid_shape: torch.LongTensor, |
| pattern: str, |
| **kwargs: Dict[str, torch.LongTensor], |
| ) -> Tuple[ |
| torch.FloatTensor, |
| torch.LongTensor, |
| ]: |
| hid = unflatten(hid, hid_shape) |
| kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] |
| return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) |
|
|
|
|
| def pack( |
| samples: List[torch.Tensor], |
| ) -> Tuple[ |
| List[torch.Tensor], |
| List[List[int]], |
| ]: |
| batches = {} |
| indices = {} |
| for i, sample in enumerate(samples): |
| shape = sample.shape |
| batches[shape] = batches.get(shape, []) |
| indices[shape] = indices.get(shape, []) |
| batches[shape].append(sample) |
| indices[shape].append(i) |
|
|
| batches = list(map(torch.stack, batches.values())) |
| indices = list(indices.values()) |
| return batches, indices |
|
|
|
|
| def unpack( |
| batches: List[torch.Tensor], |
| indices: List[List[int]], |
| ) -> List[torch.Tensor]: |
| samples = [None] * (max(chain(*indices)) + 1) |
| for batch, index in zip(batches, indices): |
| for sample, i in zip(batch.unbind(), index): |
| samples[i] = sample |
| return samples |
|
|
|
|
| |
| def flatten( |
| hid: List[torch.FloatTensor], |
| ) -> Tuple[ |
| torch.FloatTensor, |
| torch.LongTensor, |
| ]: |
| assert len(hid) > 0 |
| shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) |
| hid = torch.cat([x.flatten(0, -2) for x in hid]) |
| return hid, shape |
|
|
|
|
| def unflatten( |
| hid: torch.FloatTensor, |
| hid_shape: torch.LongTensor, |
| ) -> List[torch.Tensor]: |
| hid_len = hid_shape.prod(-1) |
| hid = hid.split(hid_len.tolist()) |
| hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] |
| return hid |
|
|