dagloop5 commited on
Commit
c1892c6
·
verified ·
1 Parent(s): c21a9b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -12
app.py CHANGED
@@ -60,6 +60,10 @@ from ltx_pipelines.utils.helpers import (
60
  encode_prompts,
61
  simple_denoising_func,
62
  )
 
 
 
 
63
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
64
 
65
  # Force-patch xformers attention into the LTX attention module.
@@ -271,12 +275,34 @@ print(f"Checkpoint: {checkpoint_path}")
271
  print(f"Spatial upsampler: {spatial_upsampler_path}")
272
  print(f"Gemma root: {gemma_root}")
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  # Initialize pipeline WITH text encoder and optional audio support
275
  pipeline = LTX23DistilledA2VPipeline(
276
  distilled_checkpoint_path=checkpoint_path,
277
  spatial_upsampler_path=spatial_upsampler_path,
278
  gemma_root=gemma_root,
279
- loras=[],
280
  quantization=QuantizationPolicy.fp8_cast(),
281
  )
282
 
@@ -293,15 +319,6 @@ _spatial_upsampler = ledger.spatial_upsampler()
293
  _text_encoder = ledger.text_encoder()
294
  _embeddings_processor = ledger.gemma_embeddings_processor()
295
 
296
- ledger.transformer = lambda: _transformer
297
- ledger.video_encoder = lambda: _video_encoder
298
- ledger.video_decoder = lambda: _video_decoder
299
- ledger.audio_encoder = lambda: _audio_encoder
300
- ledger.audio_decoder = lambda: _audio_decoder
301
- ledger.vocoder = lambda: _vocoder
302
- ledger.spatial_upsampler = lambda: _spatial_upsampler
303
- ledger.text_encoder = lambda: _text_encoder
304
- ledger.gemma_embeddings_processor = lambda: _embeddings_processor
305
  print("All models preloaded (including Gemma text encoder and audio encoder)!")
306
 
307
  print("=" * 80)
@@ -347,7 +364,7 @@ def on_highres_toggle(first_image, last_image, high_res):
347
  return gr.update(value=w), gr.update(value=h)
348
 
349
 
350
- @spaces.GPU(duration=75)
351
  @torch.inference_mode()
352
  def generate_video(
353
  first_image,
@@ -360,6 +377,9 @@ def generate_video(
360
  randomize_seed: bool = True,
361
  height: int = 1024,
362
  width: int = 1536,
 
 
 
363
  progress=gr.Progress(track_tqdm=True),
364
  ):
365
  try:
@@ -368,6 +388,42 @@ def generate_video(
368
 
369
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  frame_rate = DEFAULT_FRAME_RATE
372
  num_frames = int(duration * frame_rate) + 1
373
  num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1
@@ -464,9 +520,12 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
464
  with gr.Row():
465
  enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=False)
466
  high_res = gr.Checkbox(label="High Resolution", value=True)
 
 
 
467
 
468
  with gr.Column():
469
- output_video = gr.Video(label="Generated Video", autoplay=True)
470
 
471
  gr.Examples(
472
  examples=[
@@ -517,6 +576,7 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
517
  inputs=[
518
  first_image, last_image, input_audio, prompt, duration, enhance_prompt,
519
  seed, randomize_seed, height, width,
 
520
  ],
521
  outputs=[output_video, seed],
522
  )
 
60
  encode_prompts,
61
  simple_denoising_func,
62
  )
63
+
64
+ from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
65
+ from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
66
+
67
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
68
 
69
  # Force-patch xformers attention into the LTX attention module.
 
275
  print(f"Spatial upsampler: {spatial_upsampler_path}")
276
  print(f"Gemma root: {gemma_root}")
277
 
278
+ # Download the LoRAs we want to support and prepare helper to create LoraPathStrengthAndSDOps
279
+ LORA_REPO = "dagloop5/LoRA"
280
+ pose_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="pose_enhancer.safetensors")
281
+ general_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="general_enhancer.safetensors")
282
+ motion_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="motion_helper.safetensors")
283
+
284
+ print(f"Downloaded LoRAs: {pose_lora_path}, {general_lora_path}, {motion_lora_path}")
285
+
286
+ def build_loras_tuple(pose_strength: float, general_strength: float, motion_strength: float):
287
+ """
288
+ Return a tuple of LoraPathStrengthAndSDOps matching LTX loader expectations.
289
+ Uses the LTX renaming map for SD key remapping (helps with some LoRA formats).
290
+ """
291
+ return (
292
+ LoraPathStrengthAndSDOps(path=str(pose_lora_path), strength=float(pose_strength), sd_ops=LTXV_LORA_COMFY_RENAMING_MAP),
293
+ LoraPathStrengthAndSDOps(path=str(general_lora_path), strength=float(general_strength), sd_ops=LTXV_LORA_COMFY_RENAMING_MAP),
294
+ LoraPathStrengthAndSDOps(path=str(motion_lora_path), strength=float(motion_strength), sd_ops=LTXV_LORA_COMFY_RENAMING_MAP),
295
+ )
296
+
297
+ # initial strengths (you can change defaults)
298
+ INITIAL_LORAS = build_loras_tuple(1.0, 1.0, 1.0)
299
+
300
  # Initialize pipeline WITH text encoder and optional audio support
