Manmay Nakhashi
Drop video_connector + video_aggregate_embed BEFORE .to(device)
8018d88
"""Pipeline blocks — each block owns its model lifecycle.
Blocks build a model on each ``__call__``, use it, then free GPU memory.
This eliminates manual ``del model; cleanup_memory()`` in pipelines and
removes the need for :class:`ModelLedger`.
"""
from __future__ import annotations
import logging
from collections.abc import Iterator
from contextlib import AbstractContextManager, contextmanager
from dataclasses import replace
from typing import Callable, TypeVar
import torch
from ltx_core.batch_split import BatchSplitAdapter
from ltx_core.components.diffusion_steps import EulerDiffusionStep
from ltx_core.components.noisers import Noiser
from ltx_core.components.patchifiers import AudioPatchifier, VideoLatentPatchifier
from ltx_core.components.protocols import DiffusionStepProtocol
from ltx_core.layer_streaming import LayerStreamingWrapper
from ltx_core.loader import SDOps
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
from ltx_core.loader.registry import DummyRegistry, Registry
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
from ltx_core.model.audio_vae import (
AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
VOCODER_COMFY_KEYS_FILTER,
AudioDecoderConfigurator,
AudioEncoderConfigurator,
VocoderConfigurator,
)
from ltx_core.model.audio_vae import (
decode_audio as vae_decode_audio,
)
from ltx_core.model.transformer import (
LTXV_MODEL_COMFY_RENAMING_MAP,
LTXModelConfigurator,
X0Model,
)
from ltx_core.model.transformer.compiling import COMPILE_TRANSFORMER, modify_sd_ops_for_compilation
from ltx_core.model.upsampler import LatentUpsamplerConfigurator, upsample_video
from ltx_core.model.video_vae import (
VAE_DECODER_COMFY_KEYS_FILTER,
VAE_ENCODER_COMFY_KEYS_FILTER,
TilingConfig,
VideoDecoderConfigurator,
VideoEncoder,
VideoEncoderConfigurator,
)
from ltx_core.quantization import QuantizationPolicy
from ltx_core.text_encoders.gemma import (
EMBEDDINGS_PROCESSOR_KEY_OPS,
GEMMA_LLM_KEY_OPS,
GEMMA_MODEL_OPS,
EmbeddingsProcessorConfigurator,
GemmaTextEncoderConfigurator,
module_ops_from_gemma_root,
)
from ltx_core.text_encoders.gemma.embeddings_processor import EmbeddingsProcessorOutput
from ltx_core.tools import AudioLatentTools, LatentTools, VideoLatentTools
from ltx_core.types import Audio, AudioLatentShape, LatentState, VideoLatentShape, VideoPixelShape
from ltx_core.utils import find_matching_file
from ltx_pipelines.utils.gpu_model import gpu_model
from ltx_pipelines.utils.helpers import (
cleanup_memory,
create_noised_state,
generate_enhanced_prompt,
)
from ltx_pipelines.utils.samplers import euler_denoising_loop
from ltx_pipelines.utils.types import Denoiser, ModalitySpec
logger = logging.getLogger(__name__)
T = TypeVar("T")
_M = TypeVar("_M", bound=torch.nn.Module)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
@contextmanager
def _streaming_model(
model: _M,
layers_attr: str,
target_device: torch.device,
prefetch_count: int,
) -> Iterator[_M]:
"""Wrap *model* with :class:`LayerStreamingWrapper`, yield it, then tear down."""
wrapped = LayerStreamingWrapper(
model,
layers_attr=layers_attr,
target_device=target_device,
prefetch_count=prefetch_count,
)
try:
yield wrapped # type: ignore[misc]
finally:
wrapped.teardown()
wrapped.to("meta")
cleanup_memory()
# Flush the host (pinned) memory cache so that freed pinned pages are
# returned to the OS. Without this, sequential streaming models
# (e.g. text encoder then transformer) exhaust host memory because the
# CachingHostAllocator keeps freed blocks cached indefinitely.
torch.cuda.synchronize(device=target_device)
try:
if hasattr(torch._C, "_host_emptyCache"):
torch._C._host_emptyCache()
except Exception:
logger.warning("Host empty cache cleanup failed; ignoring.", exc_info=True)
def _build_state(
spec: ModalitySpec,
tools: LatentTools,
noiser: Noiser,
dtype: torch.dtype,
device: torch.device,
) -> LatentState:
"""Create a noised latent state from a modality spec and tools."""
state = create_noised_state(
tools=tools,
conditionings=spec.conditionings,
noiser=noiser,
dtype=dtype,
device=device,
noise_scale=spec.noise_scale,
initial_latent=spec.initial_latent,
)
if spec.frozen:
state = replace(state, denoise_mask=torch.zeros_like(state.denoise_mask))
return state
def _cleanup_iter(it: Iterator[torch.Tensor], model: torch.nn.Module) -> Iterator[torch.Tensor]:
"""Wrap an iterator to clean up *model* memory once it is exhausted or abandoned."""
with gpu_model(model):
yield from it
# ---------------------------------------------------------------------------
# DiffusionStage
# ---------------------------------------------------------------------------
class DiffusionStage:
"""Owns transformer lifecycle. Builds on each call, frees on exit.
Replaces the manual ``model_ledger.transformer()`` / ``del transformer``
pattern in every pipeline.
"""
def __init__(
self,
checkpoint_path: str,
dtype: torch.dtype,
device: torch.device,
loras: tuple[LoraPathStrengthAndSDOps, ...] = (),
quantization: QuantizationPolicy | None = None,
registry: Registry | None = None,
torch_compile: bool = False,
) -> None:
self._dtype = dtype
self._device = device
self._quantization = quantization
self._torch_compile = torch_compile
self._transformer_builder = Builder(
model_path=checkpoint_path,
model_class_configurator=LTXModelConfigurator,
model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP,
loras=tuple(loras),
registry=registry or DummyRegistry(),
)
def _build_transformer(self, *, device: torch.device | None = None, **kwargs: object) -> X0Model:
target = device or self._device
sd_ops = self._transformer_builder.model_sd_ops
module_ops = self._transformer_builder.module_ops
loras = self._transformer_builder.loras
if self._torch_compile:
module_ops = (*module_ops, COMPILE_TRANSFORMER)
number_of_layers = self._transformer_builder.model_config()["transformer"]["num_layers"]
sd_ops = modify_sd_ops_for_compilation(sd_ops, number_of_layers)
loras = tuple(
LoraPathStrengthAndSDOps(
lora.path,
lora.strength,
modify_sd_ops_for_compilation(
lora.sd_ops if lora.sd_ops is not None else SDOps(name="identity"), number_of_layers
),
)
for lora in loras
)
if self._quantization is not None:
module_ops = (*module_ops, *self._quantization.module_ops)
sd_ops = SDOps(
name=f"sd_ops_chain_{sd_ops.name}+{self._quantization.sd_ops.name}",
mapping=(*sd_ops.mapping, *self._quantization.sd_ops.mapping),
)
builder = self._transformer_builder.with_module_ops(module_ops).with_sd_ops(sd_ops).with_loras(loras)
return X0Model(builder.build(device=target, **kwargs)).to(target).eval()
def _transformer_ctx(
self,
streaming_prefetch_count: int | None,
**kwargs: object,
) -> AbstractContextManager:
if streaming_prefetch_count is not None:
return _streaming_model(
self._build_transformer(device=torch.device("cpu"), **kwargs),
layers_attr="velocity_model.transformer_blocks",
target_device=self._device,
prefetch_count=streaming_prefetch_count,
)
return gpu_model(self._build_transformer(**kwargs))
def __call__( # noqa: PLR0913
self,
denoiser: Denoiser,
sigmas: torch.Tensor,
noiser: Noiser,
width: int,
height: int,
frames: int,
fps: float,
video: ModalitySpec | None = None,
audio: ModalitySpec | None = None,
stepper: DiffusionStepProtocol | None = None,
loop: Callable[..., tuple[LatentState | None, LatentState | None]] | None = None,
streaming_prefetch_count: int | None = None,
max_batch_size: int = 1,
) -> tuple[LatentState | None, LatentState | None]:
"""Build transformer → run denoising loop → free transformer.
Args:
width: Output width in pixels.
height: Output height in pixels.
frames: Number of output frames.
fps: Frame rate.
loop: Denoising loop function. Must accept
``(sigmas, video_state, audio_state, stepper, transformer, denoiser)``
as the first six positional arguments. When ``None``, resolves to
:func:`euler_denoising_loop` at call time.
streaming_prefetch_count: When set, build the transformer on CPU and
wrap with :class:`LayerStreamingWrapper` for memory-efficient
inference, prefetching this many layers ahead.
max_batch_size: Maximum batch size per transformer forward pass.
Guided denoisers make up to 4 transformer calls per step.
When set to a value > 1, the transformer batches multiple
calls together, reducing layer-streaming PCIe transfers.
Default ``1`` preserves sequential behavior.
Returns ``(video_state | None, audio_state | None)`` with cleared
conditionings and unpatchified latents for present modalities.
"""
if video is None and audio is None:
raise ValueError("At least one of `video` or `audio` must be provided")
if loop is None:
loop = euler_denoising_loop
if stepper is None:
stepper = EulerDiffusionStep()
pixel_shape = VideoPixelShape(batch=1, frames=frames, height=height, width=width, fps=fps)
video_state: LatentState | None = None
video_tools: LatentTools | None = None
if video is not None:
v_shape = VideoLatentShape.from_pixel_shape(pixel_shape)
video_tools = VideoLatentTools(VideoLatentPatchifier(patch_size=1), v_shape, fps)
video_state = _build_state(video, video_tools, noiser, self._dtype, self._device)
audio_state: LatentState | None = None
audio_tools: LatentTools | None = None
if audio is not None:
a_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape)
audio_tools = AudioLatentTools(AudioPatchifier(patch_size=1), a_shape)
audio_state = _build_state(audio, audio_tools, noiser, self._dtype, self._device)
with self._transformer_ctx(streaming_prefetch_count, video_tools=video_tools) as base_transformer:
transformer = BatchSplitAdapter(base_transformer, max_batch_size=max_batch_size)
video_state, audio_state = loop(
sigmas=sigmas,
video_state=video_state,
audio_state=audio_state,
stepper=stepper,
transformer=transformer,
denoiser=denoiser,
)
# Post-process: clear conditionings and unpatchify
if video_state is not None and video_tools is not None:
video_state = video_tools.clear_conditioning(video_state)
video_state = video_tools.unpatchify(video_state)
if audio_state is not None and audio_tools is not None:
audio_state = audio_tools.clear_conditioning(audio_state)
audio_state = audio_tools.unpatchify(audio_state)
return video_state, audio_state
# ---------------------------------------------------------------------------
# PromptEncoder
# ---------------------------------------------------------------------------
class PromptEncoder:
"""Owns text encoder + embeddings processor lifecycle.
Loads Gemma, encodes prompts, frees Gemma, then loads the embeddings
processor to produce final outputs.
With warm=True, models are built once and kept on GPU for fast repeated calls.
"""
def __init__(
self,
checkpoint_path: str,
gemma_root: str,
dtype: torch.dtype,
device: torch.device,
registry: Registry | None = None,
warm: bool = False,
use_bnb_4bit: bool = False,
audio_only: bool = False,
) -> None:
self._dtype = dtype
self._device = device
self._warm = warm
self._use_bnb_4bit = use_bnb_4bit
self._audio_only = audio_only
module_ops = module_ops_from_gemma_root(gemma_root)
model_folder = find_matching_file(gemma_root, "model*.safetensors").parent
weight_paths = [str(p) for p in model_folder.rglob("*.safetensors")]
self._gemma_root = gemma_root
self._text_encoder_builder = Builder(
model_path=tuple(weight_paths),
model_class_configurator=GemmaTextEncoderConfigurator,
model_sd_ops=GEMMA_LLM_KEY_OPS,
module_ops=(GEMMA_MODEL_OPS, *module_ops),
registry=registry or DummyRegistry(),
)
self._embeddings_processor_builder = Builder(
model_path=checkpoint_path,
model_class_configurator=EmbeddingsProcessorConfigurator,
model_sd_ops=EMBEDDINGS_PROCESSOR_KEY_OPS,
registry=registry or DummyRegistry(),
)
# Warm mode: build and keep models on GPU
self._warm_text_encoder = None
self._warm_embeddings_processor = None
if warm:
if use_bnb_4bit:
self._warm_text_encoder = self._load_bnb_4bit_encoder(gemma_root)
else:
self._warm_text_encoder = self._text_encoder_builder.build(
device=self._device, dtype=self._dtype
).eval()
built_ep = self._embeddings_processor_builder.build(
device=self._device, dtype=self._dtype
)
# Audio-only mode: delete video components BEFORE .to(device).
# This both frees ~4.8GB VRAM at load time and lets us strip
# text_embedding_projection.video_aggregate_embed.* from the
# checkpoint on disk (otherwise those tensors stay on the meta
# device and .to(device) errors with "cannot copy out of meta").
if audio_only:
import logging as _log
ep = built_ep
freed = 0
# 1. Replace video_connector with None and patch create_embeddings
if ep.video_connector is not None:
try:
freed += sum(p.numel() * p.element_size() for p in ep.video_connector.parameters() if not p.is_meta)
except Exception:
pass
del ep.video_connector
ep.video_connector = None
# 2. Replace video_aggregate_embed with a dummy that returns zeros
fe = ep.feature_extractor
if hasattr(fe, 'video_aggregate_embed') and fe.video_aggregate_embed is not None:
try:
freed += sum(p.numel() * p.element_size() for p in fe.video_aggregate_embed.parameters() if not p.is_meta)
except Exception:
pass
out_features = fe.video_aggregate_embed.out_features
del fe.video_aggregate_embed
# Dummy that returns zeros with correct shape
class _DummyVideoEmbed(torch.nn.Module):
def __init__(self, out_f):
super().__init__()
self.out_features = out_f
def forward(self, x):
return torch.zeros(x.shape[0], x.shape[1], self.out_features,
device=x.device, dtype=x.dtype)
fe.video_aggregate_embed = _DummyVideoEmbed(out_features)
# Now move the (post-strip) module onto the target device.
self._warm_embeddings_processor = built_ep.to(self._device).eval()
if audio_only and self._warm_embeddings_processor is not None:
ep = self._warm_embeddings_processor
# 3. Patch create_embeddings to skip video connector
_orig_create = ep.create_embeddings
def _audio_only_create(video_features, audio_features, additive_attention_mask,
_ep=ep):
# Skip video connector entirely — only run audio connector
# Create binary mask from additive mask
# additive_attention_mask: [B, 1, seq, seq] or [B, 1, 1, seq]
m = additive_attention_mask
while m.dim() > 2:
m = m[:, 0]
# m is now [B, seq] — binary: 0 = attend, -inf = mask
binary_mask = (m >= -1.0).to(torch.int64)
audio_encoded = None
if _ep.audio_connector is not None:
audio_encoded, _ = _ep.audio_connector(audio_features, additive_attention_mask)
return video_features, audio_encoded, binary_mask
ep.create_embeddings = _audio_only_create
torch.cuda.empty_cache()
import gc; gc.collect()
_log.info(f"Audio-only mode: freed video components, saved {freed/1e9:.1f}GB VRAM")
def _load_bnb_4bit_encoder(self, gemma_root: str):
"""Load Gemma with bitsandbytes 4-bit quantization for reduced VRAM.
Auto-detects whether the checkpoint at ``gemma_root`` is already
pre-quantized (has ``quantization_config`` in ``config.json``) — in
that case we skip our explicit ``BitsAndBytesConfig`` and let
transformers honour the checkpoint's own quantization metadata.
Passing our own config on top of a pre-quantized checkpoint causes
shape mismatches when transformers tries to quantize already-packed
4-bit weights a second time.
"""
import json
import logging
import os
from transformers import Gemma3ForConditionalGeneration, BitsAndBytesConfig
from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer
from ltx_core.text_encoders.gemma.encoders.base_encoder import GemmaTextEncoder
# Inspect config.json for an existing quantization_config.
prequantized = False
cfg_path = os.path.join(gemma_root, "config.json")
if os.path.exists(cfg_path):
try:
with open(cfg_path) as f:
cfg = json.load(f)
prequantized = "quantization_config" in cfg
except Exception:
pass
from_kwargs = {
"device_map": str(self._device),
"torch_dtype": self._dtype,
}
if prequantized:
logging.info(
"Loading pre-quantized Gemma (bnb-4bit) — using checkpoint's own quantization_config"
)
else:
logging.info("Loading Gemma with runtime bitsandbytes 4-bit quantization...")
from_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=self._dtype,
)
hf_model = Gemma3ForConditionalGeneration.from_pretrained(gemma_root, **from_kwargs)
tokenizer = LTXVGemmaTokenizer(
str(find_matching_file(gemma_root, "tokenizer.model").parent), 1024
)
encoder = GemmaTextEncoder(model=hf_model, tokenizer=tokenizer, dtype=self._dtype)
mem_gb = torch.cuda.memory_allocated(self._device) / 1e9
logging.info(f"Gemma 4-bit loaded: {mem_gb:.1f}GB VRAM")
return encoder
def _text_encoder_ctx(
self,
streaming_prefetch_count: int | None,
) -> AbstractContextManager:
if streaming_prefetch_count is not None:
return _streaming_model(
self._text_encoder_builder.build(device=torch.device("cpu"), dtype=self._dtype).eval(),
layers_attr="model.model.language_model.layers",
target_device=self._device,
prefetch_count=streaming_prefetch_count,
)
return gpu_model(self._text_encoder_builder.build(device=self._device, dtype=self._dtype).eval())
@contextmanager
def _noop_ctx(self, model):
"""Context manager that yields model without freeing it."""
yield model
def __call__(
self,
prompts: list[str],
*,
enhance_first_prompt: bool = False,
enhance_prompt_image: str | None = None,
enhance_prompt_seed: int = 42,
streaming_prefetch_count: int | None = None,
) -> list[EmbeddingsProcessorOutput]:
"""Encode *prompts* through Gemma → embeddings processor, freeing each model after use."""
if self._warm and self._warm_text_encoder is not None:
# Warm path: reuse cached models, no load/free overhead
text_encoder = self._warm_text_encoder
raw_outputs = [text_encoder.encode(p) for p in prompts]
embeddings_processor = self._warm_embeddings_processor
return [embeddings_processor.process_hidden_states(hs, mask) for hs, mask in raw_outputs]
# Cold path: original load-use-free behavior
with self._text_encoder_ctx(streaming_prefetch_count) as text_encoder:
if enhance_first_prompt:
prompts = list(prompts)
prompts[0] = generate_enhanced_prompt(
text_encoder, prompts[0], enhance_prompt_image, seed=enhance_prompt_seed
)
raw_outputs = [text_encoder.encode(p) for p in prompts]
with gpu_model(
self._embeddings_processor_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval()
) as embeddings_processor:
return [embeddings_processor.process_hidden_states(hs, mask) for hs, mask in raw_outputs]
# ---------------------------------------------------------------------------
# ImageConditioner
# ---------------------------------------------------------------------------
class ImageConditioner:
"""Owns video encoder lifecycle.
Builds the encoder, passes it to the user-supplied callable, then frees it.
"""
def __init__(
self,
checkpoint_path: str,
dtype: torch.dtype,
device: torch.device,
registry: Registry | None = None,
) -> None:
self._dtype = dtype
self._device = device
self._encoder_builder = Builder(
model_path=checkpoint_path,
model_class_configurator=VideoEncoderConfigurator,
model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER,
registry=registry or DummyRegistry(),
)
def _build_encoder(self) -> VideoEncoder:
return self._encoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval()
def __call__(self, fn: Callable[[VideoEncoder], T]) -> T:
"""Build video encoder → call *fn(encoder)* → free encoder."""
with gpu_model(self._build_encoder()) as encoder:
return fn(encoder)
# ---------------------------------------------------------------------------
# VideoUpsampler
# ---------------------------------------------------------------------------
class VideoUpsampler:
"""Owns video encoder + spatial upsampler lifecycle."""
def __init__(
self,
checkpoint_path: str,
upsampler_path: str,
dtype: torch.dtype,
device: torch.device,
registry: Registry | None = None,
) -> None:
self._dtype = dtype
self._device = device
self._encoder_builder = Builder(
model_path=checkpoint_path,
model_class_configurator=VideoEncoderConfigurator,
model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER,
registry=registry or DummyRegistry(),
)
self._upsampler_builder = Builder(
model_path=upsampler_path,
model_class_configurator=LatentUpsamplerConfigurator,
registry=registry or DummyRegistry(),
)
def __call__(self, latent: torch.Tensor) -> torch.Tensor:
"""Upsample *latent* using video encoder + spatial upsampler, then free both."""
with (
gpu_model(
self._encoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval()
) as encoder,
gpu_model(
self._upsampler_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval()
) as upsampler,
):
return upsample_video(latent=latent, video_encoder=encoder, upsampler=upsampler)
# ---------------------------------------------------------------------------
# VideoDecoder
# ---------------------------------------------------------------------------
class VideoDecoder:
"""Owns video decoder lifecycle.
Returns an iterator that cleans up the decoder after all chunks are consumed.
"""
def __init__(
self,
checkpoint_path: str,
dtype: torch.dtype,
device: torch.device,
registry: Registry | None = None,
) -> None:
self._dtype = dtype
self._device = device
self._decoder_builder = Builder(
model_path=checkpoint_path,
model_class_configurator=VideoDecoderConfigurator,
model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER,
registry=registry or DummyRegistry(),
)
def __call__(
self,
latent: torch.Tensor,
tiling_config: TilingConfig | None = None,
generator: torch.Generator | None = None,
) -> Iterator[torch.Tensor]:
"""Decode *latent* to pixel-space video chunks. Decoder freed after exhaustion."""
decoder = self._decoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval()
return _cleanup_iter(decoder.decode_video(latent, tiling_config, generator), decoder)
# ---------------------------------------------------------------------------
# AudioDecoder
# ---------------------------------------------------------------------------
class AudioDecoder:
"""Owns audio decoder + vocoder lifecycle. With warm=True, keeps models on GPU."""
def __init__(
self,
checkpoint_path: str,
dtype: torch.dtype,
device: torch.device,
registry: Registry | None = None,
warm: bool = False,
) -> None:
self._dtype = dtype
self._device = device
self._warm = warm
self._decoder_builder = Builder(
model_path=checkpoint_path,
model_class_configurator=AudioDecoderConfigurator,
model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
registry=registry or DummyRegistry(),
)
self._vocoder_builder = Builder(
model_path=checkpoint_path,
model_class_configurator=VocoderConfigurator,
model_sd_ops=VOCODER_COMFY_KEYS_FILTER,
registry=registry or DummyRegistry(),
)
self._warm_decoder = None
self._warm_vocoder = None
if warm:
self._warm_decoder = self._decoder_builder.build(device=device, dtype=dtype).to(device).eval()
self._warm_vocoder = self._vocoder_builder.build(device=device, dtype=dtype).to(device).eval()
def __call__(self, latent: torch.Tensor) -> Audio:
"""Decode audio *latent* through VAE decoder + vocoder, then free both."""
if self._warm and self._warm_decoder is not None:
return vae_decode_audio(latent, self._warm_decoder, self._warm_vocoder)
with (
gpu_model(
self._decoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval()
) as decoder,
gpu_model(
self._vocoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval()
) as vocoder,
):
return vae_decode_audio(latent, decoder, vocoder)
# ---------------------------------------------------------------------------
# AudioEncoder
# ---------------------------------------------------------------------------
class AudioConditioner:
"""Owns audio encoder lifecycle. With warm=True, keeps encoder on GPU."""
def __init__(
self,
checkpoint_path: str,
dtype: torch.dtype,
device: torch.device,
registry: Registry | None = None,
warm: bool = False,
) -> None:
self._dtype = dtype
self._device = device
self._warm = warm
self._encoder_builder = Builder(
model_path=checkpoint_path,
model_class_configurator=AudioEncoderConfigurator,
model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER,
registry=registry or DummyRegistry(),
)
self._warm_encoder = None
if warm:
self._warm_encoder = self._encoder_builder.build(device=device, dtype=dtype).to(device).eval()
def __call__(self, fn: Callable[[torch.nn.Module], T]) -> T:
"""Build audio encoder → call *fn(encoder)* → free encoder."""
if self._warm and self._warm_encoder is not None:
return fn(self._warm_encoder)
with gpu_model(
self._encoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval()
) as encoder:
return fn(encoder)