# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # coding: utf-8 from itertools import chain from typing import Dict, List, Tuple import einops import torch def rearrange( hid: torch.FloatTensor, # (L c) hid_shape: torch.LongTensor, # (b n) pattern: str, **kwargs: Dict[str, int], ) -> Tuple[ torch.FloatTensor, torch.LongTensor, ]: return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)]) def repeat( hid: torch.FloatTensor, # (L c) hid_shape: torch.LongTensor, # (b n) pattern: str, **kwargs: Dict[str, torch.LongTensor], # (b) ) -> Tuple[ torch.FloatTensor, torch.LongTensor, ]: hid = unflatten(hid, hid_shape) kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) def pack( samples: List[torch.Tensor], # List of (h w c). ) -> Tuple[ List[torch.Tensor], # groups [(b1 h1 w1 c1), (b2 h2 w2 c2)] List[List[int]], # reversal indices. ]: batches = {} indices = {} for i, sample in enumerate(samples): shape = sample.shape batches[shape] = batches.get(shape, []) indices[shape] = indices.get(shape, []) batches[shape].append(sample) indices[shape].append(i) batches = list(map(torch.stack, batches.values())) indices = list(indices.values()) return batches, indices def unpack( batches: List[torch.Tensor], indices: List[List[int]], ) -> List[torch.Tensor]: samples = [None] * (max(chain(*indices)) + 1) for batch, index in zip(batches, indices): for sample, i in zip(batch.unbind(), index): samples[i] = sample return samples # 需要保留的辅助函数,因为 rearrange 和 repeat 依赖它们 def flatten( hid: List[torch.FloatTensor], # List of (*** c) ) -> Tuple[ torch.FloatTensor, # (L c) torch.LongTensor, # (b n) ]: assert len(hid) > 0 shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) hid = torch.cat([x.flatten(0, -2) for x in hid]) return hid, shape def unflatten( hid: torch.FloatTensor, # (L c) or (L ... c) hid_shape: torch.LongTensor, # (b n) ) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) hid_len = hid_shape.prod(-1) hid = hid.split(hid_len.tolist()) hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] return hid