import argparse from pathlib import Path from typing import NamedTuple from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps from ltx_core.quantization import QuantizationPolicy from ltx_pipelines.utils.constants import ( DEFAULT_IMAGE_CRF, DEFAULT_LORA_STRENGTH, DEFAULT_NEGATIVE_PROMPT, LTX_2_3_HQ_PARAMS, LTX_2_3_PARAMS, PipelineParams, ) class ImageConditioningInput(NamedTuple): path: str frame_idx: int strength: float crf: int = DEFAULT_IMAGE_CRF class VideoConditioningAction(argparse.Action): def __call__( self, parser: argparse.ArgumentParser, # noqa: ARG002 namespace: argparse.Namespace, values: list[str], option_string: str | None = None, # noqa: ARG002 ) -> None: path, strength_str = values resolved_path = resolve_path(path) strength = float(strength_str) current = getattr(namespace, self.dest) or [] current.append((resolved_path, strength)) setattr(namespace, self.dest, current) class VideoMaskConditioningAction(argparse.Action): """Parse ``--conditioning-attention-mask PATH STRENGTH``. Stores a ``(mask_path, strength)`` tuple on the namespace. The mask video should be grayscale with pixel values in [0, 1] controlling per-region conditioning attention strength. The scalar *STRENGTH* is multiplied with the spatial mask before it is applied. """ def __call__( self, parser: argparse.ArgumentParser, # noqa: ARG002 namespace: argparse.Namespace, values: list[str], option_string: str | None = None, ) -> None: if len(values) != 2: msg = f"{option_string} requires exactly 2 arguments (MASK_PATH STRENGTH), got {len(values)}" raise argparse.ArgumentError(self, msg) mask_path = resolve_path(values[0]) strength = float(values[1]) setattr(namespace, self.dest, (mask_path, strength)) class ImageAction(argparse.Action): def __call__( self, parser: argparse.ArgumentParser, # noqa: ARG002 namespace: argparse.Namespace, values: list[str], option_string: str | None = None, ) -> None: if len(values) not in (3, 4): msg = f"{option_string} requires 3 or 4 arguments (PATH FRAME_IDX STRENGTH [CRF]), got {len(values)}" raise argparse.ArgumentError(self, msg) conditioning = ImageConditioningInput( path=resolve_path(values[0]), frame_idx=int(values[1]), strength=float(values[2]), crf=int(values[3]) if len(values) > 3 else DEFAULT_IMAGE_CRF, ) current = getattr(namespace, self.dest) or [] current.append(conditioning) setattr(namespace, self.dest, current) class LoraAction(argparse.Action): def __call__( self, parser: argparse.ArgumentParser, # noqa: ARG002 namespace: argparse.Namespace, values: list[str], option_string: str | None = None, ) -> None: if len(values) > 2: msg = f"{option_string} accepts at most 2 arguments (PATH and optional STRENGTH), got {len(values)} values" raise argparse.ArgumentError(self, msg) path = values[0] strength_str = values[1] if len(values) > 1 else str(DEFAULT_LORA_STRENGTH) resolved_path = resolve_path(path) strength = float(strength_str) current = getattr(namespace, self.dest) or [] current.append(LoraPathStrengthAndSDOps(resolved_path, strength, LTXV_LORA_COMFY_RENAMING_MAP)) setattr(namespace, self.dest, current) def resolve_path(path: str) -> str: return str(Path(path).expanduser().resolve().as_posix()) QUANTIZATION_POLICIES = ("fp8-cast", "fp8-scaled-mm") class QuantizationAction(argparse.Action): def __call__( self, parser: argparse.ArgumentParser, # noqa: ARG002 namespace: argparse.Namespace, values: list[str], option_string: str | None = None, ) -> None: if len(values) > 2: msg = ( f"{option_string} accepts at most 2 arguments (POLICY and optional AMAX_PATH), got {len(values)} values" ) raise argparse.ArgumentError(self, msg) policy_name = values[0] if policy_name not in QUANTIZATION_POLICIES: msg = f"Unknown quantization policy '{policy_name}'. Choose from: {', '.join(QUANTIZATION_POLICIES)}" raise argparse.ArgumentError(self, msg) if policy_name == "fp8-cast": if len(values) > 1: msg = f"{option_string} fp8-cast does not accept additional arguments" raise argparse.ArgumentError(self, msg) policy = QuantizationPolicy.fp8_cast() elif policy_name == "fp8-scaled-mm": amax_path = resolve_path(values[1]) if len(values) > 1 else None policy = QuantizationPolicy.fp8_scaled_mm(amax_path) setattr(namespace, self.dest, policy) def detect_checkpoint_path(distilled: bool = False) -> str: """Pre-parse argv to extract the checkpoint path before building the full parser.""" pre = argparse.ArgumentParser(add_help=False) flag = "--distilled-checkpoint-path" if distilled else "--checkpoint-path" pre.add_argument(flag, type=resolve_path, required=True) known, _ = pre.parse_known_args() return known.distilled_checkpoint_path if distilled else known.checkpoint_path def basic_arg_parser( params: PipelineParams = LTX_2_3_PARAMS, distilled: bool = False, ) -> argparse.ArgumentParser: parser = argparse.ArgumentParser() if distilled: parser.add_argument( "--distilled-checkpoint-path", type=resolve_path, required=True, help="Path to LTX-2 distilled model checkpoint (.safetensors file).", ) else: parser.add_argument( "--checkpoint-path", type=resolve_path, required=True, help="Path to LTX-2 model checkpoint (.safetensors file).", ) parser.add_argument( "--num-inference-steps", type=int, default=params.num_inference_steps, help=( f"Number of denoising steps in the diffusion sampling process. " f"Higher values improve quality but increase generation time (default: {params.num_inference_steps})." ), ) parser.add_argument( "--gemma-root", type=resolve_path, required=True, help="Path to the root directory containing the Gemma text encoder model files.", ) parser.add_argument( "--prompt", type=str, required=True, help="Text prompt describing the desired video content to be generated by the model.", ) parser.add_argument( "--output-path", type=resolve_path, required=True, help="Path to the output video file (MP4 format).", ) parser.add_argument( "--seed", type=int, default=params.seed, help=f"Random seed for reproducible generation (default: {params.seed}).", ) parser.add_argument( "--lora", dest="lora", action=LoraAction, nargs="+", # Accept 1-2 arguments per use (path and optional strength); validation is handled in LoraAction metavar=("PATH", "STRENGTH"), default=[], help=( "LoRA (Low-Rank Adaptation) model: path to model file and optional strength " f"(default strength: {DEFAULT_LORA_STRENGTH}). Can be specified multiple times. " "Example: --lora path/to/lora1.safetensors 0.8 --lora path/to/lora2.safetensors" ), ) parser.add_argument("--enhance-prompt", action="store_true") def _positive_int(value: str) -> int: try: int_value = int(value) if int_value < 1: raise argparse.ArgumentTypeError("must be >= 1") return int_value except ValueError as e: raise argparse.ArgumentTypeError(f"must be an integer, got {value}") from e # Layer streaming parser.add_argument( "--streaming-prefetch-count", type=_positive_int, default=None, metavar="N", help=( "Enable layer streaming prefetching N layers ahead. " "At most 1 + N layers reside on GPU at once. " "Must be >= 1. Example: --streaming-prefetch-count 2" ), ) parser.add_argument( "--max-batch-size", type=_positive_int, default=1, metavar="N", help=( "Maximum batch size per transformer forward pass. " "Guided denoisers batch up to 4 guidance passes into a single call. " "Default 1 runs passes sequentially. Set to 4 to batch all passes " "together, which reduces layer-streaming PCIe transfers. " "Example: --max-batch-size 4" ), ) parser.add_argument( "--quantization", dest="quantization", action=QuantizationAction, nargs="+", metavar=("POLICY", "AMAX_PATH"), default=None, help=( f"Quantization policy: {', '.join(QUANTIZATION_POLICIES)}. " "fp8-cast uses FP8 casting with upcasting during inference. " "fp8-scaled-mm uses FP8 scaled matrix multiplication (optionally provide amax calibration file path). " "Example: --quantization fp8-cast or --quantization fp8-scaled-mm /path/to/amax.json" ), ) parser.add_argument( "--compile", action="store_true", help="Enable torch.compile for transformer blocks to optimize performance.", ) return parser def new_video_gen_arg_parser( params: PipelineParams = LTX_2_3_PARAMS, distilled: bool = False, ) -> argparse.ArgumentParser: parser = basic_arg_parser(params=params, distilled=distilled) parser.add_argument( "--height", type=int, default=params.stage_1_height, help=f"Video height in pixels, divisible by 32 (default: {params.stage_1_height}).", ) parser.add_argument( "--width", type=int, default=params.stage_1_width, help=f"Width of the generated video in pixels, should be divisible by 32 (default: {params.stage_1_width}).", ) parser.add_argument( "--num-frames", type=int, default=params.num_frames, help=f"Number of frames to generate in the output video sequence, num-frames = (8 x K) + 1, " f"where k is a non-negative integer (default: {params.num_frames}).", ) parser.add_argument( "--frame-rate", type=float, default=params.frame_rate, help=f"Frame rate of the generated video (fps) (default: {params.frame_rate}).", ) parser.add_argument( "--image", dest="images", action=ImageAction, nargs="+", metavar="ARG", default=[], help=( "Image conditioning input: PATH FRAME_IDX STRENGTH [CRF]. " "PATH is the image file, FRAME_IDX is the target frame index, " "STRENGTH is the conditioning strength (all three required). " f"CRF is the optional H.264 compression quality (0=lossless, default: {DEFAULT_IMAGE_CRF}). " "Can be specified multiple times. Example: --image path/to/image1.jpg 0 0.8 " "--image path/to/image2.jpg 160 0.9 0" ), ) return parser def video_editing_arg_parser( distilled: bool = True, ) -> argparse.ArgumentParser: """Base argument parser for video-editing pipelines (retake, extension, inpainting, sticker movement). Uses the same actions and conventions as basic_arg_parser but only the args needed for editing (no height/width/num-frames; resolution comes from input video). Default is distilled checkpoint only. """ parser = basic_arg_parser(distilled=distilled) parser.add_argument("--video-path", type=resolve_path, required=True, help="Path to the source video.") parser.add_argument("--start-time", type=float, required=True, help="Start time of the region to regenerate (s).") parser.add_argument("--end-time", type=float, required=True, help="End time of the region to regenerate (s).") return parser def default_1_stage_arg_parser(params: PipelineParams = LTX_2_3_PARAMS) -> argparse.ArgumentParser: video_guider = params.video_guider_params audio_guider = params.audio_guider_params parser = new_video_gen_arg_parser(params=params) parser.add_argument( "--negative-prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT, help=( "Negative prompt describing what should not appear in the generated video, " "used to guide the diffusion process away from unwanted content. " "Default: a comprehensive negative prompt covering common artifacts and quality issues." ), ) parser.add_argument( "--video-cfg-guidance-scale", type=float, default=video_guider.cfg_scale, help=( f"Classifier-free guidance (CFG) scale controlling how strongly " f"the model adheres to the video prompt. Higher values increase prompt " f"adherence but may reduce diversity. 1.0 means no effect " f"(default: {video_guider.cfg_scale})." ), ) parser.add_argument( "--video-stg-guidance-scale", type=float, default=video_guider.stg_scale, help=( f"STG (Spatio-Temporal Guidance) scale controlling how strongly " f"the model reacts to the perturbation of the video modality. Higher values increase " f"the effect but may reduce quality. 0.0 means no effect " f"(default: {video_guider.stg_scale})." ), ) parser.add_argument( "--video-rescale-scale", type=float, default=video_guider.rescale_scale, help=( f"Rescale scale controlling how strongly " f"the model rescales the video modality after applying other guidance. Higher values tend to decrease " f"oversaturation effects. 0.0 means no effect (default: {video_guider.rescale_scale})." ), ) parser.add_argument( "--video-stg-blocks", type=int, nargs="*", default=video_guider.stg_blocks, help=(f"Which transformer blocks to perturb for STG. Default: {video_guider.stg_blocks}."), ) parser.add_argument( "--a2v-guidance-scale", type=float, default=video_guider.modality_scale, help=( f"A2V (Audio-to-Video) guidance scale controlling how strongly " f"the model reacts to the perturbation of the audio-to-video cross-attention. Higher values may increase " f"lipsync quality. 1.0 means no effect (default: {video_guider.modality_scale})." ), ) parser.add_argument( "--video-skip-step", type=int, default=video_guider.skip_step, help=( "Video skip step N controls periodic skipping during the video diffusion process: " "only steps where step_index % (N + 1) == 0 are processed, all others are skipped " f"(e.g., 0 = no skipping; 1 = skip every other step; 2 = skip 2 of every 3 steps; " f"default: {video_guider.skip_step})." ), ) parser.add_argument( "--audio-cfg-guidance-scale", type=float, default=audio_guider.cfg_scale, help=( f"Audio CFG (Classifier-free guidance) scale controlling how strongly " f"the model adheres to the audio prompt. Higher values increase prompt " f"adherence but may reduce diversity. 1.0 means no effect " f"(default: {audio_guider.cfg_scale})." ), ) parser.add_argument( "--audio-stg-guidance-scale", type=float, default=audio_guider.stg_scale, help=( f"Audio STG (Spatio-Temporal Guidance) scale controlling how strongly " f"the model reacts to the perturbation of the audio modality. Higher values increase " f"the effect but may reduce quality. 0.0 means no effect " f"(default: {audio_guider.stg_scale})." ), ) parser.add_argument( "--audio-rescale-scale", type=float, default=audio_guider.rescale_scale, help=( f"Audio rescale scale controlling how strongly " f"the model rescales the audio modality after applying other guidance. " f"Experimental. 0.0 means no effect (default: {audio_guider.rescale_scale})." ), ) parser.add_argument( "--audio-stg-blocks", type=int, nargs="*", default=audio_guider.stg_blocks, help=(f"Which transformer blocks to perturb for Audio STG. Default: {audio_guider.stg_blocks}."), ) parser.add_argument( "--v2a-guidance-scale", type=float, default=audio_guider.modality_scale, help=( f"V2A (Video-to-Audio) guidance scale controlling how strongly " f"the model reacts to the perturbation of the video-to-audio cross-attention. Higher values may increase " f"lipsync quality. 1.0 means no effect (default: {audio_guider.modality_scale})." ), ) parser.add_argument( "--audio-skip-step", type=int, default=audio_guider.skip_step, help=( "Audio skip step N controls periodic skipping during the audio diffusion process: " "only steps where step_index % (N + 1) == 0 are processed, all others are skipped " f"(e.g., 0 = no skipping; 1 = skip every other step; 2 = skip 2 of every 3 steps; " f"default: {audio_guider.skip_step})." ), ) return parser def default_2_stage_arg_parser(params: PipelineParams = LTX_2_3_PARAMS) -> argparse.ArgumentParser: parser = default_1_stage_arg_parser(params=params) parser.set_defaults(height=params.stage_2_height, width=params.stage_2_width) # Update help text to reflect 2-stage defaults for action in parser._actions: if "--height" in action.option_strings: action.help = ( f"Height of the generated video in pixels, should be divisible by 64 " f"(default: {params.stage_2_height})." ) if "--width" in action.option_strings: action.help = ( f"Width of the generated video in pixels, should be divisible by 64 (default: {params.stage_2_width})." ) parser.add_argument( "--distilled-lora", dest="distilled_lora", action=LoraAction, nargs="+", # Accept 1-2 arguments per use (path and optional strength); validation is handled in LoraAction metavar=("PATH", "STRENGTH"), required=True, help=( "Distilled LoRA (Low-Rank Adaptation) model used in the second stage (upscaling and refinement): " f"path to model file and optional strength (default strength: {DEFAULT_LORA_STRENGTH}). " "The second stage upsamples the video by 2x resolution and refines it using a distilled " "denoising schedule (fewer steps, no CFG). The distilled LoRA is specifically trained " "for this refinement process to improve quality at higher resolutions. " "Example: --distilled-lora path/to/distilled_lora.safetensors 0.8" ), ) parser.add_argument( "--spatial-upsampler-path", type=resolve_path, required=True, help=( "Path to the spatial upsampler model used to increase the resolution " "of the generated video in the latent space." ), ) return parser def hq_2_stage_arg_parser(params: PipelineParams = LTX_2_3_HQ_PARAMS) -> argparse.ArgumentParser: parser = default_2_stage_arg_parser(params=params) parser.add_argument( "--distilled-lora-strength-stage-1", type=float, default=0.25, help=(f"Strength of the distilled LoRA used in the first stage (default: {0.25})."), ) parser.add_argument( "--distilled-lora-strength-stage-2", type=float, default=0.5, help=(f"Strength of the distilled LoRA used in the second stage (default: {0.5})."), ) return parser def default_2_stage_distilled_arg_parser(params: PipelineParams = LTX_2_3_PARAMS) -> argparse.ArgumentParser: parser = new_video_gen_arg_parser(params=params, distilled=True) parser.set_defaults(height=params.stage_2_height, width=params.stage_2_width) # Update help text to reflect 2-stage defaults for action in parser._actions: if "--height" in action.option_strings: action.help = ( f"Height of the generated video in pixels, should be divisible by 64 " f"(default: {params.stage_2_height})." ) if "--width" in action.option_strings: action.help = ( f"Width of the generated video in pixels, should be divisible by 64 (default: {params.stage_2_width})." ) parser.add_argument( "--spatial-upsampler-path", type=resolve_path, required=True, help=( "Path to the spatial upsampler model used to increase the resolution " "of the generated video in the latent space." ), ) return parser