Update app.py
Browse files
app.py
CHANGED
|
@@ -60,6 +60,10 @@ from ltx_pipelines.utils.helpers import (
|
|
| 60 |
encode_prompts,
|
| 61 |
simple_denoising_func,
|
| 62 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
|
| 64 |
|
| 65 |
# Force-patch xformers attention into the LTX attention module.
|
|
@@ -271,12 +275,34 @@ print(f"Checkpoint: {checkpoint_path}")
|
|
| 271 |
print(f"Spatial upsampler: {spatial_upsampler_path}")
|
| 272 |
print(f"Gemma root: {gemma_root}")
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
# Initialize pipeline WITH text encoder and optional audio support
|
| 275 |
pipeline = LTX23DistilledA2VPipeline(
|
| 276 |
distilled_checkpoint_path=checkpoint_path,
|
| 277 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 278 |
gemma_root=gemma_root,
|
| 279 |
-
loras=[],
|
| 280 |
quantization=QuantizationPolicy.fp8_cast(),
|
| 281 |
)
|
| 282 |
|
|
@@ -293,15 +319,6 @@ _spatial_upsampler = ledger.spatial_upsampler()
|
|
| 293 |
_text_encoder = ledger.text_encoder()
|
| 294 |
_embeddings_processor = ledger.gemma_embeddings_processor()
|
| 295 |
|
| 296 |
-
ledger.transformer = lambda: _transformer
|
| 297 |
-
ledger.video_encoder = lambda: _video_encoder
|
| 298 |
-
ledger.video_decoder = lambda: _video_decoder
|
| 299 |
-
ledger.audio_encoder = lambda: _audio_encoder
|
| 300 |
-
ledger.audio_decoder = lambda: _audio_decoder
|
| 301 |
-
ledger.vocoder = lambda: _vocoder
|
| 302 |
-
ledger.spatial_upsampler = lambda: _spatial_upsampler
|
| 303 |
-
ledger.text_encoder = lambda: _text_encoder
|
| 304 |
-
ledger.gemma_embeddings_processor = lambda: _embeddings_processor
|
| 305 |
print("All models preloaded (including Gemma text encoder and audio encoder)!")
|
| 306 |
|
| 307 |
print("=" * 80)
|
|
@@ -347,7 +364,7 @@ def on_highres_toggle(first_image, last_image, high_res):
|
|
| 347 |
return gr.update(value=w), gr.update(value=h)
|
| 348 |
|
| 349 |
|
| 350 |
-
@spaces.GPU(duration=
|
| 351 |
@torch.inference_mode()
|
| 352 |
def generate_video(
|
| 353 |
first_image,
|
|
@@ -360,6 +377,9 @@ def generate_video(
|
|
| 360 |
randomize_seed: bool = True,
|
| 361 |
height: int = 1024,
|
| 362 |
width: int = 1536,
|
|
|
|
|
|
|
|
|
|
| 363 |
progress=gr.Progress(track_tqdm=True),
|
| 364 |
):
|
| 365 |
try:
|
|
@@ -368,6 +388,42 @@ def generate_video(
|
|
| 368 |
|
| 369 |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
| 370 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
frame_rate = DEFAULT_FRAME_RATE
|
| 372 |
num_frames = int(duration * frame_rate) + 1
|
| 373 |
num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1
|
|
@@ -464,9 +520,12 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
|
|
| 464 |
with gr.Row():
|
| 465 |
enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=False)
|
| 466 |
high_res = gr.Checkbox(label="High Resolution", value=True)
|
|
|
|
|
|
|
|
|
|
| 467 |
|
| 468 |
with gr.Column():
|
| 469 |
-
output_video = gr.Video(label="Generated Video", autoplay=
|
| 470 |
|
| 471 |
gr.Examples(
|
| 472 |
examples=[
|
|
@@ -517,6 +576,7 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
|
|
| 517 |
inputs=[
|
| 518 |
first_image, last_image, input_audio, prompt, duration, enhance_prompt,
|
| 519 |
seed, randomize_seed, height, width,
|
|
|
|
| 520 |
],
|
| 521 |
outputs=[output_video, seed],
|
| 522 |
)
|
|
|
|
| 60 |
encode_prompts,
|
| 61 |
simple_denoising_func,
|
| 62 |
)
|
| 63 |
+
|
| 64 |
+
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
|
| 65 |
+
from ltx_core.loader.sd_ops import LTXV_LORA_COMFY_RENAMING_MAP
|
| 66 |
+
|
| 67 |
from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
|
| 68 |
|
| 69 |
# Force-patch xformers attention into the LTX attention module.
|
|
|
|
| 275 |
print(f"Spatial upsampler: {spatial_upsampler_path}")
|
| 276 |
print(f"Gemma root: {gemma_root}")
|
| 277 |
|
| 278 |
+
# Download the LoRAs we want to support and prepare helper to create LoraPathStrengthAndSDOps
|
| 279 |
+
LORA_REPO = "dagloop5/LoRA"
|
| 280 |
+
pose_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="pose_enhancer.safetensors")
|
| 281 |
+
general_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="general_enhancer.safetensors")
|
| 282 |
+
motion_lora_path = hf_hub_download(repo_id=LORA_REPO, filename="motion_helper.safetensors")
|
| 283 |
+
|
| 284 |
+
print(f"Downloaded LoRAs: {pose_lora_path}, {general_lora_path}, {motion_lora_path}")
|
| 285 |
+
|
| 286 |
+
def build_loras_tuple(pose_strength: float, general_strength: float, motion_strength: float):
|
| 287 |
+
"""
|
| 288 |
+
Return a tuple of LoraPathStrengthAndSDOps matching LTX loader expectations.
|
| 289 |
+
Uses the LTX renaming map for SD key remapping (helps with some LoRA formats).
|
| 290 |
+
"""
|
| 291 |
+
return (
|
| 292 |
+
LoraPathStrengthAndSDOps(path=str(pose_lora_path), strength=float(pose_strength), sd_ops=LTXV_LORA_COMFY_RENAMING_MAP),
|
| 293 |
+
LoraPathStrengthAndSDOps(path=str(general_lora_path), strength=float(general_strength), sd_ops=LTXV_LORA_COMFY_RENAMING_MAP),
|
| 294 |
+
LoraPathStrengthAndSDOps(path=str(motion_lora_path), strength=float(motion_strength), sd_ops=LTXV_LORA_COMFY_RENAMING_MAP),
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# initial strengths (you can change defaults)
|
| 298 |
+
INITIAL_LORAS = build_loras_tuple(1.0, 1.0, 1.0)
|
| 299 |
+
|
| 300 |
# Initialize pipeline WITH text encoder and optional audio support
|
| 301 |
pipeline = LTX23DistilledA2VPipeline(
|
| 302 |
distilled_checkpoint_path=checkpoint_path,
|
| 303 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 304 |
gemma_root=gemma_root,
|
| 305 |
+
loras=[INITIAL_LORAS],
|
| 306 |
quantization=QuantizationPolicy.fp8_cast(),
|
| 307 |
)
|
| 308 |
|
|
|
|
| 319 |
_text_encoder = ledger.text_encoder()
|
| 320 |
_embeddings_processor = ledger.gemma_embeddings_processor()
|
| 321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
print("All models preloaded (including Gemma text encoder and audio encoder)!")
|
| 323 |
|
| 324 |
print("=" * 80)
|
|
|
|
| 364 |
return gr.update(value=w), gr.update(value=h)
|
| 365 |
|
| 366 |
|
| 367 |
+
@spaces.GPU(duration=80)
|
| 368 |
@torch.inference_mode()
|
| 369 |
def generate_video(
|
| 370 |
first_image,
|
|
|
|
| 377 |
randomize_seed: bool = True,
|
| 378 |
height: int = 1024,
|
| 379 |
width: int = 1536,
|
| 380 |
+
pose_lora_strength: float = 1.0,
|
| 381 |
+
general_lora_strength: float = 1.0,
|
| 382 |
+
motion_lora_strength: float = 1.0,
|
| 383 |
progress=gr.Progress(track_tqdm=True),
|
| 384 |
):
|
| 385 |
try:
|
|
|
|
| 388 |
|
| 389 |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
| 390 |
|
| 391 |
+
# --- LoRA dynamic update: rebuild ledger models in-place when strengths change ---
|
| 392 |
+
try:
|
| 393 |
+
current_ledger = pipeline.model_ledger
|
| 394 |
+
# helper to compare strengths quickly
|
| 395 |
+
def _get_current_strengths(ledger_obj):
|
| 396 |
+
return tuple(float(lora.strength) for lora in getattr(ledger_obj, "loras", ()))
|
| 397 |
+
|
| 398 |
+
requested_strengths = (float(pose_lora_strength), float(general_lora_strength), float(motion_lora_strength))
|
| 399 |
+
if _get_current_strengths(current_ledger) != requested_strengths:
|
| 400 |
+
# build new tuple and replace ledger.loras
|
| 401 |
+
current_ledger.loras = build_loras_tuple(*requested_strengths)
|
| 402 |
+
# clear cached model instances so new models are constructed with the new LoRAs
|
| 403 |
+
# (ModelLedger builds models on first access using its configured `loras`)
|
| 404 |
+
try:
|
| 405 |
+
current_ledger.clear_vram()
|
| 406 |
+
except Exception:
|
| 407 |
+
# `clear_vram` should exist; if it doesn't, fall back to deleting cached attrs
|
| 408 |
+
for k in list(vars(current_ledger).keys()):
|
| 409 |
+
if k in ("_transformer", "_video_encoder", "_video_decoder", "_audio_encoder", "_audio_decoder", "_vocoder", "_spatial_upsampler", "_text_encoder", "_gemma_embeddings_processor"):
|
| 410 |
+
vars(current_ledger).pop(k, None)
|
| 411 |
+
# Now pre-load the models again (ensures they are on-device before pipeline call)
|
| 412 |
+
_ = current_ledger.transformer()
|
| 413 |
+
_ = current_ledger.video_encoder()
|
| 414 |
+
_ = current_ledger.video_decoder()
|
| 415 |
+
_ = current_ledger.audio_encoder()
|
| 416 |
+
_ = current_ledger.audio_decoder()
|
| 417 |
+
_ = current_ledger.vocoder()
|
| 418 |
+
_ = current_ledger.spatial_upsampler()
|
| 419 |
+
_ = current_ledger.text_encoder()
|
| 420 |
+
_ = current_ledger.gemma_embeddings_processor()
|
| 421 |
+
torch.cuda.empty_cache()
|
| 422 |
+
except Exception as e:
|
| 423 |
+
# if this fails, we still proceed with the existing pipeline (safer to continue than to crash)
|
| 424 |
+
print(f"[LoRA rebuild warning] Could not update LoRA strengths in-place: {e}")
|
| 425 |
+
# --- end LoRA update ---
|
| 426 |
+
|
| 427 |
frame_rate = DEFAULT_FRAME_RATE
|
| 428 |
num_frames = int(duration * frame_rate) + 1
|
| 429 |
num_frames = ((num_frames - 1 + 7) // 8) * 8 + 1
|
|
|
|
| 520 |
with gr.Row():
|
| 521 |
enhance_prompt = gr.Checkbox(label="Enhance Prompt", value=False)
|
| 522 |
high_res = gr.Checkbox(label="High Resolution", value=True)
|
| 523 |
+
pose_lora_strength = gr.Slider(label="Pose LoRA Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 524 |
+
general_lora_strength = gr.Slider(label="General LoRA Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 525 |
+
motion_lora_strength = gr.Slider(label="Motion LoRA Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01)
|
| 526 |
|
| 527 |
with gr.Column():
|
| 528 |
+
output_video = gr.Video(label="Generated Video", autoplay=False)
|
| 529 |
|
| 530 |
gr.Examples(
|
| 531 |
examples=[
|
|
|
|
| 576 |
inputs=[
|
| 577 |
first_image, last_image, input_audio, prompt, duration, enhance_prompt,
|
| 578 |
seed, randomize_seed, height, width,
|
| 579 |
+
pose_lora_strength, general_lora_strength, motion_lora_strength,
|
| 580 |
],
|
| 581 |
outputs=[output_video, seed],
|
| 582 |
)
|