dagloop5 commited on
Commit
3347828
·
verified ·
1 Parent(s): 8cda1be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -2
app.py CHANGED
@@ -41,6 +41,7 @@ import spaces
41
  import gradio as gr
42
  import numpy as np
43
  from huggingface_hub import hf_hub_download, snapshot_download
 
44
 
45
  from ltx_core.components.diffusion_steps import EulerDiffusionStep
46
  from ltx_core.components.noisers import GaussianNoiser
@@ -74,6 +75,8 @@ except Exception as e:
74
 
75
  logging.getLogger().setLevel(logging.INFO)
76
 
 
 
77
  MAX_SEED = np.iinfo(np.int32).max
78
  DEFAULT_PROMPT = (
79
  "An astronaut hatches from a fragile egg on the surface of the Moon, "
@@ -267,6 +270,11 @@ checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-
267
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
268
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
269
 
 
 
 
 
 
270
  print(f"Checkpoint: {checkpoint_path}")
271
  print(f"Spatial upsampler: {spatial_upsampler_path}")
272
  print(f"Gemma root: {gemma_root}")
@@ -276,7 +284,13 @@ pipeline = LTX23DistilledA2VPipeline(
276
  distilled_checkpoint_path=checkpoint_path,
277
  spatial_upsampler_path=spatial_upsampler_path,
278
  gemma_root=gemma_root,
279
- loras=[],
 
 
 
 
 
 
280
  quantization=QuantizationPolicy.fp8_cast(),
281
  )
282
 
@@ -284,6 +298,20 @@ pipeline = LTX23DistilledA2VPipeline(
284
  print("Preloading all models (including Gemma and audio components)...")
285
  ledger = pipeline.model_ledger
286
  _transformer = ledger.transformer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  _video_encoder = ledger.video_encoder()
288
  _video_decoder = ledger.video_decoder()
289
  _audio_encoder = ledger.audio_encoder()
@@ -355,6 +383,7 @@ def generate_video(
355
  input_audio,
356
  prompt: str,
357
  duration: float,
 
358
  enhance_prompt: bool = True,
359
  seed: int = 42,
360
  randomize_seed: bool = True,
@@ -367,6 +396,8 @@ def generate_video(
367
  log_memory("start")
368
 
369
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
 
 
370
 
371
  frame_rate = DEFAULT_FRAME_RATE
372
  num_frames = int(duration * frame_rate) + 1
@@ -451,6 +482,13 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
451
  placeholder="Describe the motion and animation you want...",
452
  )
453
  duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
 
 
 
 
 
 
 
454
 
455
 
456
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
@@ -515,7 +553,7 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
515
  generate_btn.click(
516
  fn=generate_video,
517
  inputs=[
518
- first_image, last_image, input_audio, prompt, duration, enhance_prompt,
519
  seed, randomize_seed, height, width,
520
  ],
521
  outputs=[output_video, seed],
 
41
  import gradio as gr
42
  import numpy as np
43
  from huggingface_hub import hf_hub_download, snapshot_download
44
+ from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
45
 
46
  from ltx_core.components.diffusion_steps import EulerDiffusionStep
47
  from ltx_core.components.noisers import GaussianNoiser
 
75
 
76
  logging.getLogger().setLevel(logging.INFO)
77
 
78
+ LORA_RUNTIME_SCALE = 1.0
79
+
80
  MAX_SEED = np.iinfo(np.int32).max
81
  DEFAULT_PROMPT = (
82
  "An astronaut hatches from a fragile egg on the surface of the Moon, "
 
270
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
271
  gemma_root = snapshot_download(repo_id=GEMMA_REPO)
272
 
273
+ lora_path = hf_hub_download(
274
+ repo_id="dagloop5/LoRA",
275
+ filename="LoRA2.safetensors"
276
+ )
277
+
278
  print(f"Checkpoint: {checkpoint_path}")
279
  print(f"Spatial upsampler: {spatial_upsampler_path}")
280
  print(f"Gemma root: {gemma_root}")
 
284
  distilled_checkpoint_path=checkpoint_path,
285
  spatial_upsampler_path=spatial_upsampler_path,
286
  gemma_root=gemma_root,
287
+ loras=[
288
+ LoraPathStrengthAndSDOps(
289
+ lora_path,
290
+ 1.0, # fixed internal strength
291
+ LTXV_LORA_COMFY_RENAMING_MAP
292
+ )
293
+ ],
294
  quantization=QuantizationPolicy.fp8_cast(),
295
  )
296
 
 
298
  print("Preloading all models (including Gemma and audio components)...")
299
  ledger = pipeline.model_ledger
300
  _transformer = ledger.transformer()
301
+ _original_forward = _transformer.forward
302
+
303
+ def _lora_scaled_forward(*args, **kwargs):
304
+ out = _original_forward(*args, **kwargs)
305
+
306
+ # Apply runtime scaling to LoRA-influenced output
307
+ # (LTX merges LoRA into attention residuals, so we scale output delta)
308
+ if isinstance(out, tuple):
309
+ return tuple(o * LORA_RUNTIME_SCALE if torch.is_tensor(o) else o for o in out)
310
+ elif torch.is_tensor(out):
311
+ return out * LORA_RUNTIME_SCALE
312
+ return out
313
+
314
+ _transformer.forward = _lora_scaled_forward
315
  _video_encoder = ledger.video_encoder()
316
  _video_decoder = ledger.video_decoder()
317
  _audio_encoder = ledger.audio_encoder()
 
383
  input_audio,
384
  prompt: str,
385
  duration: float,
386
+ lora_strength: float
387
  enhance_prompt: bool = True,
388
  seed: int = 42,
389
  randomize_seed: bool = True,
 
396
  log_memory("start")
397
 
398
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
399
+ global LORA_RUNTIME_SCALE
400
+ LORA_RUNTIME_SCALE = lora_strength
401
 
402
  frame_rate = DEFAULT_FRAME_RATE
403
  num_frames = int(duration * frame_rate) + 1
 
482
  placeholder="Describe the motion and animation you want...",
483
  )
484
  duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
485
+ lora_strength = gr.Slider(
486
+ label="LoRA Strength",
487
+ minimum=0.0,
488
+ maximum=1.5,
489
+ value=1.0,
490
+ step=0.05,
491
+ )
492
 
493
 
494
  generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
 
553
  generate_btn.click(
554
  fn=generate_video,
555
  inputs=[
556
+ first_image, last_image, input_audio, prompt, duration, lora_strength, enhance_prompt,
557
  seed, randomize_seed, height, width,
558
  ],
559
  outputs=[output_video, seed],