Spaces:
Running on Zero
Running on Zero
| """Batch-splitting adapter for the transformer. | |
| Wraps an ``X0Model`` (or ``LayerStreamingWrapper``) and splits batched inputs | |
| into smaller chunks before forwarding, then concatenates the results. This | |
| controls peak activation memory at the cost of more forward passes. | |
| The adapter is transparent — it has the same ``forward`` signature as | |
| ``X0Model`` and proxies attribute access to the wrapped model. | |
| Example | |
| ------- | |
| from ltx_core.batch_split import BatchSplitAdapter | |
| adapter = BatchSplitAdapter(model, max_batch_size=1) | |
| # Receives B=4, runs 4xB=1 internally, returns B=4 | |
| denoised_video, denoised_audio = adapter(video=v_b4, audio=a_b4, perturbations=ptb) | |
| """ | |
| from __future__ import annotations | |
| from typing import Any | |
| import torch | |
| from torch import nn | |
| from ltx_core.guidance.perturbations import BatchedPerturbationConfig | |
| from ltx_core.model.transformer.modality import Modality | |
| def _split_perturbations(config: BatchedPerturbationConfig, sizes: list[int]) -> list[BatchedPerturbationConfig]: | |
| """Split a ``BatchedPerturbationConfig`` along the batch dimension.""" | |
| it = iter(config.perturbations) | |
| return [BatchedPerturbationConfig([next(it) for _ in range(s)]) for s in sizes] | |
| def _merge_tensors(tensors: list[torch.Tensor | None]) -> torch.Tensor | None: | |
| """Concatenate tensors along batch dim, or return None if all are None.""" | |
| non_none = [t for t in tensors if t is not None] | |
| if not non_none: | |
| return None | |
| return torch.cat(non_none, dim=0) | |
| class BatchSplitAdapter(nn.Module): | |
| """Wraps a model and splits batched forward calls into smaller chunks. | |
| Has the same ``forward`` signature as ``X0Model``: | |
| ``(video, audio, perturbations) -> (denoised_video, denoised_audio)``. | |
| Args: | |
| model: The model to wrap (``X0Model``, ``LayerStreamingWrapper``, etc.). | |
| max_batch_size: Maximum batch size per forward pass. Input batches | |
| larger than this are split into sequential chunks. | |
| """ | |
| def __init__(self, model: nn.Module, max_batch_size: int) -> None: | |
| if max_batch_size < 1: | |
| raise ValueError(f"max_batch_size must be >= 1, got {max_batch_size}") | |
| super().__init__() | |
| self._model = model | |
| self._max_batch_size = max_batch_size | |
| def _get_chunk_sizes(self, batch_size: int) -> list[int]: | |
| full, remainder = divmod(batch_size, self._max_batch_size) | |
| sizes = [self._max_batch_size] * full | |
| if remainder: | |
| sizes.append(remainder) | |
| return sizes | |
| def forward( | |
| self, | |
| video: Modality | None, | |
| audio: Modality | None, | |
| perturbations: BatchedPerturbationConfig, | |
| ) -> tuple[torch.Tensor | None, torch.Tensor | None]: | |
| batch_size = (video or audio).latent.shape[0] | |
| if batch_size <= self._max_batch_size: | |
| return self._model(video=video, audio=audio, perturbations=perturbations) | |
| sizes = self._get_chunk_sizes(batch_size) | |
| n = len(sizes) | |
| v_chunks = video.split(sizes) if video is not None else [None] * n | |
| a_chunks = audio.split(sizes) if audio is not None else [None] * n | |
| p_chunks = _split_perturbations(perturbations, sizes) | |
| chunk_results = [ | |
| self._model(video=vc, audio=ac, perturbations=pc) | |
| for vc, ac, pc in zip(v_chunks, a_chunks, p_chunks, strict=True) | |
| ] | |
| results_v, results_a = zip(*chunk_results, strict=True) | |
| return _merge_tensors(list(results_v)), _merge_tensors(list(results_a)) | |
| def __getattr__(self, name: str) -> Any: # noqa: ANN401 | |
| """Proxy attribute access to the wrapped model.""" | |
| try: | |
| return super().__getattr__(name) | |
| except AttributeError: | |
| return getattr(self._model, name) | |