from __future__ import annotations import logging from collections.abc import Iterator import torch from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams from ltx_core.components.noisers import GaussianNoiser from ltx_core.components.schedulers import LTX2Scheduler from ltx_core.conditioning.types.noise_mask_cond import TemporalRegionMask from ltx_core.loader import LoraPathStrengthAndSDOps from ltx_core.loader.registry import Registry from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number from ltx_core.quantization import QuantizationPolicy from ltx_core.types import ( SpatioTemporalScaleFactors, ) from ltx_pipelines.utils.args import video_editing_arg_parser from ltx_pipelines.utils.blocks import ( AudioConditioner, AudioDecoder, DiffusionStage, ImageConditioner, PromptEncoder, VideoDecoder, ) from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, detect_params from ltx_pipelines.utils.denoisers import GuidedDenoiser, SimpleDenoiser from ltx_pipelines.utils.helpers import ( audio_latent_from_file, get_device, video_latent_from_file, ) from ltx_pipelines.utils.media_io import ( encode_video, get_videostream_metadata, ) from ltx_pipelines.utils.types import ModalitySpec class RetakePipeline: """Regenerate a time region (retake) of an existing video. Given a source video file and a time window ``[start_time, end_time]`` (in seconds), this pipeline keeps the video/audio outside that window unchanged and *regenerates* the content inside the window from a text prompt using the LTX-2 diffusion model. Parameters ---------- checkpoint_path : str Path to the LTX-2 model checkpoint. gemma_root : str Root directory containing Gemma text-encoder weights. loras : list[LoraPathStrengthAndSDOps] Optional LoRA configs applied to the transformer. device : torch.device Target device (default: CUDA if available). quantization : QuantizationPolicy | None Optional quantization policy for the transformer. distilled : bool Set to ``True`` if using distilled model or passing distillation lora with full model. If set to ``True``, distilled sigma schedule (``DISTILLED_SIGMA_VALUES``) and a simple (non-guided) denoising function will be used during ``__call__``. """ def __init__( self, checkpoint_path: str, gemma_root: str, loras: list[LoraPathStrengthAndSDOps], device: torch.device | None = None, quantization: QuantizationPolicy | None = None, registry: Registry | None = None, distilled: bool = True, torch_compile: bool = False, ): self.device = device or get_device() self.dtype = torch.bfloat16 self.distilled = distilled self.prompt_encoder = PromptEncoder( checkpoint_path=checkpoint_path, gemma_root=gemma_root, dtype=self.dtype, device=self.device, registry=registry, ) self.image_conditioner = ImageConditioner( checkpoint_path=checkpoint_path, dtype=self.dtype, device=self.device, registry=registry, ) self.audio_conditioner = AudioConditioner( checkpoint_path=checkpoint_path, dtype=self.dtype, device=self.device, registry=registry, ) self.stage = DiffusionStage( checkpoint_path=checkpoint_path, dtype=self.dtype, device=self.device, loras=tuple(loras), quantization=quantization, registry=registry, torch_compile=torch_compile, ) self.video_decoder = VideoDecoder( checkpoint_path=checkpoint_path, dtype=self.dtype, device=self.device, registry=registry, ) self.audio_decoder = AudioDecoder( checkpoint_path=checkpoint_path, dtype=self.dtype, device=self.device, registry=registry, ) # --------------------------------------------------------------------- # # Public entry point # # --------------------------------------------------------------------- # def __call__( # noqa: PLR0913 self, video_path: str, prompt: str, start_time: float, end_time: float, seed: int, *, negative_prompt: str = "", num_inference_steps: int = 40, video_guider_params: MultiModalGuiderParams | None = None, audio_guider_params: MultiModalGuiderParams | None = None, regenerate_video: bool = True, regenerate_audio: bool = True, enhance_prompt: bool = False, tiling_config: TilingConfig | None = None, streaming_prefetch_count: int | None = None, max_batch_size: int = 1, ) -> tuple[Iterator[torch.Tensor], torch.Tensor]: """Regenerate ``[start_time, end_time]`` of the source video (retake). Parameters ---------- video_path : str Path to the source video file (must contain video; audio is optional). prompt : str Text prompt describing the *regenerated* section. start_time, end_time : float Time window (in seconds) of the section to regenerate. seed : int Random seed for reproducibility. negative_prompt : str Negative prompt for CFG guidance (ignored in distilled mode). num_inference_steps : int Number of Euler denoising steps (ignored in distilled mode which uses a fixed 8-step schedule). video_guider_params, audio_guider_params : MultiModalGuiderParams | None Guidance parameters for video and audio modalities. Ignored in distilled mode. regenerate_video : bool If ``True`` (default), regenerate video inside ``[start_time, end_time]``. If ``False``, video is preserved as-is (no regeneration). regenerate_audio : bool If True, regenerate audio in the [start_time, end_time] window; if False, audio is preserved as-is (no regeneration). enhance_prompt : bool Whether to enhance the prompt via the text encoder. Returns ------- tuple[Iterator[torch.Tensor], torch.Tensor] ``(video_frames_iterator, audio_waveform)`` """ if start_time >= end_time: raise ValueError(f"start_time ({start_time}) must be less than end_time ({end_time})") generator = torch.Generator(device=self.device).manual_seed(seed) noiser = GaussianNoiser(generator=generator) dtype = self.dtype output_shape = get_videostream_metadata(video_path) initial_video_latent = self.image_conditioner( lambda enc: video_latent_from_file( video_encoder=enc, file_path=video_path, output_shape=output_shape, dtype=dtype, device=self.device, ) ) initial_audio_latent = self.audio_conditioner( lambda enc: audio_latent_from_file( audio_encoder=enc, file_path=video_path, output_shape=output_shape, dtype=dtype, device=self.device, ) ) prompts_to_encode = [prompt] if self.distilled else [prompt, negative_prompt] contexts = self.prompt_encoder( prompts_to_encode, enhance_first_prompt=enhance_prompt, enhance_prompt_seed=seed, streaming_prefetch_count=streaming_prefetch_count, ) v_context_p, a_context_p = contexts[0].video_encoding, contexts[0].audio_encoding video_modality_spec = ModalitySpec( context=v_context_p, conditionings=[TemporalRegionMask(start_time=start_time, end_time=end_time, fps=output_shape.fps)] if regenerate_video else [], initial_latent=initial_video_latent, frozen=not regenerate_video, ) audio_modality_spec = ModalitySpec( context=a_context_p, conditionings=[TemporalRegionMask(start_time=start_time, end_time=end_time, fps=output_shape.fps)] if (initial_audio_latent is not None and regenerate_audio) else [], initial_latent=initial_audio_latent, frozen=initial_audio_latent is not None and not regenerate_audio, ) # Build denoiser if self.distilled: sigmas = torch.tensor(DISTILLED_SIGMA_VALUES).to(dtype=torch.float32, device=self.device) denoiser = SimpleDenoiser( v_context=v_context_p, a_context=a_context_p, ) else: sigmas = LTX2Scheduler().execute(steps=num_inference_steps).to(dtype=torch.float32, device=self.device) v_context_n, a_context_n = contexts[1].video_encoding, contexts[1].audio_encoding video_guider = MultiModalGuider( params=video_guider_params, negative_context=v_context_n, ) audio_guider = MultiModalGuider( params=audio_guider_params, negative_context=a_context_n, ) denoiser = GuidedDenoiser( v_context=v_context_p, a_context=a_context_p, video_guider=video_guider, audio_guider=audio_guider, ) # Run diffusion stage video_state, audio_state = self.stage( denoiser=denoiser, sigmas=sigmas, noiser=noiser, width=output_shape.width, height=output_shape.height, frames=output_shape.frames, fps=output_shape.fps, video=video_modality_spec, audio=audio_modality_spec, streaming_prefetch_count=streaming_prefetch_count, max_batch_size=max_batch_size, ) # Decode decoded_video = self.video_decoder(video_state.latent, tiling_config, generator) decoded_audio = self.audio_decoder(audio_state.latent) return decoded_video, decoded_audio @torch.inference_mode() def main() -> None: """CLI entry point for retake (regenerate a time region).""" logging.getLogger().setLevel(logging.INFO) parser = video_editing_arg_parser(distilled=True) parser.description = "Retake: regenerate a time region of a video with LTX-2." args = parser.parse_args() if args.start_time >= args.end_time: raise ValueError("start_time must be less than end_time") # Validate frame count (8k+1) and resolution (multiples of 32) at CLI stage video_scale = SpatioTemporalScaleFactors.default() src = get_videostream_metadata(args.video_path) if (src.frames - 1) % video_scale.time != 0: snapped = ((src.frames - 1) // video_scale.time) * video_scale.time + 1 raise ValueError( f"Video frame count must satisfy 8k+1 (e.g. 97, 193). Got {src.frames}; use a video with {snapped} frames." ) if src.width % 32 != 0 or src.height % 32 != 0: raise ValueError(f"Video width and height must be multiples of 32. Got {src.width}x{src.height}.") pipeline = RetakePipeline( checkpoint_path=args.distilled_checkpoint_path, gemma_root=args.gemma_root, loras=tuple(args.lora) if args.lora else (), quantization=args.quantization, distilled=args.distilled, torch_compile=args.compile, ) params = detect_params(args.distilled_checkpoint_path) tiling_config = TilingConfig.default() video_iter, audio = pipeline( video_path=args.video_path, prompt=args.prompt, start_time=args.start_time, end_time=args.end_time, seed=args.seed, video_guider_params=params.video_guider_params, audio_guider_params=params.audio_guider_params, tiling_config=tiling_config, streaming_prefetch_count=args.streaming_prefetch_count, max_batch_size=args.max_batch_size, ) video_chunks_number = get_video_chunks_number(src.frames, tiling_config) encode_video( video=video_iter, fps=int(src.fps), audio=audio, output_path=args.output_path, video_chunks_number=video_chunks_number, ) if __name__ == "__main__": main()