dagloop5 commited on
Commit
8d88e86
·
verified ·
1 Parent(s): 1be38e6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -195
app.py CHANGED
@@ -46,38 +46,47 @@ import spaces
46
  import gradio as gr
47
  import numpy as np
48
  from huggingface_hub import hf_hub_download, snapshot_download
49
- from safetensors.torch import load_file, save_file
50
- from safetensors import safe_open
51
- import json
52
- import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  from ltx_core.components.diffusion_steps import EulerDiffusionStep
55
  from ltx_core.components.noisers import GaussianNoiser
56
- from ltx_core.components.protocols import DiffusionStepProtocol
57
- from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
58
  from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
59
  from ltx_core.model.upsampler import upsample_video
60
  from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number, decode_video as vae_decode_video
61
  from ltx_core.quantization import QuantizationPolicy
62
- from ltx_core.types import Audio, LatentState, AudioLatentShape, VideoPixelShape
63
  from ltx_pipelines.distilled import DistilledPipeline
64
- from ltx_pipelines.utils import ModelLedger, euler_denoising_loop
65
  from ltx_pipelines.utils.args import ImageConditioningInput
66
  from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
67
  from ltx_pipelines.utils.helpers import (
68
  cleanup_memory,
69
  combined_image_conditionings,
70
  denoise_video_only,
71
- denoise_audio_video,
72
- get_device,
73
  encode_prompts,
74
  simple_denoising_func,
75
  )
76
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
77
- from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
78
- from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
79
-
80
- from ltx_pipelines.utils.types import PipelineComponents
81
 
82
  # Force-patch xformers attention into the LTX attention module.
83
  from ltx_core.model.transformer import attention as _attn_mod
@@ -107,35 +116,9 @@ RESOLUTIONS = {
107
  }
108
 
109
 
110
- class LTX23DistilledA2VPipeline:
111
  """DistilledPipeline with optional audio conditioning."""
112
 
113
- def __init__(
114
- self,
115
- distilled_checkpoint_path: str,
116
- gemma_root: str,
117
- spatial_upsampler_path: str,
118
- loras: tuple,
119
- quantization: QuantizationPolicy | None = None,
120
- ):
121
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
- self.dtype = torch.bfloat16
123
-
124
- self.model_ledger = ModelLedger(
125
- dtype=self.dtype,
126
- device=self.device,
127
- checkpoint_path=distilled_checkpoint_path,
128
- spatial_upsampler_path=spatial_upsampler_path,
129
- gemma_root_path=gemma_root,
130
- loras=loras,
131
- quantization=quantization,
132
- )
133
-
134
- self.pipeline_components = PipelineComponents(
135
- dtype=self.dtype,
136
- device=self.device,
137
- )
138
-
139
  def __call__(
140
  self,
141
  prompt: str,
@@ -145,9 +128,24 @@ class LTX23DistilledA2VPipeline:
145
  num_frames: int,
146
  frame_rate: float,
147
  images: list[ImageConditioningInput],
 
148
  tiling_config: TilingConfig | None = None,
149
  enhance_prompt: bool = False,
150
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  generator = torch.Generator(device=self.device).manual_seed(seed)
153
  noiser = GaussianNoiser(generator=generator)
@@ -158,18 +156,38 @@ class LTX23DistilledA2VPipeline:
158
  [prompt],
159
  self.model_ledger,
160
  enhance_first_prompt=enhance_prompt,
161
- enhance_prompt_image=images[0][0] if len(images) > 0 else None,
162
  )
163
  video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
164
 
165
- # Stage 1: Initial low resolution video generation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  video_encoder = self.model_ledger.video_encoder()
167
  transformer = self.model_ledger.transformer()
168
- stage_1_sigmas = torch.Tensor(DISTILLED_SIGMA_VALUES).to(self.device)
169
 
170
- def denoising_loop(
171
- sigmas: torch.Tensor, video_state: LatentState, audio_state: LatentState, stepper: DiffusionStepProtocol
172
- ) -> tuple[LatentState, LatentState]:
173
  return euler_denoising_loop(
174
  sigmas=sigmas,
175
  video_state=video_state,
@@ -178,15 +196,15 @@ class LTX23DistilledA2VPipeline:
178
  denoise_fn=simple_denoising_func(
179
  video_context=video_context,
180
  audio_context=audio_context,
181
- transformer=transformer, # noqa: F821
182
  ),
183
  )
184
 
185
  stage_1_output_shape = VideoPixelShape(
186
  batch=1,
187
  frames=num_frames,
188
- width=width,
189
- height=height,
190
  fps=frame_rate,
191
  )
192
  stage_1_conditionings = combined_image_conditionings(
@@ -197,8 +215,7 @@ class LTX23DistilledA2VPipeline:
197
  dtype=dtype,
198
  device=self.device,
199
  )
200
-
201
- video_state, audio_state = denoise_audio_video(
202
  output_shape=stage_1_output_shape,
203
  conditionings=stage_1_conditionings,
204
  noiser=noiser,
@@ -208,6 +225,40 @@ class LTX23DistilledA2VPipeline:
208
  components=self.pipeline_components,
209
  dtype=dtype,
210
  device=self.device,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  )
