Manmay's picture
DramaBox Space — initial app + vendored ltx2
08c5e28 verified
import argparse
from pathlib import Path
from typing import NamedTuple
from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps
from ltx_core.quantization import QuantizationPolicy
from ltx_pipelines.utils.constants import (
DEFAULT_IMAGE_CRF,
DEFAULT_LORA_STRENGTH,
DEFAULT_NEGATIVE_PROMPT,
LTX_2_3_HQ_PARAMS,
LTX_2_3_PARAMS,
PipelineParams,
)
class ImageConditioningInput(NamedTuple):
path: str
frame_idx: int
strength: float
crf: int = DEFAULT_IMAGE_CRF
class VideoConditioningAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser, # noqa: ARG002
namespace: argparse.Namespace,
values: list[str],
option_string: str | None = None, # noqa: ARG002
) -> None:
path, strength_str = values
resolved_path = resolve_path(path)
strength = float(strength_str)
current = getattr(namespace, self.dest) or []
current.append((resolved_path, strength))
setattr(namespace, self.dest, current)
class VideoMaskConditioningAction(argparse.Action):
"""Parse ``--conditioning-attention-mask PATH STRENGTH``.
Stores a ``(mask_path, strength)`` tuple on the namespace. The mask video
should be grayscale with pixel values in [0, 1] controlling per-region
conditioning attention strength. The scalar *STRENGTH* is multiplied with
the spatial mask before it is applied.
"""
def __call__(
self,
parser: argparse.ArgumentParser, # noqa: ARG002
namespace: argparse.Namespace,
values: list[str],
option_string: str | None = None,
) -> None:
if len(values) != 2:
msg = f"{option_string} requires exactly 2 arguments (MASK_PATH STRENGTH), got {len(values)}"
raise argparse.ArgumentError(self, msg)
mask_path = resolve_path(values[0])
strength = float(values[1])
setattr(namespace, self.dest, (mask_path, strength))
class ImageAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser, # noqa: ARG002
namespace: argparse.Namespace,
values: list[str],
option_string: str | None = None,
) -> None:
if len(values) not in (3, 4):
msg = f"{option_string} requires 3 or 4 arguments (PATH FRAME_IDX STRENGTH [CRF]), got {len(values)}"
raise argparse.ArgumentError(self, msg)
conditioning = ImageConditioningInput(
path=resolve_path(values[0]),
frame_idx=int(values[1]),
strength=float(values[2]),
crf=int(values[3]) if len(values) > 3 else DEFAULT_IMAGE_CRF,
)
current = getattr(namespace, self.dest) or []
current.append(conditioning)
setattr(namespace, self.dest, current)
class LoraAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser, # noqa: ARG002
namespace: argparse.Namespace,
values: list[str],
option_string: str | None = None,
) -> None:
if len(values) > 2:
msg = f"{option_string} accepts at most 2 arguments (PATH and optional STRENGTH), got {len(values)} values"
raise argparse.ArgumentError(self, msg)
path = values[0]
strength_str = values[1] if len(values) > 1 else str(DEFAULT_LORA_STRENGTH)
resolved_path = resolve_path(path)
strength = float(strength_str)
current = getattr(namespace, self.dest) or []
current.append(LoraPathStrengthAndSDOps(resolved_path, strength, LTXV_LORA_COMFY_RENAMING_MAP))
setattr(namespace, self.dest, current)
def resolve_path(path: str) -> str:
return str(Path(path).expanduser().resolve().as_posix())
QUANTIZATION_POLICIES = ("fp8-cast", "fp8-scaled-mm")
class QuantizationAction(argparse.Action):
def __call__(
self,
parser: argparse.ArgumentParser, # noqa: ARG002
namespace: argparse.Namespace,
values: list[str],
option_string: str | None = None,
) -> None:
if len(values) > 2:
msg = (
f"{option_string} accepts at most 2 arguments (POLICY and optional AMAX_PATH), got {len(values)} values"
)
raise argparse.ArgumentError(self, msg)
policy_name = values[0]
if policy_name not in QUANTIZATION_POLICIES:
msg = f"Unknown quantization policy '{policy_name}'. Choose from: {', '.join(QUANTIZATION_POLICIES)}"
raise argparse.ArgumentError(self, msg)
if policy_name == "fp8-cast":
if len(values) > 1:
msg = f"{option_string} fp8-cast does not accept additional arguments"
raise argparse.ArgumentError(self, msg)
policy = QuantizationPolicy.fp8_cast()
elif policy_name == "fp8-scaled-mm":
amax_path = resolve_path(values[1]) if len(values) > 1 else None
policy = QuantizationPolicy.fp8_scaled_mm(amax_path)
setattr(namespace, self.dest, policy)
def detect_checkpoint_path(distilled: bool = False) -> str:
"""Pre-parse argv to extract the checkpoint path before building the full parser."""
pre = argparse.ArgumentParser(add_help=False)
flag = "--distilled-checkpoint-path" if distilled else "--checkpoint-path"
pre.add_argument(flag, type=resolve_path, required=True)
known, _ = pre.parse_known_args()
return known.distilled_checkpoint_path if distilled else known.checkpoint_path
def basic_arg_parser(
params: PipelineParams = LTX_2_3_PARAMS,
distilled: bool = False,
) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
if distilled:
parser.add_argument(
"--distilled-checkpoint-path",
type=resolve_path,
required=True,
help="Path to LTX-2 distilled model checkpoint (.safetensors file).",
)
else:
parser.add_argument(
"--checkpoint-path",
type=resolve_path,
required=True,
help="Path to LTX-2 model checkpoint (.safetensors file).",
)
parser.add_argument(
"--num-inference-steps",
type=int,
default=params.num_inference_steps,
help=(
f"Number of denoising steps in the diffusion sampling process. "
f"Higher values improve quality but increase generation time (default: {params.num_inference_steps})."
),
)
parser.add_argument(
"--gemma-root",
type=resolve_path,
required=True,
help="Path to the root directory containing the Gemma text encoder model files.",
)
parser.add_argument(
"--prompt",
type=str,
required=True,
help="Text prompt describing the desired video content to be generated by the model.",
)
parser.add_argument(
"--output-path",
type=resolve_path,
required=True,
help="Path to the output video file (MP4 format).",
)
parser.add_argument(
"--seed",
type=int,
default=params.seed,
help=f"Random seed for reproducible generation (default: {params.seed}).",
)
parser.add_argument(
"--lora",
dest="lora",
action=LoraAction,
nargs="+", # Accept 1-2 arguments per use (path and optional strength); validation is handled in LoraAction
metavar=("PATH", "STRENGTH"),
default=[],
help=(
"LoRA (Low-Rank Adaptation) model: path to model file and optional strength "
f"(default strength: {DEFAULT_LORA_STRENGTH}). Can be specified multiple times. "
"Example: --lora path/to/lora1.safetensors 0.8 --lora path/to/lora2.safetensors"
),
)
parser.add_argument("--enhance-prompt", action="store_true")
def _positive_int(value: str) -> int:
try:
int_value = int(value)
if int_value < 1:
raise argparse.ArgumentTypeError("must be >= 1")
return int_value
except ValueError as e:
raise argparse.ArgumentTypeError(f"must be an integer, got {value}") from e
# Layer streaming
parser.add_argument(
"--streaming-prefetch-count",
type=_positive_int,
default=None,
metavar="N",
help=(
"Enable layer streaming prefetching N layers ahead. "
"At most 1 + N layers reside on GPU at once. "
"Must be >= 1. Example: --streaming-prefetch-count 2"
),
)
parser.add_argument(
"--max-batch-size",
type=_positive_int,
default=1,
metavar="N",
help=(
"Maximum batch size per transformer forward pass. "
"Guided denoisers batch up to 4 guidance passes into a single call. "
"Default 1 runs passes sequentially. Set to 4 to batch all passes "
"together, which reduces layer-streaming PCIe transfers. "
"Example: --max-batch-size 4"
),
)
parser.add_argument(
"--quantization",
dest="quantization",
action=QuantizationAction,
nargs="+",
metavar=("POLICY", "AMAX_PATH"),
default=None,
help=(
f"Quantization policy: {', '.join(QUANTIZATION_POLICIES)}. "
"fp8-cast uses FP8 casting with upcasting during inference. "
"fp8-scaled-mm uses FP8 scaled matrix multiplication (optionally provide amax calibration file path). "
"Example: --quantization fp8-cast or --quantization fp8-scaled-mm /path/to/amax.json"
),
)
parser.add_argument(
"--compile",
action="store_true",
help="Enable torch.compile for transformer blocks to optimize performance.",
)
return parser
def new_video_gen_arg_parser(
params: PipelineParams = LTX_2_3_PARAMS,
distilled: bool = False,
) -> argparse.ArgumentParser:
parser = basic_arg_parser(params=params, distilled=distilled)
parser.add_argument(
"--height",
type=int,
default=params.stage_1_height,
help=f"Video height in pixels, divisible by 32 (default: {params.stage_1_height}).",
)
parser.add_argument(
"--width",
type=int,
default=params.stage_1_width,
help=f"Width of the generated video in pixels, should be divisible by 32 (default: {params.stage_1_width}).",
)
parser.add_argument(
"--num-frames",
type=int,
default=params.num_frames,
help=f"Number of frames to generate in the output video sequence, num-frames = (8 x K) + 1, "
f"where k is a non-negative integer (default: {params.num_frames}).",
)
parser.add_argument(
"--frame-rate",
type=float,
default=params.frame_rate,
help=f"Frame rate of the generated video (fps) (default: {params.frame_rate}).",
)
parser.add_argument(
"--image",
dest="images",
action=ImageAction,
nargs="+",
metavar="ARG",
default=[],
help=(
"Image conditioning input: PATH FRAME_IDX STRENGTH [CRF]. "
"PATH is the image file, FRAME_IDX is the target frame index, "
"STRENGTH is the conditioning strength (all three required). "
f"CRF is the optional H.264 compression quality (0=lossless, default: {DEFAULT_IMAGE_CRF}). "
"Can be specified multiple times. Example: --image path/to/image1.jpg 0 0.8 "
"--image path/to/image2.jpg 160 0.9 0"
),
)
return parser
def video_editing_arg_parser(
distilled: bool = True,
) -> argparse.ArgumentParser:
"""Base argument parser for video-editing pipelines (retake, extension, inpainting, sticker movement).
Uses the same actions and conventions as basic_arg_parser but only the args needed for editing
(no height/width/num-frames; resolution comes from input video). Default is distilled checkpoint only.
"""
parser = basic_arg_parser(distilled=distilled)
parser.add_argument("--video-path", type=resolve_path, required=True, help="Path to the source video.")
parser.add_argument("--start-time", type=float, required=True, help="Start time of the region to regenerate (s).")
parser.add_argument("--end-time", type=float, required=True, help="End time of the region to regenerate (s).")
return parser
def default_1_stage_arg_parser(params: PipelineParams = LTX_2_3_PARAMS) -> argparse.ArgumentParser:
video_guider = params.video_guider_params
audio_guider = params.audio_guider_params
parser = new_video_gen_arg_parser(params=params)
parser.add_argument(
"--negative-prompt",
type=str,
default=DEFAULT_NEGATIVE_PROMPT,
help=(
"Negative prompt describing what should not appear in the generated video, "
"used to guide the diffusion process away from unwanted content. "
"Default: a comprehensive negative prompt covering common artifacts and quality issues."
),
)
parser.add_argument(
"--video-cfg-guidance-scale",
type=float,
default=video_guider.cfg_scale,
help=(
f"Classifier-free guidance (CFG) scale controlling how strongly "
f"the model adheres to the video prompt. Higher values increase prompt "
f"adherence but may reduce diversity. 1.0 means no effect "
f"(default: {video_guider.cfg_scale})."
),
)
parser.add_argument(
"--video-stg-guidance-scale",
type=float,
default=video_guider.stg_scale,
help=(
f"STG (Spatio-Temporal Guidance) scale controlling how strongly "
f"the model reacts to the perturbation of the video modality. Higher values increase "
f"the effect but may reduce quality. 0.0 means no effect "
f"(default: {video_guider.stg_scale})."
),
)
parser.add_argument(
"--video-rescale-scale",
type=float,
default=video_guider.rescale_scale,
help=(
f"Rescale scale controlling how strongly "
f"the model rescales the video modality after applying other guidance. Higher values tend to decrease "
f"oversaturation effects. 0.0 means no effect (default: {video_guider.rescale_scale})."
),
)
parser.add_argument(
"--video-stg-blocks",
type=int,
nargs="*",
default=video_guider.stg_blocks,
help=(f"Which transformer blocks to perturb for STG. Default: {video_guider.stg_blocks}."),
)
parser.add_argument(
"--a2v-guidance-scale",
type=float,
default=video_guider.modality_scale,
help=(
f"A2V (Audio-to-Video) guidance scale controlling how strongly "
f"the model reacts to the perturbation of the audio-to-video cross-attention. Higher values may increase "
f"lipsync quality. 1.0 means no effect (default: {video_guider.modality_scale})."
),
)
parser.add_argument(
"--video-skip-step",
type=int,
default=video_guider.skip_step,
help=(
"Video skip step N controls periodic skipping during the video diffusion process: "
"only steps where step_index % (N + 1) == 0 are processed, all others are skipped "
f"(e.g., 0 = no skipping; 1 = skip every other step; 2 = skip 2 of every 3 steps; "
f"default: {video_guider.skip_step})."
),
)
parser.add_argument(
"--audio-cfg-guidance-scale",
type=float,
default=audio_guider.cfg_scale,
help=(
f"Audio CFG (Classifier-free guidance) scale controlling how strongly "
f"the model adheres to the audio prompt. Higher values increase prompt "
f"adherence but may reduce diversity. 1.0 means no effect "
f"(default: {audio_guider.cfg_scale})."
),
)
parser.add_argument(
"--audio-stg-guidance-scale",
type=float,
default=audio_guider.stg_scale,
help=(
f"Audio STG (Spatio-Temporal Guidance) scale controlling how strongly "
f"the model reacts to the perturbation of the audio modality. Higher values increase "
f"the effect but may reduce quality. 0.0 means no effect "
f"(default: {audio_guider.stg_scale})."
),
)
parser.add_argument(
"--audio-rescale-scale",
type=float,
default=audio_guider.rescale_scale,
help=(
f"Audio rescale scale controlling how strongly "
f"the model rescales the audio modality after applying other guidance. "
f"Experimental. 0.0 means no effect (default: {audio_guider.rescale_scale})."
),
)
parser.add_argument(
"--audio-stg-blocks",
type=int,
nargs="*",
default=audio_guider.stg_blocks,
help=(f"Which transformer blocks to perturb for Audio STG. Default: {audio_guider.stg_blocks}."),
)
parser.add_argument(
"--v2a-guidance-scale",
type=float,
default=audio_guider.modality_scale,
help=(
f"V2A (Video-to-Audio) guidance scale controlling how strongly "
f"the model reacts to the perturbation of the video-to-audio cross-attention. Higher values may increase "
f"lipsync quality. 1.0 means no effect (default: {audio_guider.modality_scale})."
),
)
parser.add_argument(
"--audio-skip-step",
type=int,
default=audio_guider.skip_step,
help=(
"Audio skip step N controls periodic skipping during the audio diffusion process: "
"only steps where step_index % (N + 1) == 0 are processed, all others are skipped "
f"(e.g., 0 = no skipping; 1 = skip every other step; 2 = skip 2 of every 3 steps; "
f"default: {audio_guider.skip_step})."
),
)
return parser
def default_2_stage_arg_parser(params: PipelineParams = LTX_2_3_PARAMS) -> argparse.ArgumentParser:
parser = default_1_stage_arg_parser(params=params)
parser.set_defaults(height=params.stage_2_height, width=params.stage_2_width)
# Update help text to reflect 2-stage defaults
for action in parser._actions:
if "--height" in action.option_strings:
action.help = (
f"Height of the generated video in pixels, should be divisible by 64 "
f"(default: {params.stage_2_height})."
)
if "--width" in action.option_strings:
action.help = (
f"Width of the generated video in pixels, should be divisible by 64 (default: {params.stage_2_width})."
)
parser.add_argument(
"--distilled-lora",
dest="distilled_lora",
action=LoraAction,
nargs="+", # Accept 1-2 arguments per use (path and optional strength); validation is handled in LoraAction
metavar=("PATH", "STRENGTH"),
required=True,
help=(
"Distilled LoRA (Low-Rank Adaptation) model used in the second stage (upscaling and refinement): "
f"path to model file and optional strength (default strength: {DEFAULT_LORA_STRENGTH}). "
"The second stage upsamples the video by 2x resolution and refines it using a distilled "
"denoising schedule (fewer steps, no CFG). The distilled LoRA is specifically trained "
"for this refinement process to improve quality at higher resolutions. "
"Example: --distilled-lora path/to/distilled_lora.safetensors 0.8"
),
)
parser.add_argument(
"--spatial-upsampler-path",
type=resolve_path,
required=True,
help=(
"Path to the spatial upsampler model used to increase the resolution "
"of the generated video in the latent space."
),
)
return parser
def hq_2_stage_arg_parser(params: PipelineParams = LTX_2_3_HQ_PARAMS) -> argparse.ArgumentParser:
parser = default_2_stage_arg_parser(params=params)
parser.add_argument(
"--distilled-lora-strength-stage-1",
type=float,
default=0.25,
help=(f"Strength of the distilled LoRA used in the first stage (default: {0.25})."),
)
parser.add_argument(
"--distilled-lora-strength-stage-2",
type=float,
default=0.5,
help=(f"Strength of the distilled LoRA used in the second stage (default: {0.5})."),
)
return parser
def default_2_stage_distilled_arg_parser(params: PipelineParams = LTX_2_3_PARAMS) -> argparse.ArgumentParser:
parser = new_video_gen_arg_parser(params=params, distilled=True)
parser.set_defaults(height=params.stage_2_height, width=params.stage_2_width)
# Update help text to reflect 2-stage defaults
for action in parser._actions:
if "--height" in action.option_strings:
action.help = (
f"Height of the generated video in pixels, should be divisible by 64 "
f"(default: {params.stage_2_height})."
)
if "--width" in action.option_strings:
action.help = (
f"Width of the generated video in pixels, should be divisible by 64 (default: {params.stage_2_width})."
)
parser.add_argument(
"--spatial-upsampler-path",
type=resolve_path,
required=True,
help=(
"Path to the spatial upsampler model used to increase the resolution "
"of the generated video in the latent space."
),
)
return parser