Update app.py
Browse files
app.py
CHANGED
|
@@ -41,6 +41,7 @@ import spaces
|
|
| 41 |
import gradio as gr
|
| 42 |
import numpy as np
|
| 43 |
from huggingface_hub import hf_hub_download, snapshot_download
|
|
|
|
| 44 |
|
| 45 |
from ltx_core.components.diffusion_steps import EulerDiffusionStep
|
| 46 |
from ltx_core.components.noisers import GaussianNoiser
|
|
@@ -74,6 +75,8 @@ except Exception as e:
|
|
| 74 |
|
| 75 |
logging.getLogger().setLevel(logging.INFO)
|
| 76 |
|
|
|
|
|
|
|
| 77 |
MAX_SEED = np.iinfo(np.int32).max
|
| 78 |
DEFAULT_PROMPT = (
|
| 79 |
"An astronaut hatches from a fragile egg on the surface of the Moon, "
|
|
@@ -267,6 +270,11 @@ checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-22b-
|
|
| 267 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
| 268 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
print(f"Checkpoint: {checkpoint_path}")
|
| 271 |
print(f"Spatial upsampler: {spatial_upsampler_path}")
|
| 272 |
print(f"Gemma root: {gemma_root}")
|
|
@@ -276,7 +284,13 @@ pipeline = LTX23DistilledA2VPipeline(
|
|
| 276 |
distilled_checkpoint_path=checkpoint_path,
|
| 277 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 278 |
gemma_root=gemma_root,
|
| 279 |
-
loras=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
quantization=QuantizationPolicy.fp8_cast(),
|
| 281 |
)
|
| 282 |
|
|
@@ -284,6 +298,20 @@ pipeline = LTX23DistilledA2VPipeline(
|
|
| 284 |
print("Preloading all models (including Gemma and audio components)...")
|
| 285 |
ledger = pipeline.model_ledger
|
| 286 |
_transformer = ledger.transformer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
_video_encoder = ledger.video_encoder()
|
| 288 |
_video_decoder = ledger.video_decoder()
|
| 289 |
_audio_encoder = ledger.audio_encoder()
|
|
@@ -355,6 +383,7 @@ def generate_video(
|
|
| 355 |
input_audio,
|
| 356 |
prompt: str,
|
| 357 |
duration: float,
|
|
|
|
| 358 |
enhance_prompt: bool = True,
|
| 359 |
seed: int = 42,
|
| 360 |
randomize_seed: bool = True,
|
|
@@ -367,6 +396,8 @@ def generate_video(
|
|
| 367 |
log_memory("start")
|
| 368 |
|
| 369 |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
|
|
|
|
|
|
| 370 |
|
| 371 |
frame_rate = DEFAULT_FRAME_RATE
|
| 372 |
num_frames = int(duration * frame_rate) + 1
|
|
@@ -451,6 +482,13 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
|
|
| 451 |
placeholder="Describe the motion and animation you want...",
|
| 452 |
)
|
| 453 |
duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
|
| 456 |
generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
|
|
@@ -515,7 +553,7 @@ with gr.Blocks(title="LTX-2.3 Heretic Distilled") as demo:
|
|
| 515 |
generate_btn.click(
|
| 516 |
fn=generate_video,
|
| 517 |
inputs=[
|
| 518 |
-
first_image, last_image, input_audio, prompt, duration, enhance_prompt,
|
| 519 |
seed, randomize_seed, height, width,
|
| 520 |
],
|
| 521 |
outputs=[output_video, seed],
|
|
|
|
| 41 |
import gradio as gr
|
| 42 |
import numpy as np
|
| 43 |
from huggingface_hub import hf_hub_download, snapshot_download
|
| 44 |
+
from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
|
| 45 |
|
| 46 |
from ltx_core.components.diffusion_steps import EulerDiffusionStep
|
| 47 |
from ltx_core.components.noisers import GaussianNoiser
|
|
|
|
| 75 |
|
| 76 |
logging.getLogger().setLevel(logging.INFO)
|
| 77 |
|
| 78 |
+
LORA_RUNTIME_SCALE = 1.0
|
| 79 |
+
|
| 80 |
MAX_SEED = np.iinfo(np.int32).max
|
| 81 |
DEFAULT_PROMPT = (
|
| 82 |
"An astronaut hatches from a fragile egg on the surface of the Moon, "
|
|
|
|
| 270 |
spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
| 271 |
gemma_root = snapshot_download(repo_id=GEMMA_REPO)
|
| 272 |
|
| 273 |
+
lora_path = hf_hub_download(
|
| 274 |
+
repo_id="dagloop5/LoRA",
|
| 275 |
+
filename="LoRA2.safetensors"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
print(f"Checkpoint: {checkpoint_path}")
|
| 279 |
print(f"Spatial upsampler: {spatial_upsampler_path}")
|
| 280 |
print(f"Gemma root: {gemma_root}")
|
|
|
|
| 284 |
distilled_checkpoint_path=checkpoint_path,
|
| 285 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 286 |
gemma_root=gemma_root,
|
| 287 |
+
loras=[
|
| 288 |
+
LoraPathStrengthAndSDOps(
|
| 289 |
+
lora_path,
|
| 290 |
+
1.0, # fixed internal strength
|
| 291 |
+
LTXV_LORA_COMFY_RENAMING_MAP
|
| 292 |
+
)
|
| 293 |
+
],
|
| 294 |
quantization=QuantizationPolicy.fp8_cast(),
|
| 295 |
)
|
| 296 |
|
|
|
|
| 298 |
print("Preloading all models (including Gemma and audio components)...")
|
| 299 |
ledger = pipeline.model_ledger
|
| 300 |
_transformer = ledger.transformer()
|
| 301 |
+
_original_forward = _transformer.forward
|
| 302 |
+
|
| 303 |
+
def _lora_scaled_forward(*args, **kwargs):
|
| 304 |
+
out = _original_forward(*args, **kwargs)
|
| 305 |
+
|
| 306 |
+
# Apply runtime scaling to LoRA-influenced output
|
| 307 |
+
# (LTX merges LoRA into attention residuals, so we scale output delta)
|
| 308 |
+
if isinstance(out, tuple):
|
| 309 |
+
return tuple(o * LORA_RUNTIME_SCALE if torch.is_tensor(o) else o for o in out)
|
| 310 |
+
elif torch.is_tensor(out):
|
| 311 |
+
return out * LORA_RUNTIME_SCALE
|
| 312 |
+
return out
|
| 313 |
+
|
| 314 |
+
_transformer.forward = _lora_scaled_forward
|
| 315 |
_video_encoder = ledger.video_encoder()
|
| 316 |
_video_decoder = ledger.video_decoder()
|
| 317 |
_audio_encoder = ledger.audio_encoder()
|
|
|
|
| 383 |
input_audio,
|
| 384 |
prompt: str,
|
| 385 |
duration: float,
|
| 386 |
+
lora_strength: float
|
| 387 |
enhance_prompt: bool = True,
|
| 388 |
seed: int = 42,
|
| 389 |
randomize_seed: bool = True,
|
|
|
|
| 396 |
log_memory("start")
|
| 397 |
|
| 398 |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
|
| 399 |
+
global LORA_RUNTIME_SCALE
|
| 400 |
+
LORA_RUNTIME_SCALE = lora_strength
|
| 401 |
|
| 402 |
frame_rate = DEFAULT_FRAME_RATE
|
| 403 |
num_frames = int(duration * frame_rate) + 1
|
|
|
|
| 482 |
placeholder="Describe the motion and animation you want...",
|
| 483 |
)
|
| 484 |
duration = gr.Slider(label="Duration (seconds)", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
|
| 485 |
+
lora_strength = gr.Slider(
|
| 486 |
+
label="LoRA Strength",
|
| 487 |
+
minimum=0.0,
|
| 488 |
+
maximum=1.5,
|
| 489 |
+
value=1.0,
|
| 490 |
+
step=0.05,
|
| 491 |
+
)
|
| 492 |
|
| 493 |
|
| 494 |
generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
|
|
|
|
| 553 |
generate_btn.click(
|
| 554 |
fn=generate_video,
|
| 555 |
inputs=[
|
| 556 |
+
first_image, last_image, input_audio, prompt, duration, lora_strength, enhance_prompt,
|
| 557 |
seed, randomize_seed, height, width,
|
| 558 |
],
|
| 559 |
outputs=[output_video, seed],
|