212
 
213
  torch.cuda.synchronize()
@@ -216,12 +267,16 @@ class LTX23DistilledA2VPipeline:
216
  cleanup_memory()
217
 
218
  decoded_video = vae_decode_video(
219
- video_state.latent, self.model_ledger.video_decoder(), tiling_config, generator
 
 
 
220
  )
221
- decoded_audio = vae_decode_audio(
222
- audio_state.latent, self.model_ledger.audio_decoder(), self.model_ledger.vocoder()
 
223
  )
224
- return decoded_video, decoded_audio
225
 
226
 
227
  # Model repos
@@ -233,20 +288,11 @@ print("=" * 80)
233
  print("Downloading LTX-2.3 distilled model + Gemma...")
234
  print("=" * 80)
235
 
236
- # LoRA cache directory and currently-applied key
237
- LORA_CACHE_DIR = Path("lora_cache")
238
- LORA_CACHE_DIR.mkdir(exist_ok=True)
239
- current_lora_key: str | None = None
240
-
241
- PENDING_LORA_KEY: str | None = None
242
- PENDING_LORA_STATE: dict[str, torch.Tensor] | None = None
243
- PENDING_LORA_STATUS: str = "No LoRA state prepared yet."
244
-
245
  weights_dir = Path("weights")
246
  weights_dir.mkdir(exist_ok=True)
247
  checkpoint_path = hf_hub_download(
248
- repo_id="TenStrip/LTX2.3-10Eros",
249
- filename="10Eros_v1_bf16.safetensors",
250
  local_dir=str(weights_dir),
251
  local_dir_use_symlinks=False,
252
  )
@@ -264,7 +310,7 @@ print("=" * 80)
264
  pose_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2_3_NSFW_furry_concat_v2.safetensors")
265
  general_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_reasoning_I2V_V3.safetensors")
266
  motion_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="motion_helper.safetensors")
267
- dreamlay_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="DR34ML4Y_LTXXX_PREVIEW_RC1.safetensors") # m15510n4ry, bl0wj0b, d0ubl3_bj, d0gg1e, c0wg1rl
268
  mself_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="Furry Hyper Masturbation - LTX-2 I2V v1.safetensors") # Hyperfap
269
  dramatic_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX-2.3 - Orgasm.safetensors") # "[He | She] is having am orgasm." (am or an?)
270
  fluid_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_CREAMPIE_ANIMATION-V0.1.safetensors") # cum
@@ -274,7 +320,8 @@ voice_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="hentai_voice_ltx2
274
  realism_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="FurryenhancerLTX2.3V1.215.safetensors")
275
  transition_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX-2_takerpov_lora_v1.2.safetensors") # takerpov1, taker pov
276
  physics_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_Better_Physics_PhysLTX.safetensors")
277
- reasoning_lora_path = hf_hub_download(repo_id="TenStrip/LTX2.3_Distilled_Lora_1.1_Experiments", filename="ltx-2.3-22b-distilled-lora-1.1_fro90_ceil72_condsafe.safetensors")
 
278
 
279
  print(f"Pose LoRA: {pose_lora_path}")
280
  print(f"General LoRA: {general_lora_path}")
@@ -290,6 +337,7 @@ print(f"Realism LoRA: {realism_lora_path}")
290
  print(f"Transition LoRA: {transition_lora_path}")
291
  print(f"Physics LoRA: {physics_lora_path}")
292
  print(f"Reasoning LoRA: {reasoning_lora_path}")
 
293
  # ----------------------------------------------------------------
294
 
295
  print(f"Checkpoint: {checkpoint_path}")
@@ -307,7 +355,7 @@ pipeline = LTX23DistilledA2VPipeline(
307
  )
308
  # ----------------------------------------------------------------
309
 
310
- def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float, dreamlay_strength: float, mself_strength: float, dramatic_strength: float, fluid_strength: float, liquid_strength: float, demopose_strength: float, voice_strength: float, realism_strength: float, transition_strength: float, physics_strength: float, reasoning_strength: float) -> tuple[str, str]:
311
  rp = round(float(pose_strength), 2)
312
  rg = round(float(general_strength), 2)
313
  rm = round(float(motion_strength), 2)
