File size: 5,420 Bytes
08c5e28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import logging
from dataclasses import dataclass, field, replace

from safetensors import safe_open

from ltx_core.components.guiders import MultiModalGuiderParams
from ltx_core.types import SpatioTemporalScaleFactors

# =============================================================================
# Diffusion Schedule
# =============================================================================

# Noise schedule for the distilled pipeline. These sigma values control noise
# levels at each denoising step and were tuned to match the distillation process.
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]

# Reduced schedule for super-resolution stage 2 (subset of distilled values)
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]


# =============================================================================
# Pipeline Parameters
# =============================================================================


@dataclass(frozen=True)
class PipelineParams:
    seed: int = 10
    stage_1_height: int = 512
    stage_1_width: int = 768
    num_frames: int = 121
    frame_rate: float = 24.0
    num_inference_steps: int = 40
    video_guider_params: MultiModalGuiderParams = field(
        default_factory=lambda: MultiModalGuiderParams(
            cfg_scale=3.0,
            stg_scale=1.0,
            rescale_scale=0.7,
            modality_scale=3.0,
            skip_step=0,
            stg_blocks=[29],
        )
    )
    audio_guider_params: MultiModalGuiderParams = field(
        default_factory=lambda: MultiModalGuiderParams(
            cfg_scale=7.0,
            stg_scale=1.0,
            rescale_scale=0.7,
            modality_scale=3.0,
            skip_step=0,
            stg_blocks=[29],
        )
    )

    @property
    def stage_2_height(self) -> int:
        return int(self.stage_1_height * 2)

    @property
    def stage_2_width(self) -> int:
        return int(self.stage_1_width * 2)


# Default params for LTX-2.0 non-distilled models. These can be overridden by detecting from checkpoint metadata.
LTX_2_PARAMS = PipelineParams()

# Default params for LTX-2.3 non-distilled models. These override some of the LTX-2.0 defaults.
LTX_2_3_PARAMS = replace(
    LTX_2_PARAMS,
    num_inference_steps=30,
    video_guider_params=replace(LTX_2_PARAMS.video_guider_params, stg_blocks=[28]),
    audio_guider_params=replace(LTX_2_PARAMS.audio_guider_params, stg_blocks=[28]),
)
LTX_2_3_HQ_PARAMS = PipelineParams(
    num_inference_steps=15,
    stage_1_height=1088 // 2,
    stage_1_width=1920 // 2,
    video_guider_params=MultiModalGuiderParams(
        cfg_scale=3.0,
        stg_scale=0.0,
        rescale_scale=0.45,
        modality_scale=3.0,
        skip_step=0,
        stg_blocks=[],
    ),
    audio_guider_params=MultiModalGuiderParams(
        cfg_scale=7.0,
        stg_scale=0.0,
        rescale_scale=1.0,
        modality_scale=3.0,
        skip_step=0,
        stg_blocks=[],
    ),
)

DEFAULT_LORA_STRENGTH = 1.0
DEFAULT_IMAGE_CRF = 33
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
VIDEO_LATENT_CHANNELS = 128

_LTX_2_3_MODEL_VERSION_PREFIX = "2.3"


def detect_params(checkpoint_path: str) -> PipelineParams:
    """Detect pipeline params from checkpoint metadata.
    Reads the ``model_version`` field from the safetensors config metadata.
    Returns ``LTX_2_3_PARAMS`` when the version starts with "2.3",
    otherwise falls back to ``LTX_2_PARAMS``.
    """
    logger = logging.getLogger(__name__)

    try:
        with safe_open(checkpoint_path, framework="pt") as f:
            metadata = f.metadata() or {}
        version = metadata.get("model_version", "")
    except Exception:
        logger.warning("Could not read checkpoint metadata from %s, using LTX-2 defaults", checkpoint_path)
        return LTX_2_PARAMS

    if version.startswith(_LTX_2_3_MODEL_VERSION_PREFIX):
        return LTX_2_3_PARAMS

    logger.info("Using LTX_2_PARAMS for checkpoint (version=%s)", version or "unknown")
    return LTX_2_PARAMS


# =============================================================================
# Prompts
# =============================================================================

DEFAULT_NEGATIVE_PROMPT = (
    "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
    "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
    "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
    "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
    "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
    "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
    "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
    "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
    "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
    "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
    "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
)