| """Batch-splitting adapter for the transformer. |
| Wraps an ``X0Model`` (or ``LayerStreamingWrapper``) and splits batched inputs |
| into smaller chunks before forwarding, then concatenates the results. This |
| controls peak activation memory at the cost of more forward passes. |
| The adapter is transparent — it has the same ``forward`` signature as |
| ``X0Model`` and proxies attribute access to the wrapped model. |
| Example |
| ------- |
| >>> from ltx_core.batch_split import BatchSplitAdapter |
| >>> adapter = BatchSplitAdapter(model, max_batch_size=1) |
| >>> # Receives B=4, runs 4xB=1 internally, returns B=4 |
| >>> denoised_video, denoised_audio = adapter(video=v_b4, audio=a_b4, perturbations=ptb) |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| import torch |
| from torch import nn |
|
|
| from ltx_core.guidance.perturbations import BatchedPerturbationConfig |
| from ltx_core.model.transformer.modality import Modality |
|
|
|
|
| def _split_perturbations(config: BatchedPerturbationConfig, sizes: list[int]) -> list[BatchedPerturbationConfig]: |
| """Split a ``BatchedPerturbationConfig`` along the batch dimension.""" |
| it = iter(config.perturbations) |
| return [BatchedPerturbationConfig([next(it) for _ in range(s)]) for s in sizes] |
|
|
|
|
| def _merge_tensors(tensors: list[torch.Tensor | None]) -> torch.Tensor | None: |
| """Concatenate tensors along batch dim, or return None if all are None.""" |
| non_none = [t for t in tensors if t is not None] |
| if not non_none: |
| return None |
| return torch.cat(non_none, dim=0) |
|
|
|
|
| class BatchSplitAdapter(nn.Module): |
| """Wraps a model and splits batched forward calls into smaller chunks. |
| Has the same ``forward`` signature as ``X0Model``: |
| ``(video, audio, perturbations) -> (denoised_video, denoised_audio)``. |
| Args: |
| model: The model to wrap (``X0Model``, ``LayerStreamingWrapper``, etc.). |
| max_batch_size: Maximum batch size per forward pass. Input batches |
| larger than this are split into sequential chunks. |
| """ |
|
|
| def __init__(self, model: nn.Module, max_batch_size: int) -> None: |
| if max_batch_size < 1: |
| raise ValueError(f"max_batch_size must be >= 1, got {max_batch_size}") |
| super().__init__() |
| self._model = model |
| self._max_batch_size = max_batch_size |
|
|
| def _get_chunk_sizes(self, batch_size: int) -> list[int]: |
| full, remainder = divmod(batch_size, self._max_batch_size) |
| sizes = [self._max_batch_size] * full |
| if remainder: |
| sizes.append(remainder) |
| return sizes |
|
|
| def forward( |
| self, |
| video: Modality | None, |
| audio: Modality | None, |
| perturbations: BatchedPerturbationConfig, |
| ) -> tuple[torch.Tensor | None, torch.Tensor | None]: |
| batch_size = (video or audio).latent.shape[0] |
|
|
| if batch_size <= self._max_batch_size: |
| return self._model(video=video, audio=audio, perturbations=perturbations) |
|
|
| sizes = self._get_chunk_sizes(batch_size) |
| n = len(sizes) |
|
|
| v_chunks = video.split(sizes) if video is not None else [None] * n |
| a_chunks = audio.split(sizes) if audio is not None else [None] * n |
| p_chunks = _split_perturbations(perturbations, sizes) |
|
|
| chunk_results = [ |
| self._model(video=vc, audio=ac, perturbations=pc) |
| for vc, ac, pc in zip(v_chunks, a_chunks, p_chunks, strict=True) |
| ] |
|
|
| results_v, results_a = zip(*chunk_results, strict=True) |
| return _merge_tensors(list(results_v)), _merge_tensors(list(results_a)) |
|
|
| def __getattr__(self, name: str) -> Any: |
| """Proxy attribute access to the wrapped model.""" |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| return getattr(self._model, name) |
|
|