@@ -322,12 +370,12 @@ def _make_lora_key(pose_strength: float, general_strength: float, motion_strengt
322
  rt = round(float(transition_strength), 2)
323
  ry = round(float(physics_strength), 2)
324
  ri = round(float(reasoning_strength), 2)
325
- key_str = f"{pose_lora_path}:{rp}|{general_lora_path}:{rg}|{motion_lora_path}:{rm}|{dreamlay_lora_path}:{rd}|{mself_lora_path}:{rs}|{dramatic_lora_path}:{rr}|{fluid_lora_path}:{rf}|{liquid_lora_path}:{rl}|{demopose_lora_path}:{ro}|{voice_lora_path}:{rv}|{realism_lora_path}:{re}|{transition_lora_path}:{rt}|{physics_lora_path}:{ry}|{reasoning_lora_path}:{ri}"
 
326
  key = hashlib.sha256(key_str.encode("utf-8")).hexdigest()
327
  return key, key_str
328
 
329
-
330
- def prepare_lora_cache(
331
  pose_strength: float,
332
  general_strength: float,
333
  motion_strength: float,
@@ -342,34 +390,10 @@ def prepare_lora_cache(
342
  transition_strength: float,
343
  physics_strength: float,
344
  reasoning_strength: float,
345
- progress=gr.Progress(track_tqdm=True),
346
  ):
347
- """
348
- CPU-only step:
349
- - checks cache
350
- - loads cached fused transformer state_dict, or
351
- - builds fused transformer on CPU and saves it
352
- The resulting state_dict is stored in memory and can be applied later.
353
- """
354
- global PENDING_LORA_KEY, PENDING_LORA_STATE, PENDING_LORA_STATUS
355
-
356
- ledger = pipeline.model_ledger
357
- key, _ = _make_lora_key(pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength)
358
- cache_path = LORA_CACHE_DIR / f"{key}.safetensors"
359
-
360
- progress(0.05, desc="Preparing LoRA state")
361
- if cache_path.exists():
362
- try:
363
- progress(0.20, desc="Loading cached fused state")
364
- state = load_file(str(cache_path))
365
- PENDING_LORA_KEY = key
366
- PENDING_LORA_STATE = state
367
- PENDING_LORA_STATUS = f"Loaded cached LoRA state: {cache_path.name}"
368
- return PENDING_LORA_STATUS
369
- except Exception as e:
370
- print(f"[LoRA] Cache load failed: {type(e).__name__}: {e}")
371
-
372
- entries = [
373
  (pose_lora_path, round(float(pose_strength), 2)),
374
  (general_lora_path, round(float(general_strength), 2)),
375
  (motion_lora_path, round(float(motion_strength), 2)),
@@ -384,91 +408,68 @@ def prepare_lora_cache(
384
  (transition_lora_path, round(float(transition_strength), 2)),
385
  (physics_lora_path, round(float(physics_strength), 2)),
386
  (reasoning_lora_path, round(float(reasoning_strength), 2)),
387
- ]
388
- loras_for_builder = [
389
- LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
390
- for path, strength in entries
391
- if path is not None and float(strength) != 0.0
392
  ]
393
 
394
- if not loras_for_builder:
395
- PENDING_LORA_KEY = None
396
- PENDING_LORA_STATE = None
397
- PENDING_LORA_STATUS = "No non-zero LoRA strengths selected; nothing to prepare."
398
- return PENDING_LORA_STATUS
399
 
400
- tmp_ledger = None
401
- new_transformer_cpu = None
402
- try:
403
- progress(0.35, desc="Building fused CPU transformer")
404
- tmp_ledger = pipeline.model_ledger.__class__(
405
- dtype=ledger.dtype,
406
- device=torch.device("cpu"),
407
- checkpoint_path=str(checkpoint_path),
408
- spatial_upsampler_path=str(spatial_upsampler_path),
409
- gemma_root_path=str(gemma_root),
410
- loras=tuple(loras_for_builder),
411
- quantization=getattr(ledger, "quantization", None),
412
- )
413
- new_transformer_cpu = tmp_ledger.transformer()
 
 
 
 
414
 
415
- progress(0.70, desc="Extracting fused state_dict")
416
- state = {
417
- k: v.detach().cpu().contiguous()
418
- for k, v in new_transformer_cpu.state_dict().items()
419
- }
420
- save_file(state, str(cache_path))
421
 
422
- PENDING_LORA_KEY = key
423
- PENDING_LORA_STATE = state
424
- PENDING_LORA_STATUS = f"Built and cached LoRA state: {cache_path.name}"
425
- return PENDING_LORA_STATUS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
- except Exception as e:
428
- import traceback
429
- print(f"[LoRA] Prepare failed: {type(e).__name__}: {e}")
430
- print(traceback.format_exc())
431
- PENDING_LORA_KEY = None
432
- PENDING_LORA_STATE = None
433
- PENDING_LORA_STATUS = f"LoRA prepare failed: {type(e).__name__}: {e}"
434
- return PENDING_LORA_STATUS
435
-
436
- finally:
437
- try:
438
- del new_transformer_cpu
439
- except Exception:
440
- pass
441
- try:
442
- del tmp_ledger
443
- except Exception:
444
- pass
445
- gc.collect()
446
-
447
-
448
- def apply_prepared_lora_state_to_pipeline():
449
- """
450
- Fast step: copy the already prepared CPU state into the live transformer.
451
- This is the only part that should remain near generation time.
452
- """
453
- global current_lora_key, PENDING_LORA_KEY, PENDING_LORA_STATE
454
-
455
- if PENDING_LORA_STATE is None or PENDING_LORA_KEY is None:
456
- print("[LoRA] No prepared LoRA state available; skipping.")
457
- return False
458
-
459
- if current_lora_key == PENDING_LORA_KEY:
460
- print("[LoRA] Prepared LoRA state already active; skipping.")
461
- return True
462
-
463
- existing_transformer = _transformer
464
  with torch.no_grad():
465
- missing, unexpected = existing_transformer.load_state_dict(PENDING_LORA_STATE, strict=False)
466
  if missing or unexpected:
467
- print(f"[LoRA] load_state_dict mismatch: missing={len(missing)}, unexpected={len(unexpected)}")
 
 
468
 
469
- current_lora_key = PENDING_LORA_KEY
470
- print("[LoRA] Prepared LoRA state applied to the pipeline.")
471
- return True
472
 
473
  # ---- REPLACE PRELOAD BLOCK START ----
474
  # Preload all models for ZeroGPU tensor packing.
@@ -489,6 +490,13 @@ _orig_gemma_embeddings_factory = ledger.gemma_embeddings_processor
489
 
490
  # Call the original factories once to create the cached instances we will serve by default.
491
  _transformer = _orig_transformer_factory()
 
 
 
 
 
 
 
492
  _video_encoder = _orig_video_encoder_factory()
493
  _video_decoder = _orig_video_decoder_factory()
494
  _audio_encoder = _orig_audio_encoder_factory()
@@ -559,6 +567,7 @@ def on_highres_toggle(first_image, last_image, high_res):
559
  def get_gpu_duration(
560
  first_image,
561
  last_image,
 
562
  prompt: str,
563
  duration: float,
564
  gpu_duration: float,
@@ -581,6 +590,7 @@ def get_gpu_duration(
581
  transition_strength: float = 0.0,
582
  physics_strength: float = 0.0,
583
  reasoning_strength: float = 0.0,
 
584
  progress=None,
585
  ):
586
  return int(gpu_duration)
@@ -590,6 +600,7 @@ def get_gpu_duration(
590
  def generate_video(
591
  first_image,
592
  last_image,
 
593
  prompt: str,
594
  duration: float,
595
  gpu_duration: float,
@@ -612,6 +623,7 @@ def generate_video(
612
  transition_strength: float = 0.0,
613
  physics_strength: float = 0.0,
614
  reasoning_strength: float = 0.0,
 
615
  progress=gr.Progress(track_tqdm=True),
616
  ):
617
  try:
@@ -651,8 +663,13 @@ def generate_video(
651
 
652
  log_memory("before pipeline call")
653
 
654
- apply_prepared_lora_state_to_pipeline()
655
-
 
 
 
 
 
656
  video, audio = pipeline(
657
  prompt=prompt,
658
  seed=current_seed,
@@ -661,6 +678,7 @@ def generate_video(
661
  num_frames=num_frames,
662
  frame_rate=frame_rate,
663
  images=images,
 
664
  tiling_config=tiling_config,
665
  enhance_prompt=enhance_prompt,
666
  )
@@ -695,6 +713,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
695
  with gr.Row():
696
  first_image = gr.Image(label="First Frame (Optional)", type="pil")
697
  last_image = gr.Image(label="Last Frame (Optional)", type="pil")
 
698
  prompt = gr.Textbox(
699
  label="Prompt",
700
  info="for best results - make it as elaborate as possible",
@@ -771,15 +790,13 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
771
  minimum=0.0, maximum=2.0, value=0.0, step=0.01
772
  )
773
  reasoning_strength = gr.Slider(
774
- label="Distilled strength",
 
 
 
 
775
  minimum=0.0, maximum=2.0, value=0.0, step=0.01
776
  )
777
- prepare_lora_btn = gr.Button("Prepare / Load LoRA Cache", variant="secondary")
778
- lora_status = gr.Textbox(
779
- label="LoRA Cache Status",
780
- value="No LoRA state prepared yet.",
781
- interactive=False,
782
- )
783
 
784
  with gr.Column():
785
  output_video = gr.Video(label="Generated Video", autoplay=False)
@@ -796,6 +813,7 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
796
  [
797
  None,
798
  "pinkknit.jpg",
 
799
  "The camera falls downward through darkness as if dropped into a tunnel. "
800
  "As it slows, five friends wearing pink knitted hats and sunglasses lean "
801
  "over and look down toward the camera with curious expressions. The lens "
@@ -823,12 +841,13 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
823
  0.0,
824
  0.0,
825
  0.0,
 
826
  ],
827
  ],
828
  inputs=[
829
- first_image, last_image, prompt, duration, gpu_duration,
830
  enhance_prompt, seed, randomize_seed, height, width,
831
- pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength,
832
  ],
833
  )
834
 
@@ -850,18 +869,13 @@ with gr.Blocks(title="LTX-2.3 Distilled") as demo:
850
  outputs=[width, height],
851
  )
852
 
853
- prepare_lora_btn.click(
854
- fn=prepare_lora_cache,
855
- inputs=[pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength],
856
- outputs=[lora_status],
857
- )
858
 
859
  generate_btn.click(
860
  fn=generate_video,
861
  inputs=[
862
- first_image, last_image, prompt, duration, gpu_duration, enhance_prompt,
863
  seed, randomize_seed, height, width,
864
- pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength,
865
  ],
866
  outputs=[output_video, seed],
867
  )
@@ -872,4 +886,4 @@ css = """
872
  """
873
 
874
  if __name__ == "__main__":
875
- demo.launch(theme=gr.themes.Citrus(), css=css)
 
46
  import gradio as gr
47
  import numpy as np
48
  from huggingface_hub import hf_hub_download, snapshot_download
49
+ from safetensors.torch import load_file
50
+ from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
51
+ from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
52
+
53
+ try:
54
+ from ltx_core.loader.fuse_loras import apply_loras
55
+ except ImportError:
56
+ from ltx_core.loader.fuse_loras import fuse_lora_weights
57
+
58
+ def apply_loras(model_sd, loras, dtype=None):
59
+ # fuse_lora_weights is the lower-level helper the repo uses internally;
60
+ # this wrapper turns its output into a regular state_dict.
61
+ return {
62
+ k: v
63
+ for k, v in fuse_lora_weights(
64
+ model_sd,
65
+ loras,
66
+ dtype=dtype,
67
+ preserve_input_device=False,
68
+ )
69
+ }
70
 
71
  from ltx_core.components.diffusion_steps import EulerDiffusionStep
72
  from ltx_core.components.noisers import GaussianNoiser
 
 
73
  from ltx_core.model.audio_vae import encode_audio as vae_encode_audio
74
  from ltx_core.model.upsampler import upsample_video
75
  from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number, decode_video as vae_decode_video
76
  from ltx_core.quantization import QuantizationPolicy
77
+ from ltx_core.types import Audio, AudioLatentShape, VideoPixelShape
78
  from ltx_pipelines.distilled import DistilledPipeline
79
+ from ltx_pipelines.utils import euler_denoising_loop
80
  from ltx_pipelines.utils.args import ImageConditioningInput
81
  from ltx_pipelines.utils.constants import DISTILLED_SIGMA_VALUES, STAGE_2_DISTILLED_SIGMA_VALUES
82
  from ltx_pipelines.utils.helpers import (
83
  cleanup_memory,
84
  combined_image_conditionings,
85
  denoise_video_only,
 
 
86
  encode_prompts,
87
  simple_denoising_func,
88
  )
89
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
 
 
 
 
90
 
91
  # Force-patch xformers attention into the LTX attention module.
92
  from ltx_core.model.transformer import attention as _attn_mod
 
116
  }
117
 
118
 
119
+ class LTX23DistilledA2VPipeline(DistilledPipeline):
120
  """DistilledPipeline with optional audio conditioning."""
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  def __call__(
123
  self,
124
  prompt: str,
 
128
  num_frames: int,
129
  frame_rate: float,
130
  images: list[ImageConditioningInput],
131
+ audio_path: str | None = None,
132
  tiling_config: TilingConfig | None = None,
133
  enhance_prompt: bool = False,
134
  ):
135
+ # Standard path when no audio input is provided.
136
+ print(prompt)
137
+ if audio_path is None:
138
+ return super().__call__(
139
+ prompt=prompt,
140
+ seed=seed,
141
+ height=height,
142
+ width=width,
143
+ num_frames=num_frames,
144
+ frame_rate=frame_rate,
145
+ images=images,
146
+ tiling_config=tiling_config,
147
+ enhance_prompt=enhance_prompt,
148
+ )
149
 
150
  generator = torch.Generator(device=self.device).manual_seed(seed)
151
  noiser = GaussianNoiser(generator=generator)
 
156
  [prompt],
157
  self.model_ledger,
158
  enhance_first_prompt=enhance_prompt,
159
+ enhance_prompt_image=images[0].path if len(images) > 0 else None,
160
  )
161
  video_context, audio_context = ctx_p.video_encoding, ctx_p.audio_encoding
162
 
163
+ video_duration = num_frames / frame_rate
164
+ decoded_audio = decode_audio_from_file(audio_path, self.device, 0.0, video_duration)
165
+ if decoded_audio is None:
166
+ raise ValueError(f"Could not extract audio stream from {audio_path}")
167
+
168
+ encoded_audio_latent = vae_encode_audio(decoded_audio, self.model_ledger.audio_encoder())
169
+ audio_shape = AudioLatentShape.from_duration(batch=1, duration=video_duration, channels=8, mel_bins=16)
170
+ expected_frames = audio_shape.frames
171
+ actual_frames = encoded_audio_latent.shape[2]
172
+
173
+ if actual_frames > expected_frames:
174
+ encoded_audio_latent = encoded_audio_latent[:, :, :expected_frames, :]
175
+ elif actual_frames < expected_frames:
176
+ pad = torch.zeros(
177
+ encoded_audio_latent.shape[0],
178
+ encoded_audio_latent.shape[1],
179
+ expected_frames - actual_frames,
180
+ encoded_audio_latent.shape[3],
181
+ device=encoded_audio_latent.device,
182
+ dtype=encoded_audio_latent.dtype,
183
+ )
184
+ encoded_audio_latent = torch.cat([encoded_audio_latent, pad], dim=2)
185
+
186
  video_encoder = self.model_ledger.video_encoder()
187
  transformer = self.model_ledger.transformer()
188
+ stage_1_sigmas = torch.tensor(DISTILLED_SIGMA_VALUES, device=self.device)
189
 
190
+ def denoising_loop(sigmas, video_state, audio_state, stepper):
 
 
191
  return euler_denoising_loop(
192
  sigmas=sigmas,
193
  video_state=video_state,
 
196
  denoise_fn=simple_denoising_func(
197
  video_context=video_context,
198
  audio_context=audio_context,
199
+ transformer=transformer,
200
  ),
201
  )
202
 
203
  stage_1_output_shape = VideoPixelShape(
204
  batch=1,
205
  frames=num_frames,
206
+ width=width // 2,
207
+ height=height // 2,
208
  fps=frame_rate,
209
  )
210
  stage_1_conditionings = combined_image_conditionings(
 
215
  dtype=dtype,
216
  device=self.device,
217
  )
218
+ video_state = denoise_video_only(
 
219
  output_shape=stage_1_output_shape,
220
  conditionings=stage_1_conditionings,
221
  noiser=noiser,
 
225
  components=self.pipeline_components,
226
  dtype=dtype,
227
  device=self.device,
228
+ initial_audio_latent=encoded_audio_latent,
229
+ )
230
+
231
+ torch.cuda.synchronize()
232
+ cleanup_memory()
233
+
234
+ upscaled_video_latent = upsample_video(
235
+ latent=video_state.latent[:1],
236
+ video_encoder=video_encoder,
237
+ upsampler=self.model_ledger.spatial_upsampler(),
238
+ )
239
+ stage_2_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=self.device)
240
+ stage_2_output_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
241
+ stage_2_conditionings = combined_image_conditionings(
242
+ images=images,
243
+ height=stage_2_output_shape.height,
244
+ width=stage_2_output_shape.width,
245
+ video_encoder=video_encoder,
246
+ dtype=dtype,
247
+ device=self.device,
248
+ )
249
+ video_state = denoise_video_only(
250
+ output_shape=stage_2_output_shape,
251
+ conditionings=stage_2_conditionings,
252
+ noiser=noiser,
253
+ sigmas=stage_2_sigmas,
254
+ stepper=stepper,
255
+ denoising_loop_fn=denoising_loop,
256
+ components=self.pipeline_components,
257
+ dtype=dtype,
258
+ device=self.device,
259
+ noise_scale=stage_2_sigmas[0],
260
+ initial_video_latent=upscaled_video_latent,
261
+ initial_audio_latent=encoded_audio_latent,
262
  )
263
 
264
  torch.cuda.synchronize()
 
267
  cleanup_memory()
268
 
269
  decoded_video = vae_decode_video(
270
+ video_state.latent,
271
+ self.model_ledger.video_decoder(),
272
+ tiling_config,
273
+ generator,
274
  )
275
+ original_audio = Audio(
276
+ waveform=decoded_audio.waveform.squeeze(0),
277
+ sampling_rate=decoded_audio.sampling_rate,
278
  )
279
+ return decoded_video, original_audio
280
 
281
 
282
  # Model repos
 
288
  print("Downloading LTX-2.3 distilled model + Gemma...")
289
  print("=" * 80)
290
 
 
 
 
 
 
 
 
 
 
291
  weights_dir = Path("weights")
292
  weights_dir.mkdir(exist_ok=True)
293
  checkpoint_path = hf_hub_download(
294
+ repo_id="SulphurAI/Sulphur-2-base",
295
+ filename="sulphur_distil_bf16.safetensors",
296
  local_dir=str(weights_dir),
297
  local_dir_use_symlinks=False,
298
  )
 
310
  pose_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2_3_NSFW_furry_concat_v2.safetensors")
311
  general_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_reasoning_I2V_V3.safetensors")
312
  motion_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="motion_helper.safetensors")
313
+ dreamlay_lora_path = hf_hub_download(repo_id="lynaNSFW/DR34ML4Y_AIO_NSFW_LTX23", filename="DR34ML4Y_LTXXX_V1.safetensors") # m15510n4ry, bl0wj0b, d0ubl3_bj, d0gg1e, c0wg1rl
314
  mself_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="Furry Hyper Masturbation - LTX-2 I2V v1.safetensors") # Hyperfap
