Spaces:
Running on Zero
Running on Zero
File size: 5,420 Bytes
08c5e28 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import logging
from dataclasses import dataclass, field, replace
from safetensors import safe_open
from ltx_core.components.guiders import MultiModalGuiderParams
from ltx_core.types import SpatioTemporalScaleFactors
# =============================================================================
# Diffusion Schedule
# =============================================================================
# Noise schedule for the distilled pipeline. These sigma values control noise
# levels at each denoising step and were tuned to match the distillation process.
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
# =============================================================================
# Pipeline Parameters
# =============================================================================
@dataclass(frozen=True)
class PipelineParams:
seed: int = 10
stage_1_height: int = 512
stage_1_width: int = 768
num_frames: int = 121
frame_rate: float = 24.0
num_inference_steps: int = 40
video_guider_params: MultiModalGuiderParams = field(
default_factory=lambda: MultiModalGuiderParams(
cfg_scale=3.0,
stg_scale=1.0,
rescale_scale=0.7,
modality_scale=3.0,
skip_step=0,
stg_blocks=[29],
)
)
audio_guider_params: MultiModalGuiderParams = field(
default_factory=lambda: MultiModalGuiderParams(
cfg_scale=7.0,
stg_scale=1.0,
rescale_scale=0.7,
modality_scale=3.0,
skip_step=0,
stg_blocks=[29],
)
)
@property
def stage_2_height(self) -> int:
return int(self.stage_1_height * 2)
@property
def stage_2_width(self) -> int:
return int(self.stage_1_width * 2)
# Default params for LTX-2.0 non-distilled models. These can be overridden by detecting from checkpoint metadata.
LTX_2_PARAMS = PipelineParams()
# Default params for LTX-2.3 non-distilled models. These override some of the LTX-2.0 defaults.
LTX_2_3_PARAMS = replace(
LTX_2_PARAMS,
num_inference_steps=30,
video_guider_params=replace(LTX_2_PARAMS.video_guider_params, stg_blocks=[28]),
audio_guider_params=replace(LTX_2_PARAMS.audio_guider_params, stg_blocks=[28]),
)
LTX_2_3_HQ_PARAMS = PipelineParams(
num_inference_steps=15,
stage_1_height=1088 // 2,
stage_1_width=1920 // 2,
video_guider_params=MultiModalGuiderParams(
cfg_scale=3.0,
stg_scale=0.0,
rescale_scale=0.45,
modality_scale=3.0,
skip_step=0,
stg_blocks=[],
),
audio_guider_params=MultiModalGuiderParams(
cfg_scale=7.0,
stg_scale=0.0,
rescale_scale=1.0,
modality_scale=3.0,
skip_step=0,
stg_blocks=[],
),
)
DEFAULT_LORA_STRENGTH = 1.0
DEFAULT_IMAGE_CRF = 33
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
VIDEO_LATENT_CHANNELS = 128
_LTX_2_3_MODEL_VERSION_PREFIX = "2.3"
def detect_params(checkpoint_path: str) -> PipelineParams:
"""Detect pipeline params from checkpoint metadata.
Reads the ``model_version`` field from the safetensors config metadata.
Returns ``LTX_2_3_PARAMS`` when the version starts with "2.3",
otherwise falls back to ``LTX_2_PARAMS``.
"""
logger = logging.getLogger(__name__)
try:
with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata() or {}
version = metadata.get("model_version", "")
except Exception:
logger.warning("Could not read checkpoint metadata from %s, using LTX-2 defaults", checkpoint_path)
return LTX_2_PARAMS
if version.startswith(_LTX_2_3_MODEL_VERSION_PREFIX):
return LTX_2_3_PARAMS
logger.info("Using LTX_2_PARAMS for checkpoint (version=%s)", version or "unknown")
return LTX_2_PARAMS
# =============================================================================
# Prompts
# =============================================================================
DEFAULT_NEGATIVE_PROMPT = (
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)
|