Spaces:
Running on Zero
Running on Zero
| """Pipeline blocks — each block owns its model lifecycle. | |
| Blocks build a model on each ``__call__``, use it, then free GPU memory. | |
| This eliminates manual ``del model; cleanup_memory()`` in pipelines and | |
| removes the need for :class:`ModelLedger`. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from collections.abc import Iterator | |
| from contextlib import AbstractContextManager, contextmanager | |
| from dataclasses import replace | |
| from typing import Callable, TypeVar | |
| import torch | |
| from ltx_core.batch_split import BatchSplitAdapter | |
| from ltx_core.components.diffusion_steps import EulerDiffusionStep | |
| from ltx_core.components.noisers import Noiser | |
| from ltx_core.components.patchifiers import AudioPatchifier, VideoLatentPatchifier | |
| from ltx_core.components.protocols import DiffusionStepProtocol | |
| from ltx_core.layer_streaming import LayerStreamingWrapper | |
| from ltx_core.loader import SDOps | |
| from ltx_core.loader.primitives import LoraPathStrengthAndSDOps | |
| from ltx_core.loader.registry import DummyRegistry, Registry | |
| from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder | |
| from ltx_core.model.audio_vae import ( | |
| AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, | |
| AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, | |
| VOCODER_COMFY_KEYS_FILTER, | |
| AudioDecoderConfigurator, | |
| AudioEncoderConfigurator, | |
| VocoderConfigurator, | |
| ) | |
| from ltx_core.model.audio_vae import ( | |
| decode_audio as vae_decode_audio, | |
| ) | |
| from ltx_core.model.transformer import ( | |
| LTXV_MODEL_COMFY_RENAMING_MAP, | |
| LTXModelConfigurator, | |
| X0Model, | |
| ) | |
| from ltx_core.model.transformer.compiling import COMPILE_TRANSFORMER, modify_sd_ops_for_compilation | |
| from ltx_core.model.upsampler import LatentUpsamplerConfigurator, upsample_video | |
| from ltx_core.model.video_vae import ( | |
| VAE_DECODER_COMFY_KEYS_FILTER, | |
| VAE_ENCODER_COMFY_KEYS_FILTER, | |
| TilingConfig, | |
| VideoDecoderConfigurator, | |
| VideoEncoder, | |
| VideoEncoderConfigurator, | |
| ) | |
| from ltx_core.quantization import QuantizationPolicy | |
| from ltx_core.text_encoders.gemma import ( | |
| EMBEDDINGS_PROCESSOR_KEY_OPS, | |
| GEMMA_LLM_KEY_OPS, | |
| GEMMA_MODEL_OPS, | |
| EmbeddingsProcessorConfigurator, | |
| GemmaTextEncoderConfigurator, | |
| module_ops_from_gemma_root, | |
| ) | |
| from ltx_core.text_encoders.gemma.embeddings_processor import EmbeddingsProcessorOutput | |
| from ltx_core.tools import AudioLatentTools, LatentTools, VideoLatentTools | |
| from ltx_core.types import Audio, AudioLatentShape, LatentState, VideoLatentShape, VideoPixelShape | |
| from ltx_core.utils import find_matching_file | |
| from ltx_pipelines.utils.gpu_model import gpu_model | |
| from ltx_pipelines.utils.helpers import ( | |
| cleanup_memory, | |
| create_noised_state, | |
| generate_enhanced_prompt, | |
| ) | |
| from ltx_pipelines.utils.samplers import euler_denoising_loop | |
| from ltx_pipelines.utils.types import Denoiser, ModalitySpec | |
| logger = logging.getLogger(__name__) | |
| T = TypeVar("T") | |
| _M = TypeVar("_M", bound=torch.nn.Module) | |
| # --------------------------------------------------------------------------- | |
| # Internal helpers | |
| # --------------------------------------------------------------------------- | |
| def _streaming_model( | |
| model: _M, | |
| layers_attr: str, | |
| target_device: torch.device, | |
| prefetch_count: int, | |
| ) -> Iterator[_M]: | |
| """Wrap *model* with :class:`LayerStreamingWrapper`, yield it, then tear down.""" | |
| wrapped = LayerStreamingWrapper( | |
| model, | |
| layers_attr=layers_attr, | |
| target_device=target_device, | |
| prefetch_count=prefetch_count, | |
| ) | |
| try: | |
| yield wrapped # type: ignore[misc] | |
| finally: | |
| wrapped.teardown() | |
| wrapped.to("meta") | |
| cleanup_memory() | |
| # Flush the host (pinned) memory cache so that freed pinned pages are | |
| # returned to the OS. Without this, sequential streaming models | |
| # (e.g. text encoder then transformer) exhaust host memory because the | |
| # CachingHostAllocator keeps freed blocks cached indefinitely. | |
| torch.cuda.synchronize(device=target_device) | |
| try: | |
| if hasattr(torch._C, "_host_emptyCache"): | |
| torch._C._host_emptyCache() | |
| except Exception: | |
| logger.warning("Host empty cache cleanup failed; ignoring.", exc_info=True) | |
| def _build_state( | |
| spec: ModalitySpec, | |
| tools: LatentTools, | |
| noiser: Noiser, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| ) -> LatentState: | |
| """Create a noised latent state from a modality spec and tools.""" | |
| state = create_noised_state( | |
| tools=tools, | |
| conditionings=spec.conditionings, | |
| noiser=noiser, | |
| dtype=dtype, | |
| device=device, | |
| noise_scale=spec.noise_scale, | |
| initial_latent=spec.initial_latent, | |
| ) | |
| if spec.frozen: | |
| state = replace(state, denoise_mask=torch.zeros_like(state.denoise_mask)) | |
| return state | |
| def _cleanup_iter(it: Iterator[torch.Tensor], model: torch.nn.Module) -> Iterator[torch.Tensor]: | |
| """Wrap an iterator to clean up *model* memory once it is exhausted or abandoned.""" | |
| with gpu_model(model): | |
| yield from it | |
| # --------------------------------------------------------------------------- | |
| # DiffusionStage | |
| # --------------------------------------------------------------------------- | |
| class DiffusionStage: | |
| """Owns transformer lifecycle. Builds on each call, frees on exit. | |
| Replaces the manual ``model_ledger.transformer()`` / ``del transformer`` | |
| pattern in every pipeline. | |
| """ | |
| def __init__( | |
| self, | |
| checkpoint_path: str, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| loras: tuple[LoraPathStrengthAndSDOps, ...] = (), | |
| quantization: QuantizationPolicy | None = None, | |
| registry: Registry | None = None, | |
| torch_compile: bool = False, | |
| ) -> None: | |
| self._dtype = dtype | |
| self._device = device | |
| self._quantization = quantization | |
| self._torch_compile = torch_compile | |
| self._transformer_builder = Builder( | |
| model_path=checkpoint_path, | |
| model_class_configurator=LTXModelConfigurator, | |
| model_sd_ops=LTXV_MODEL_COMFY_RENAMING_MAP, | |
| loras=tuple(loras), | |
| registry=registry or DummyRegistry(), | |
| ) | |
| def _build_transformer(self, *, device: torch.device | None = None, **kwargs: object) -> X0Model: | |
| target = device or self._device | |
| sd_ops = self._transformer_builder.model_sd_ops | |
| module_ops = self._transformer_builder.module_ops | |
| loras = self._transformer_builder.loras | |
| if self._torch_compile: | |
| module_ops = (*module_ops, COMPILE_TRANSFORMER) | |
| number_of_layers = self._transformer_builder.model_config()["transformer"]["num_layers"] | |
| sd_ops = modify_sd_ops_for_compilation(sd_ops, number_of_layers) | |
| loras = tuple( | |
| LoraPathStrengthAndSDOps( | |
| lora.path, | |
| lora.strength, | |
| modify_sd_ops_for_compilation( | |
| lora.sd_ops if lora.sd_ops is not None else SDOps(name="identity"), number_of_layers | |
| ), | |
| ) | |
| for lora in loras | |
| ) | |
| if self._quantization is not None: | |
| module_ops = (*module_ops, *self._quantization.module_ops) | |
| sd_ops = SDOps( | |
| name=f"sd_ops_chain_{sd_ops.name}+{self._quantization.sd_ops.name}", | |
| mapping=(*sd_ops.mapping, *self._quantization.sd_ops.mapping), | |
| ) | |
| builder = self._transformer_builder.with_module_ops(module_ops).with_sd_ops(sd_ops).with_loras(loras) | |
| return X0Model(builder.build(device=target, **kwargs)).to(target).eval() | |
| def _transformer_ctx( | |
| self, | |
| streaming_prefetch_count: int | None, | |
| **kwargs: object, | |
| ) -> AbstractContextManager: | |
| if streaming_prefetch_count is not None: | |
| return _streaming_model( | |
| self._build_transformer(device=torch.device("cpu"), **kwargs), | |
| layers_attr="velocity_model.transformer_blocks", | |
| target_device=self._device, | |
| prefetch_count=streaming_prefetch_count, | |
| ) | |
| return gpu_model(self._build_transformer(**kwargs)) | |
| def __call__( # noqa: PLR0913 | |
| self, | |
| denoiser: Denoiser, | |
| sigmas: torch.Tensor, | |
| noiser: Noiser, | |
| width: int, | |
| height: int, | |
| frames: int, | |
| fps: float, | |
| video: ModalitySpec | None = None, | |
| audio: ModalitySpec | None = None, | |
| stepper: DiffusionStepProtocol | None = None, | |
| loop: Callable[..., tuple[LatentState | None, LatentState | None]] | None = None, | |
| streaming_prefetch_count: int | None = None, | |
| max_batch_size: int = 1, | |
| ) -> tuple[LatentState | None, LatentState | None]: | |
| """Build transformer → run denoising loop → free transformer. | |
| Args: | |
| width: Output width in pixels. | |
| height: Output height in pixels. | |
| frames: Number of output frames. | |
| fps: Frame rate. | |
| loop: Denoising loop function. Must accept | |
| ``(sigmas, video_state, audio_state, stepper, transformer, denoiser)`` | |
| as the first six positional arguments. When ``None``, resolves to | |
| :func:`euler_denoising_loop` at call time. | |
| streaming_prefetch_count: When set, build the transformer on CPU and | |
| wrap with :class:`LayerStreamingWrapper` for memory-efficient | |
| inference, prefetching this many layers ahead. | |
| max_batch_size: Maximum batch size per transformer forward pass. | |
| Guided denoisers make up to 4 transformer calls per step. | |
| When set to a value > 1, the transformer batches multiple | |
| calls together, reducing layer-streaming PCIe transfers. | |
| Default ``1`` preserves sequential behavior. | |
| Returns ``(video_state | None, audio_state | None)`` with cleared | |
| conditionings and unpatchified latents for present modalities. | |
| """ | |
| if video is None and audio is None: | |
| raise ValueError("At least one of `video` or `audio` must be provided") | |
| if loop is None: | |
| loop = euler_denoising_loop | |
| if stepper is None: | |
| stepper = EulerDiffusionStep() | |
| pixel_shape = VideoPixelShape(batch=1, frames=frames, height=height, width=width, fps=fps) | |
| video_state: LatentState | None = None | |
| video_tools: LatentTools | None = None | |
| if video is not None: | |
| v_shape = VideoLatentShape.from_pixel_shape(pixel_shape) | |
| video_tools = VideoLatentTools(VideoLatentPatchifier(patch_size=1), v_shape, fps) | |
| video_state = _build_state(video, video_tools, noiser, self._dtype, self._device) | |
| audio_state: LatentState | None = None | |
| audio_tools: LatentTools | None = None | |
| if audio is not None: | |
| a_shape = AudioLatentShape.from_video_pixel_shape(pixel_shape) | |
| audio_tools = AudioLatentTools(AudioPatchifier(patch_size=1), a_shape) | |
| audio_state = _build_state(audio, audio_tools, noiser, self._dtype, self._device) | |
| with self._transformer_ctx(streaming_prefetch_count, video_tools=video_tools) as base_transformer: | |
| transformer = BatchSplitAdapter(base_transformer, max_batch_size=max_batch_size) | |
| video_state, audio_state = loop( | |
| sigmas=sigmas, | |
| video_state=video_state, | |
| audio_state=audio_state, | |
| stepper=stepper, | |
| transformer=transformer, | |
| denoiser=denoiser, | |
| ) | |
| # Post-process: clear conditionings and unpatchify | |
| if video_state is not None and video_tools is not None: | |
| video_state = video_tools.clear_conditioning(video_state) | |
| video_state = video_tools.unpatchify(video_state) | |
| if audio_state is not None and audio_tools is not None: | |
| audio_state = audio_tools.clear_conditioning(audio_state) | |
| audio_state = audio_tools.unpatchify(audio_state) | |
| return video_state, audio_state | |
| # --------------------------------------------------------------------------- | |
| # PromptEncoder | |
| # --------------------------------------------------------------------------- | |
| class PromptEncoder: | |
| """Owns text encoder + embeddings processor lifecycle. | |
| Loads Gemma, encodes prompts, frees Gemma, then loads the embeddings | |
| processor to produce final outputs. | |
| With warm=True, models are built once and kept on GPU for fast repeated calls. | |
| """ | |
| def __init__( | |
| self, | |
| checkpoint_path: str, | |
| gemma_root: str, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| registry: Registry | None = None, | |
| warm: bool = False, | |
| use_bnb_4bit: bool = False, | |
| audio_only: bool = False, | |
| ) -> None: | |
| self._dtype = dtype | |
| self._device = device | |
| self._warm = warm | |
| self._use_bnb_4bit = use_bnb_4bit | |
| self._audio_only = audio_only | |
| module_ops = module_ops_from_gemma_root(gemma_root) | |
| model_folder = find_matching_file(gemma_root, "model*.safetensors").parent | |
| weight_paths = [str(p) for p in model_folder.rglob("*.safetensors")] | |
| self._gemma_root = gemma_root | |
| self._text_encoder_builder = Builder( | |
| model_path=tuple(weight_paths), | |
| model_class_configurator=GemmaTextEncoderConfigurator, | |
| model_sd_ops=GEMMA_LLM_KEY_OPS, | |
| module_ops=(GEMMA_MODEL_OPS, *module_ops), | |
| registry=registry or DummyRegistry(), | |
| ) | |
| self._embeddings_processor_builder = Builder( | |
| model_path=checkpoint_path, | |
| model_class_configurator=EmbeddingsProcessorConfigurator, | |
| model_sd_ops=EMBEDDINGS_PROCESSOR_KEY_OPS, | |
| registry=registry or DummyRegistry(), | |
| ) | |
| # Warm mode: build and keep models on GPU | |
| self._warm_text_encoder = None | |
| self._warm_embeddings_processor = None | |
| if warm: | |
| if use_bnb_4bit: | |
| self._warm_text_encoder = self._load_bnb_4bit_encoder(gemma_root) | |
| else: | |
| self._warm_text_encoder = self._text_encoder_builder.build( | |
| device=self._device, dtype=self._dtype | |
| ).eval() | |
| built_ep = self._embeddings_processor_builder.build( | |
| device=self._device, dtype=self._dtype | |
| ) | |
| # Audio-only mode: delete video components BEFORE .to(device). | |
| # This both frees ~4.8GB VRAM at load time and lets us strip | |
| # text_embedding_projection.video_aggregate_embed.* from the | |
| # checkpoint on disk (otherwise those tensors stay on the meta | |
| # device and .to(device) errors with "cannot copy out of meta"). | |
| if audio_only: | |
| import logging as _log | |
| ep = built_ep | |
| freed = 0 | |
| # 1. Replace video_connector with None and patch create_embeddings | |
| if ep.video_connector is not None: | |
| try: | |
| freed += sum(p.numel() * p.element_size() for p in ep.video_connector.parameters() if not p.is_meta) | |
| except Exception: | |
| pass | |
| del ep.video_connector | |
| ep.video_connector = None | |
| # 2. Replace video_aggregate_embed with a dummy that returns zeros | |
| fe = ep.feature_extractor | |
| if hasattr(fe, 'video_aggregate_embed') and fe.video_aggregate_embed is not None: | |
| try: | |
| freed += sum(p.numel() * p.element_size() for p in fe.video_aggregate_embed.parameters() if not p.is_meta) | |
| except Exception: | |
| pass | |
| out_features = fe.video_aggregate_embed.out_features | |
| del fe.video_aggregate_embed | |
| # Dummy that returns zeros with correct shape | |
| class _DummyVideoEmbed(torch.nn.Module): | |
| def __init__(self, out_f): | |
| super().__init__() | |
| self.out_features = out_f | |
| def forward(self, x): | |
| return torch.zeros(x.shape[0], x.shape[1], self.out_features, | |
| device=x.device, dtype=x.dtype) | |
| fe.video_aggregate_embed = _DummyVideoEmbed(out_features) | |
| # Now move the (post-strip) module onto the target device. | |
| self._warm_embeddings_processor = built_ep.to(self._device).eval() | |
| if audio_only and self._warm_embeddings_processor is not None: | |
| ep = self._warm_embeddings_processor | |
| # 3. Patch create_embeddings to skip video connector | |
| _orig_create = ep.create_embeddings | |
| def _audio_only_create(video_features, audio_features, additive_attention_mask, | |
| _ep=ep): | |
| # Skip video connector entirely — only run audio connector | |
| # Create binary mask from additive mask | |
| # additive_attention_mask: [B, 1, seq, seq] or [B, 1, 1, seq] | |
| m = additive_attention_mask | |
| while m.dim() > 2: | |
| m = m[:, 0] | |
| # m is now [B, seq] — binary: 0 = attend, -inf = mask | |
| binary_mask = (m >= -1.0).to(torch.int64) | |
| audio_encoded = None | |
| if _ep.audio_connector is not None: | |
| audio_encoded, _ = _ep.audio_connector(audio_features, additive_attention_mask) | |
| return video_features, audio_encoded, binary_mask | |
| ep.create_embeddings = _audio_only_create | |
| torch.cuda.empty_cache() | |
| import gc; gc.collect() | |
| _log.info(f"Audio-only mode: freed video components, saved {freed/1e9:.1f}GB VRAM") | |
| def _load_bnb_4bit_encoder(self, gemma_root: str): | |
| """Load Gemma with bitsandbytes 4-bit quantization for reduced VRAM. | |
| Auto-detects whether the checkpoint at ``gemma_root`` is already | |
| pre-quantized (has ``quantization_config`` in ``config.json``) — in | |
| that case we skip our explicit ``BitsAndBytesConfig`` and let | |
| transformers honour the checkpoint's own quantization metadata. | |
| Passing our own config on top of a pre-quantized checkpoint causes | |
| shape mismatches when transformers tries to quantize already-packed | |
| 4-bit weights a second time. | |
| """ | |
| import json | |
| import logging | |
| import os | |
| from transformers import Gemma3ForConditionalGeneration, BitsAndBytesConfig | |
| from ltx_core.text_encoders.gemma.tokenizer import LTXVGemmaTokenizer | |
| from ltx_core.text_encoders.gemma.encoders.base_encoder import GemmaTextEncoder | |
| # Inspect config.json for an existing quantization_config. | |
| prequantized = False | |
| cfg_path = os.path.join(gemma_root, "config.json") | |
| if os.path.exists(cfg_path): | |
| try: | |
| with open(cfg_path) as f: | |
| cfg = json.load(f) | |
| prequantized = "quantization_config" in cfg | |
| except Exception: | |
| pass | |
| from_kwargs = { | |
| "device_map": str(self._device), | |
| "torch_dtype": self._dtype, | |
| } | |
| if prequantized: | |
| logging.info( | |
| "Loading pre-quantized Gemma (bnb-4bit) — using checkpoint's own quantization_config" | |
| ) | |
| else: | |
| logging.info("Loading Gemma with runtime bitsandbytes 4-bit quantization...") | |
| from_kwargs["quantization_config"] = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=self._dtype, | |
| ) | |
| hf_model = Gemma3ForConditionalGeneration.from_pretrained(gemma_root, **from_kwargs) | |
| tokenizer = LTXVGemmaTokenizer( | |
| str(find_matching_file(gemma_root, "tokenizer.model").parent), 1024 | |
| ) | |
| encoder = GemmaTextEncoder(model=hf_model, tokenizer=tokenizer, dtype=self._dtype) | |
| mem_gb = torch.cuda.memory_allocated(self._device) / 1e9 | |
| logging.info(f"Gemma 4-bit loaded: {mem_gb:.1f}GB VRAM") | |
| return encoder | |
| def _text_encoder_ctx( | |
| self, | |
| streaming_prefetch_count: int | None, | |
| ) -> AbstractContextManager: | |
| if streaming_prefetch_count is not None: | |
| return _streaming_model( | |
| self._text_encoder_builder.build(device=torch.device("cpu"), dtype=self._dtype).eval(), | |
| layers_attr="model.model.language_model.layers", | |
| target_device=self._device, | |
| prefetch_count=streaming_prefetch_count, | |
| ) | |
| return gpu_model(self._text_encoder_builder.build(device=self._device, dtype=self._dtype).eval()) | |
| def _noop_ctx(self, model): | |
| """Context manager that yields model without freeing it.""" | |
| yield model | |
| def __call__( | |
| self, | |
| prompts: list[str], | |
| *, | |
| enhance_first_prompt: bool = False, | |
| enhance_prompt_image: str | None = None, | |
| enhance_prompt_seed: int = 42, | |
| streaming_prefetch_count: int | None = None, | |
| ) -> list[EmbeddingsProcessorOutput]: | |
| """Encode *prompts* through Gemma → embeddings processor, freeing each model after use.""" | |
| if self._warm and self._warm_text_encoder is not None: | |
| # Warm path: reuse cached models, no load/free overhead | |
| text_encoder = self._warm_text_encoder | |
| raw_outputs = [text_encoder.encode(p) for p in prompts] | |
| embeddings_processor = self._warm_embeddings_processor | |
| return [embeddings_processor.process_hidden_states(hs, mask) for hs, mask in raw_outputs] | |
| # Cold path: original load-use-free behavior | |
| with self._text_encoder_ctx(streaming_prefetch_count) as text_encoder: | |
| if enhance_first_prompt: | |
| prompts = list(prompts) | |
| prompts[0] = generate_enhanced_prompt( | |
| text_encoder, prompts[0], enhance_prompt_image, seed=enhance_prompt_seed | |
| ) | |
| raw_outputs = [text_encoder.encode(p) for p in prompts] | |
| with gpu_model( | |
| self._embeddings_processor_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval() | |
| ) as embeddings_processor: | |
| return [embeddings_processor.process_hidden_states(hs, mask) for hs, mask in raw_outputs] | |
| # --------------------------------------------------------------------------- | |
| # ImageConditioner | |
| # --------------------------------------------------------------------------- | |
| class ImageConditioner: | |
| """Owns video encoder lifecycle. | |
| Builds the encoder, passes it to the user-supplied callable, then frees it. | |
| """ | |
| def __init__( | |
| self, | |
| checkpoint_path: str, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| registry: Registry | None = None, | |
| ) -> None: | |
| self._dtype = dtype | |
| self._device = device | |
| self._encoder_builder = Builder( | |
| model_path=checkpoint_path, | |
| model_class_configurator=VideoEncoderConfigurator, | |
| model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, | |
| registry=registry or DummyRegistry(), | |
| ) | |
| def _build_encoder(self) -> VideoEncoder: | |
| return self._encoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval() | |
| def __call__(self, fn: Callable[[VideoEncoder], T]) -> T: | |
| """Build video encoder → call *fn(encoder)* → free encoder.""" | |
| with gpu_model(self._build_encoder()) as encoder: | |
| return fn(encoder) | |
| # --------------------------------------------------------------------------- | |
| # VideoUpsampler | |
| # --------------------------------------------------------------------------- | |
| class VideoUpsampler: | |
| """Owns video encoder + spatial upsampler lifecycle.""" | |
| def __init__( | |
| self, | |
| checkpoint_path: str, | |
| upsampler_path: str, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| registry: Registry | None = None, | |
| ) -> None: | |
| self._dtype = dtype | |
| self._device = device | |
| self._encoder_builder = Builder( | |
| model_path=checkpoint_path, | |
| model_class_configurator=VideoEncoderConfigurator, | |
| model_sd_ops=VAE_ENCODER_COMFY_KEYS_FILTER, | |
| registry=registry or DummyRegistry(), | |
| ) | |
| self._upsampler_builder = Builder( | |
| model_path=upsampler_path, | |
| model_class_configurator=LatentUpsamplerConfigurator, | |
| registry=registry or DummyRegistry(), | |
| ) | |
| def __call__(self, latent: torch.Tensor) -> torch.Tensor: | |
| """Upsample *latent* using video encoder + spatial upsampler, then free both.""" | |
| with ( | |
| gpu_model( | |
| self._encoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval() | |
| ) as encoder, | |
| gpu_model( | |
| self._upsampler_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval() | |
| ) as upsampler, | |
| ): | |
| return upsample_video(latent=latent, video_encoder=encoder, upsampler=upsampler) | |
| # --------------------------------------------------------------------------- | |
| # VideoDecoder | |
| # --------------------------------------------------------------------------- | |
| class VideoDecoder: | |
| """Owns video decoder lifecycle. | |
| Returns an iterator that cleans up the decoder after all chunks are consumed. | |
| """ | |
| def __init__( | |
| self, | |
| checkpoint_path: str, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| registry: Registry | None = None, | |
| ) -> None: | |
| self._dtype = dtype | |
| self._device = device | |
| self._decoder_builder = Builder( | |
| model_path=checkpoint_path, | |
| model_class_configurator=VideoDecoderConfigurator, | |
| model_sd_ops=VAE_DECODER_COMFY_KEYS_FILTER, | |
| registry=registry or DummyRegistry(), | |
| ) | |
| def __call__( | |
| self, | |
| latent: torch.Tensor, | |
| tiling_config: TilingConfig | None = None, | |
| generator: torch.Generator | None = None, | |
| ) -> Iterator[torch.Tensor]: | |
| """Decode *latent* to pixel-space video chunks. Decoder freed after exhaustion.""" | |
| decoder = self._decoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval() | |
| return _cleanup_iter(decoder.decode_video(latent, tiling_config, generator), decoder) | |
| # --------------------------------------------------------------------------- | |
| # AudioDecoder | |
| # --------------------------------------------------------------------------- | |
| class AudioDecoder: | |
| """Owns audio decoder + vocoder lifecycle. With warm=True, keeps models on GPU.""" | |
| def __init__( | |
| self, | |
| checkpoint_path: str, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| registry: Registry | None = None, | |
| warm: bool = False, | |
| ) -> None: | |
| self._dtype = dtype | |
| self._device = device | |
| self._warm = warm | |
| self._decoder_builder = Builder( | |
| model_path=checkpoint_path, | |
| model_class_configurator=AudioDecoderConfigurator, | |
| model_sd_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER, | |
| registry=registry or DummyRegistry(), | |
| ) | |
| self._vocoder_builder = Builder( | |
| model_path=checkpoint_path, | |
| model_class_configurator=VocoderConfigurator, | |
| model_sd_ops=VOCODER_COMFY_KEYS_FILTER, | |
| registry=registry or DummyRegistry(), | |
| ) | |
| self._warm_decoder = None | |
| self._warm_vocoder = None | |
| if warm: | |
| self._warm_decoder = self._decoder_builder.build(device=device, dtype=dtype).to(device).eval() | |
| self._warm_vocoder = self._vocoder_builder.build(device=device, dtype=dtype).to(device).eval() | |
| def __call__(self, latent: torch.Tensor) -> Audio: | |
| """Decode audio *latent* through VAE decoder + vocoder, then free both.""" | |
| if self._warm and self._warm_decoder is not None: | |
| return vae_decode_audio(latent, self._warm_decoder, self._warm_vocoder) | |
| with ( | |
| gpu_model( | |
| self._decoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval() | |
| ) as decoder, | |
| gpu_model( | |
| self._vocoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval() | |
| ) as vocoder, | |
| ): | |
| return vae_decode_audio(latent, decoder, vocoder) | |
| # --------------------------------------------------------------------------- | |
| # AudioEncoder | |
| # --------------------------------------------------------------------------- | |
| class AudioConditioner: | |
| """Owns audio encoder lifecycle. With warm=True, keeps encoder on GPU.""" | |
| def __init__( | |
| self, | |
| checkpoint_path: str, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| registry: Registry | None = None, | |
| warm: bool = False, | |
| ) -> None: | |
| self._dtype = dtype | |
| self._device = device | |
| self._warm = warm | |
| self._encoder_builder = Builder( | |
| model_path=checkpoint_path, | |
| model_class_configurator=AudioEncoderConfigurator, | |
| model_sd_ops=AUDIO_VAE_ENCODER_COMFY_KEYS_FILTER, | |
| registry=registry or DummyRegistry(), | |
| ) | |
| self._warm_encoder = None | |
| if warm: | |
| self._warm_encoder = self._encoder_builder.build(device=device, dtype=dtype).to(device).eval() | |
| def __call__(self, fn: Callable[[torch.nn.Module], T]) -> T: | |
| """Build audio encoder → call *fn(encoder)* → free encoder.""" | |
| if self._warm and self._warm_encoder is not None: | |
| return fn(self._warm_encoder) | |
| with gpu_model( | |
| self._encoder_builder.build(device=self._device, dtype=self._dtype).to(self._device).eval() | |
| ) as encoder: | |
| return fn(encoder) | |