315
  dramatic_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX-2.3 - Orgasm.safetensors") # "[He | She] is having am orgasm." (am or an?)
316
  fluid_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_CREAMPIE_ANIMATION-V0.1.safetensors") # cum
 
320
  realism_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="FurryenhancerLTX2.3V1.215.safetensors")
321
  transition_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX-2_takerpov_lora_v1.2.safetensors") # takerpov1, taker pov
322
  physics_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_Better_Physics_PhysLTX.safetensors")
323
+ reasoning_lora_path = hf_hub_download(repo_id="LiconStudio/Ltx2.3-VBVR-lora-I2V", filename="Ltx2.3-Licon-VBVR-I2V-390K-R32.safetensors")
324
+ twostep_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="LTX2.3_Multi_step_video_reasoning_V0.1.safetensors")
325
 
326
  print(f"Pose LoRA: {pose_lora_path}")
327
  print(f"General LoRA: {general_lora_path}")
 
337
  print(f"Transition LoRA: {transition_lora_path}")
338
  print(f"Physics LoRA: {physics_lora_path}")
339
  print(f"Reasoning LoRA: {reasoning_lora_path}")
340
+ print(f"Twostep LoRA: {twostep_lora_path}")
341
  # ----------------------------------------------------------------
342
 
343
  print(f"Checkpoint: {checkpoint_path}")
 
