Spaces:
Running on Zero
Running on Zero
File size: 3,783 Bytes
08c5e28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | """Batch-splitting adapter for the transformer.
Wraps an ``X0Model`` (or ``LayerStreamingWrapper``) and splits batched inputs
into smaller chunks before forwarding, then concatenates the results. This
controls peak activation memory at the cost of more forward passes.
The adapter is transparent — it has the same ``forward`` signature as
``X0Model`` and proxies attribute access to the wrapped model.
Example
-------
>>> from ltx_core.batch_split import BatchSplitAdapter
>>> adapter = BatchSplitAdapter(model, max_batch_size=1)
>>> # Receives B=4, runs 4xB=1 internally, returns B=4
>>> denoised_video, denoised_audio = adapter(video=v_b4, audio=a_b4, perturbations=ptb)
"""
from __future__ import annotations
from typing import Any
import torch
from torch import nn
from ltx_core.guidance.perturbations import BatchedPerturbationConfig
from ltx_core.model.transformer.modality import Modality
def _split_perturbations(config: BatchedPerturbationConfig, sizes: list[int]) -> list[BatchedPerturbationConfig]:
"""Split a ``BatchedPerturbationConfig`` along the batch dimension."""
it = iter(config.perturbations)
return [BatchedPerturbationConfig([next(it) for _ in range(s)]) for s in sizes]
def _merge_tensors(tensors: list[torch.Tensor | None]) -> torch.Tensor | None:
"""Concatenate tensors along batch dim, or return None if all are None."""
non_none = [t for t in tensors if t is not None]
if not non_none:
return None
return torch.cat(non_none, dim=0)
class BatchSplitAdapter(nn.Module):
"""Wraps a model and splits batched forward calls into smaller chunks.
Has the same ``forward`` signature as ``X0Model``:
``(video, audio, perturbations) -> (denoised_video, denoised_audio)``.
Args:
model: The model to wrap (``X0Model``, ``LayerStreamingWrapper``, etc.).
max_batch_size: Maximum batch size per forward pass. Input batches
larger than this are split into sequential chunks.
"""
def __init__(self, model: nn.Module, max_batch_size: int) -> None:
if max_batch_size < 1:
raise ValueError(f"max_batch_size must be >= 1, got {max_batch_size}")
super().__init__()
self._model = model
self._max_batch_size = max_batch_size
def _get_chunk_sizes(self, batch_size: int) -> list[int]:
full, remainder = divmod(batch_size, self._max_batch_size)
sizes = [self._max_batch_size] * full
if remainder:
sizes.append(remainder)
return sizes
def forward(
self,
video: Modality | None,
audio: Modality | None,
perturbations: BatchedPerturbationConfig,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
batch_size = (video or audio).latent.shape[0]
if batch_size <= self._max_batch_size:
return self._model(video=video, audio=audio, perturbations=perturbations)
sizes = self._get_chunk_sizes(batch_size)
n = len(sizes)
v_chunks = video.split(sizes) if video is not None else [None] * n
a_chunks = audio.split(sizes) if audio is not None else [None] * n
p_chunks = _split_perturbations(perturbations, sizes)
chunk_results = [
self._model(video=vc, audio=ac, perturbations=pc)
for vc, ac, pc in zip(v_chunks, a_chunks, p_chunks, strict=True)
]
results_v, results_a = zip(*chunk_results, strict=True)
return _merge_tensors(list(results_v)), _merge_tensors(list(results_a))
def __getattr__(self, name: str) -> Any: # noqa: ANN401
"""Proxy attribute access to the wrapped model."""
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self._model, name)
|