301
  pipeline = LTX23DistilledA2VPipeline(
302
  distilled_checkpoint_path=checkpoint_path,
303
  spatial_upsampler_path=spatial_upsampler_path,
304
  gemma_root=gemma_root,
305
+ loras=[INITIAL_LORAS],
306
  quantization=QuantizationPolicy.fp8_cast(),
307
  )
308
 
 
319
  _text_encoder = ledger.text_encoder()
320
  _embeddings_processor = ledger.gemma_embeddings_processor()
321
 
 
 
 
 
 
 
 
 
 
322
  print("All models preloaded (including Gemma text encoder and audio encoder)!")
323
 
324
  print("=" * 80)
 
364
  return gr.update(value=w), gr.update(value=h)
365
 
366
 
367
+ @spaces.GPU(duration=80)
368
  @torch.inference_mode()
369
  def generate_video(
370
  first_image,
 
377
  randomize_seed: bool = True,
378
  height: int = 1024,
379
  width: int = 1536,
380
+ pose_lora_strength: float = 1.0,
381
+ general_lora_strength: float = 1.0,
382
+ motion_lora_strength: float = 1.0,
383
  progress=gr.Progress(track_tqdm=True),
384
  ):
385
  try:
 
388
 
389
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
390
 
391
+ # --- LoRA dynamic update: rebuild ledger models in-place when strengths change ---
392
+ try:
393
+ current_ledger = pipeline.model_ledger
394
+ # helper to compare strengths quickly
395
+ def _get_current_strengths(ledger_obj):
396
+ return tuple(float(lora.strength) for lora in getattr(ledger_obj, "loras", ()))
397
+
398
+ requested_strengths = (float(pose_lora_strength), float(general_lora_strength), float(motion_lora_strength))
399
+ if _get_current_strengths(current_ledger) != requested_strengths:
400
+ # build new tuple and replace ledger.loras
401
+ current_ledger.loras = build_loras_tuple(*requested_strengths)
402
+ # clear cached model instances so new models are constructed with the new LoRAs
403
+ # (ModelLedger builds models on first access using its configured `loras`)
404
+ try:
405
+ current_ledger.clear_vram()
406
+ except Exception:
407
+ # `clear_vram` should exist; if it doesn't, fall back to deleting cached attrs
408
+ for k in list(vars(current_ledger).keys()):
409
+ if k in ("_transformer", "_video_encoder", "_video_decoder", "_audio_encoder", "_audio_decoder", "_vocoder", "_spatial_upsampler", "_text_encoder", "_gemma_embeddings_processor"):
410
+ vars(current_ledger).pop(k, None)
411
+ # Now pre-load the models again (ensures they are on-device before pipeline call)
412
+ _ = current_ledger.transformer()
413
+ _ = current_ledger.video_encoder()
414
+ _ = current_ledger.video_decoder()
415
+ _ = current_ledger.audio_encoder()
416
+ _ = current_ledger.audio_decoder()
417
+ _ = current_ledger.vocoder()
418
+ _ = current_ledger.spatial_upsampler()
419
+ _ = current_ledger.text_encoder()
420
+ _ = current_ledger.gemma_embeddings_processor()
421
+ torch.cuda.empty_cache()
422
+ except Exception as e:
423
+ # if this fails, we still proceed with the existing pipeline (safer to continue than to crash)
424
+ print(f"[LoRA rebuild warning] Could not update LoRA strengths in-place: {e}")
425
+ # --- end LoRA update ---
426
+
427
  frame_rate = DEFAULT_FRAME_RATE
428
  num_frames = int(duration * frame_rate) + 1
429
  num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1
 
520
  with gr.Row():
521
  enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=False)
522
  high_res = gr.Checkbox(label="High Resolution", value=True)
523
+ pose_lora_strength = gr.Slider(label="Pose LoRA Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
524
+ general_lora_strength = gr.Slider(label="General LoRA Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
525
+ motion_lora_strength = gr.Slider(label="Motion LoRA Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
526
 
527
  with gr.Column():
528
+ output_video = gr.Video(label="Generated Video", autoplay=False)
529
 
530
  gr.Examples(
531
  examples=[
 
576
  inputs=[
577
  first_image, last_image, input_audio, prompt, duration, enhance_prompt,
578
  seed, randomize_seed, height, width,
579
+ pose_lora_strength, general_lora_strength, motion_lora_strength,
580
  ],
581
  outputs=[output_video, seed],
582
  )