355
  )
356
  # ----------------------------------------------------------------
357
 
358
+ def _make_lora_key(pose_strength: float, general_strength: float, motion_strength: float, dreamlay_strength: float, mself_strength: float, dramatic_strength: float, fluid_strength: float, liquid_strength: float, demopose_strength: float, voice_strength: float, realism_strength: float, transition_strength: float, physics_strength: float, reasoning_strength: float, twostep_strength: float) -> tuple[str, str]:
359
  rp = round(float(pose_strength), 2)
360
  rg = round(float(general_strength), 2)
361
  rm = round(float(motion_strength), 2)
 
370
  rt = round(float(transition_strength), 2)
371
  ry = round(float(physics_strength), 2)
372
  ri = round(float(reasoning_strength), 2)
373
+ rw = round(float(twostep_strength), 2)
374
+ key_str = f"{pose_lora_path}:{rp}|{general_lora_path}:{rg}|{motion_lora_path}:{rm}|{dreamlay_lora_path}:{rd}|{mself_lora_path}:{rs}|{dramatic_lora_path}:{rr}|{fluid_lora_path}:{rf}|{liquid_lora_path}:{rl}|{demopose_lora_path}:{ro}|{voice_lora_path}:{rv}|{realism_lora_path}:{re}|{transition_lora_path}:{rt}|{physics_lora_path}:{ry}|{reasoning_lora_path}:{ri}|{twostep_lora_path}:{rw}"
375
  key = hashlib.sha256(key_str.encode("utf-8")).hexdigest()
