| """ |
| batching_utils.py |
| |
| Core definitions of (Distributed) Samplers for VLM finetuning; provides functionality for construction and allocating |
| "split-modality" batches as described in the LLaVa paper; this makes sure that a given device/batch is either entirely |
| (vision, language) or (language-only) data, which leads to sizeable efficiency gains. |
| """ |
|
|
| import math |
| from typing import Iterator, List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| from torch.utils.data import Dataset, Sampler |
|
|
|
|
| |
| |
| |
| |
| |
| class SplitModalitySampler(Sampler): |
| def __init__( |
| self, |
| dataset: Dataset, |
| modality_lengths: List[Tuple[bool, int]], |
| global_batch_size: int, |
| num_replicas: Optional[int] = None, |
| rank: Optional[int] = None, |
| seed: int = 0, |
| drop_last: bool = False, |
| ) -> None: |
| super().__init__() |
| self.num_replicas = num_replicas if num_replicas is not None else dist.get_world_size() |
| self.rank = rank if rank is not None else dist.get_rank() |
| self.seed, self.epoch = seed, 0 |
|
|
| |
| self.dataset, self.modality_lengths, self.drop_last = dataset, modality_lengths, drop_last |
| self.global_batch_size = global_batch_size |
|
|
| |
| assert not self.drop_last, "SplitModalitySampler must set `drop_last = False`!" |
| self.total_size = math.ceil(len(self.dataset) / self.global_batch_size) * self.global_batch_size |
| self.num_samples = self.total_size // self.num_replicas |
|
|
| @staticmethod |
| def reindex_batch(batch_idxs: List[int], idx2lengths: List[int], n_buckets: int) -> List[List[int]]: |
| """Re-indexes a batch in a way that is conducive to DistributedSampler + grouping by seqlen per rank.""" |
| assert len(batch_idxs) % n_buckets == 0, "Batch length is not divisible by `num_replicas`!" |
|
|
| |
| n_examples_per_bucket = len(batch_idxs) // n_buckets |
| bucket_indices = [[] for _ in range(n_buckets)] |
| bucket_lengths = [0 for _ in range(n_buckets)] |
|
|
| |
| for idx in batch_idxs: |
| shortest_bucket_idx = bucket_lengths.index(min(bucket_lengths)) |
| bucket_indices[shortest_bucket_idx].append(idx) |
|
|
| |
| bucket_lengths[shortest_bucket_idx] += idx2lengths[idx] |
| if len(bucket_indices[shortest_bucket_idx]) == n_examples_per_bucket: |
| bucket_lengths[shortest_bucket_idx] = float("inf") |
|
|
| return bucket_indices |
|
|
| def get_modality_and_length_grouped_indices(self, generator: torch.Generator) -> List[int]: |
| """ |
| Returns a list of indices so that each slice of `global_batch_size` consecutive indices corresponds to elements |
| of the same modality with each sub-sequence of `per_replica_batch_size` (the batch size each unique device sees |
| during distributed training) is roughly grouped by sequence length (for training efficiency). |
| """ |
| multimodal_indices, multimodal_lengths = zip( |
| *[(idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if is_multimodal] |
| ) |
|
|
| |
| unimodal_split = [ |
| (idx, length) for idx, (is_multimodal, length) in enumerate(self.modality_lengths) if not is_multimodal |
| ] |
| if len(unimodal_split) == 0: |
| unimodal_indices, unimodal_lengths = [], [] |
| else: |
| unimodal_indices, unimodal_lengths = zip(*unimodal_split) |
|
|
| |
| mm_shuffled_idxs = torch.randperm(len(multimodal_indices), generator=generator) |
| uni_shuffled_idxs = torch.randperm(len(unimodal_indices), generator=generator) |
|
|
| |
| g_bsz = self.global_batch_size |
|
|
| |
| mm_batch_idxs = [mm_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(mm_shuffled_idxs), g_bsz)] |
| uni_batch_idxs = [uni_shuffled_idxs[i : i + g_bsz].tolist() for i in range(0, len(uni_shuffled_idxs), g_bsz)] |
|
|
| |
| if len(mm_batch_idxs[-1]) < g_bsz: |
| n_missing = g_bsz - len(mm_batch_idxs[-1]) |
| mm_batch_idxs[-1].extend(mm_batch_idxs[0][:n_missing]) |
|
|
| if len(uni_batch_idxs) > 0 and len(uni_batch_idxs[-1]) < g_bsz: |
| n_missing = g_bsz - len(uni_batch_idxs[-1]) |
| uni_batch_idxs[-1].extend(uni_batch_idxs[0][:n_missing]) |
|
|
| |
| mm_sorted_batch_idxs = [sorted(b, key=lambda i: multimodal_lengths[i], reverse=True) for b in mm_batch_idxs] |
| uni_sorted_batch_idxs = [sorted(b, key=lambda i: unimodal_lengths[i], reverse=True) for b in uni_batch_idxs] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| mm_length_bucketed_idxs = [ |
| self.reindex_batch(batch, multimodal_lengths, self.num_replicas) for batch in mm_sorted_batch_idxs |
| ] |
| uni_length_bucketed_idxs = [ |
| self.reindex_batch(batch, unimodal_lengths, self.num_replicas) for batch in uni_sorted_batch_idxs |
| ] |
|
|
| |
| |
| mm_output_idxs = [idx for batch in mm_length_bucketed_idxs for bucket in batch for idx in bucket] |
| mm_reindexed = [multimodal_indices[idx] for idx in mm_output_idxs] |
| mm_batches = [mm_reindexed[i : i + g_bsz] for i in range(0, len(mm_reindexed), g_bsz)] |
|
|
| uni_output_idxs = [idx for batch in uni_length_bucketed_idxs for bucket in batch for idx in bucket] |
| uni_reindexed = [unimodal_indices[idx] for idx in uni_output_idxs] |
| uni_batches = [uni_reindexed[i : i + g_bsz] for i in range(0, len(uni_reindexed), g_bsz)] |
|
|
| |
| merged_batches = mm_batches + uni_batches |
| merge_idxs = torch.randperm(len(merged_batches), generator=generator) |
| all_batches = [merged_batches[idx] for idx in merge_idxs] |
|
|
| |
| all_lengths = [length + ((_n_patches := 24 * 24) if is_mm else 0) for is_mm, length in self.modality_lengths] |
| all_batches_max_lengths = [] |
| for batch in all_batches: |
| all_batches_max_lengths.append(max([all_lengths[idx] for idx in batch])) |
|
|
| |
| longest_batch_idx = np.argmax(all_batches_max_lengths) |
| all_batches[0], all_batches[longest_batch_idx] = all_batches[longest_batch_idx], all_batches[0] |
|
|
| |
| indices = [idx for batch in all_batches for idx in batch] |
| return indices |
|
|
| def __iter__(self) -> Iterator: |
| """Deterministically shuffle, then split indices by modality and length.""" |
| g = torch.Generator() |
| g.manual_seed(self.seed + self.epoch) |
| indices = self.get_modality_and_length_grouped_indices(g) |
| assert len(set(indices)) == len(self.modality_lengths) == len(self.dataset), "Oops!" |
| assert (len(indices) % self.global_batch_size == 0) and (len(indices) % self.num_replicas) == 0, "Oops" |
|
|
| |
| |
| per_replica_batch_size = self.global_batch_size // self.num_replicas |
|
|
| |
| |
| indices_t = torch.as_tensor(indices) |
| per_replica_batch_indices_t = indices_t.reshape(-1, per_replica_batch_size) |
| replica_indices_t = per_replica_batch_indices_t[self.rank :: self.num_replicas] |
|
|
| replica_indices = replica_indices_t.flatten().tolist() |
| return iter(replica_indices) |
|
|
| def __len__(self) -> int: |
| return self.num_samples |
|
|
| def set_epoch(self, epoch: int) -> None: |
| """To be called *between* epochs, prior to DataLoader instantiation; ensures random order across epochs.""" |
| self.epoch = epoch |
|
|