376
  return key, key_str
377
 
378
+ def _collect_lora_specs(
 
379
  pose_strength: float,
380
  general_strength: float,
381
  motion_strength: float,
 
390
  transition_strength: float,
391
  physics_strength: float,
392
  reasoning_strength: float,
393
+ twostep_strength: float,
394
  ):
395
+ # Keep all 14 adapters in the active list; zero strength means no effect.
396
+ return [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  (pose_lora_path, round(float(pose_strength), 2)),
398
  (general_lora_path, round(float(general_strength), 2)),
399
  (motion_lora_path, round(float(motion_strength), 2)),
 
408
  (transition_lora_path, round(float(transition_strength), 2)),
409
  (physics_lora_path, round(float(physics_strength), 2)),
410
  (reasoning_lora_path, round(float(reasoning_strength), 2)),
411
+ (twostep_lora_path, round(float(twostep_strength), 2)),
 
 
 
 
412
  ]
413
 
 
 
 
 
 
414
 
415
+ def apply_current_loras_to_transformer(
416
+ pose_strength: float,
417
+ general_strength: float,
418
+ motion_strength: float,
419
+ dreamlay_strength: float,
420
+ mself_strength: float,
421
+ dramatic_strength: float,
422
+ fluid_strength: float,
423
+ liquid_strength: float,
424
+ demopose_strength: float,
425
+ voice_strength: float,
426
+ realism_strength: float,
427
+ transition_strength: float,
428
+ physics_strength: float,
429
+ reasoning_strength: float,
430
+ twostep_strength: float,
431
+ ):
432
+ global ACTIVE_LORA_KEY
433
 
434
+ key, _ = _make_lora_key(
435
+ pose_strength, general_strength, motion_strength, dreamlay_strength,
436
+ mself_strength, dramatic_strength, fluid_strength, liquid_strength,
437
+ demopose_strength, voice_strength, realism_strength, transition_strength,
438
+ physics_strength, reasoning_strength, twostep_strength
439
+ )
440
 
441
+ if key == ACTIVE_LORA_KEY:
442
+ return "LoRAs already active."
443
+
444
+ if key in LORA_STATE_CACHE:
445
+ fused_state = LORA_STATE_CACHE[key]
446
+ else:
447
+ loras = [
448
+ LoraPathStrengthAndSDOps(path, strength, LTXV_LORA_COMFY_RENAMING_MAP)
449
+ for path, strength in _collect_lora_specs(
450
+ pose_strength, general_strength, motion_strength, dreamlay_strength,
451
+ mself_strength, dramatic_strength, fluid_strength, liquid_strength,
452
+ demopose_strength, voice_strength, realism_strength, transition_strength,
453
+ physics_strength, reasoning_strength, twostep_strength,
454
+ )
455
+ ]
456
+
457
+ fused_state = apply_loras(
458
+ BASE_TRANSFORMER_STATE,
459
+ loras,
460
+ dtype=pipeline.model_ledger.dtype,
461
+ )
462
+ LORA_STATE_CACHE[key] = fused_state
463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  with torch.no_grad():
465
+ missing, unexpected = _transformer.load_state_dict(fused_state, strict=False)
466
  if missing or unexpected:
467
+ print(
468
+ f"[LoRA] state_dict mismatch: missing={len(missing)}, unexpected={len(unexpected)}"
469
+ )
470
 
471
+ ACTIVE_LORA_KEY = key
472
+ return f"Applied LoRAs: {key[:12]}"
 
473
 
474
  # ---- REPLACE PRELOAD BLOCK START ----
475
  # Preload all models for ZeroGPU tensor packing.
 
490
 
491
  # Call the original factories once to create the cached instances we will serve by default.
492
  _transformer = _orig_transformer_factory()
493
+ BASE_TRANSFORMER_STATE = {
494
+ k: v.detach().cpu().contiguous()
495
+ for k, v in _transformer.state_dict().items()
496
+ }
497
+
498
+ ACTIVE_LORA_KEY: str | None = None
499
+ LORA_STATE_CACHE: dict[str, dict[str, torch.Tensor]] = {}
500
  _video_encoder = _orig_video_encoder_factory()
501
  _video_decoder = _orig_video_decoder_factory()
502
  _audio_encoder = _orig_audio_encoder_factory()
 
567
  def get_gpu_duration(
568
  first_image,
569
  last_image,
570
+ input_audio,
571
  prompt: str,
572
  duration: float,
573
  gpu_duration: float,
 
590
  transition_strength: float = 0.0,
591
  physics_strength: float = 0.0,
592
  reasoning_strength: float = 0.0,
593
+ twostep_strength: float = 0.0,
594
  progress=None,
595
  ):
596
  return int(gpu_duration)
 
600
  def generate_video(
601
  first_image,
602
  last_image,
603
+ input_audio,
604
  prompt: str,
605
  duration: float,
606
  gpu_duration: float,
 
623
  transition_strength: float = 0.0,
624
  physics_strength: float = 0.0,
625
  reasoning_strength: float = 0.0,
626
+ twostep_strength: float = 0.0,
627
  progress=gr.Progress(track_tqdm=True),
628
  ):
629
  try:
 
663
 
664
  log_memory("before pipeline call")
665
 
666
+ apply_current_loras_to_transformer(
667
+ pose_strength, general_strength, motion_strength, dreamlay_strength,
668
+ mself_strength, dramatic_strength, fluid_strength, liquid_strength,
669
+ demopose_strength, voice_strength, realism_strength, transition_strength,
670
+ physics_strength, reasoning_strength,
671
+ )
672
+
673
  video, audio = pipeline(
674
  prompt=prompt,
675
  seed=current_seed,
 
678
  num_frames=num_frames,
679
  frame_rate=frame_rate,
680
  images=images,
681
+ audio_path=input_audio,
682
  tiling_config=tiling_config,
683
  enhance_prompt=enhance_prompt,
684
  )
 
713
  with gr.Row():
714
  first_image = gr.Image(label="First Frame (Optional)", type="pil")
715
  last_image = gr.Image(label="Last Frame (Optional)", type="pil")
716
+ input_audio = gr.Audio(label="Audio Input (Optional)", type="filepath")
717
  prompt = gr.Textbox(
718
  label="Prompt",
719
  info="for best results - make it as elaborate as possible",
 
790
  minimum=0.0, maximum=2.0, value=0.0, step=0.01
791
  )
792
  reasoning_strength = gr.Slider(
793
+ label="Official Reasoning strength",
794
+ minimum=0.0, maximum=2.0, value=0.0, step=0.01
795
+ )
796
+ twostep_strength = gr.Slider(
797
+ label="Two Step Reasoning strength",
798
  minimum=0.0, maximum=2.0, value=0.0, step=0.01
799
  )
 
 
 
 
 
 
800
 
801
  with gr.Column():
802
  output_video = gr.Video(label="Generated Video", autoplay=False)
 
813
  [
814
  None,
815
  "pinkknit.jpg",
816
+ None,
817
  "The camera falls downward through darkness as if dropped into a tunnel. "
818
  "As it slows, five friends wearing pink knitted hats and sunglasses lean "
819
  "over and look down toward the camera with curious expressions. The lens "
 
841
  0.0,
842
  0.0,
843
  0.0,
844
+ 0.0,
845
  ],
846
  ],
847
  inputs=[
848
+ first_image, last_image, input_audio, prompt, duration, gpu_duration,
849
  enhance_prompt, seed, randomize_seed, height, width,
850
+ pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength, twostep_strength,
851
  ],
852
  )
853
 
 
869
  outputs=[width, height],
870
  )
871
 
 
 
 
 
 
872
 
873
  generate_btn.click(
874
  fn=generate_video,
875
  inputs=[
876
+ first_image, last_image, input_audio, prompt, duration, gpu_duration, enhance_prompt,
877
  seed, randomize_seed, height, width,
878
+ pose_strength, general_strength, motion_strength, dreamlay_strength, mself_strength, dramatic_strength, fluid_strength, liquid_strength, demopose_strength, voice_strength, realism_strength, transition_strength, physics_strength, reasoning_strength, twostep_strength,
879
  ],
880
  outputs=[output_video, seed],
881
  )
 
886
  """
887
 
888
  if __name__ == "__main__":
889
+ demo.launch(theme=gr.themes.Citrus(), css=css)