multimodalart's picture
multimodalart HF Staff
Update app.py
ef22a0b verified
import spaces
import gradio as gr
import torch
from diffusers import DiffusionPipeline
from diffusers.utils import load_image, export_to_video
import os
import random
import numpy as np
from moviepy import ImageSequenceClip, AudioFileClip, VideoFileClip
from PIL import Image, ImageOps
# --- 1. Model Setup & Configuration ---
# Define the specific distilled sigmas (from LTX-2 documentation)
DISTILLED_SIGMA_VALUES = [
1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875
]
print("Loading LTX-2 Distilled Pipeline...")
pipe = DiffusionPipeline.from_pretrained(
"rootonchair/LTX-2-19b-distilled",
custom_pipeline="multimodalart/ltx2-audio-to-video",
torch_dtype=torch.bfloat16
)
pipe.to("cuda")
print("Loading and Fusing Camera Control LoRA...")
pipe.load_lora_weights("Lightricks/LTX-2-19b-LoRA-Camera-Control-Static", adapter_name="camera_control")
pipe.fuse_lora(lora_scale=0.8)
pipe.unload_lora_weights()
# --- 2. Helper Functions ---
def save_video_with_audio(video_frames, audio_path, fps=24):
"""
Combines the generated video frames with the original input audio.
"""
output_filename = f"output_{random.randint(0, 100000)}.mp4"
# 1. Handle Diffusers Output Formats
if isinstance(video_frames, list):
if video_frames and isinstance(video_frames[0], list):
frames_to_process = video_frames[0]
else:
frames_to_process = video_frames
np_frames = [np.array(img) for img in frames_to_process]
clip = ImageSequenceClip(np_frames, fps=fps)
elif isinstance(video_frames, str):
clip = VideoFileClip(video_frames)
else:
temp_path = "temp_video_no_audio.mp4"
export_to_video(video_frames, temp_path, fps=fps)
clip = VideoFileClip(temp_path)
# 2. Load and Process Audio
audio_clip = AudioFileClip(audio_path)
if audio_clip.duration > clip.duration:
audio_clip = audio_clip.subclipped(0, clip.duration)
# 3. Combine and Save
final_clip = clip.with_audio(audio_clip)
final_clip.write_videofile(
output_filename,
fps=fps,
codec="libx264",
audio_codec="aac",
logger="bar"
)
final_clip.close()
audio_clip.close()
if 'clip' in locals(): clip.close()
return output_filename
def infer_aspect_ratio(image):
"""
Infers the closest supported aspect ratio based on the image dimensions.
Returns the aspect ratio key and the target resolution.
"""
# Define resolutions (W, H)
resolutions = {
"1:1": (512, 512),
"16:9": (768, 512),
"9:16": (512, 768)
}
width, height = image.size
image_ratio = width / height
# Calculate the actual ratios
aspect_ratios = {
"1:1": 1.0,
"16:9": 16 / 9, # ~1.778
"9:16": 9 / 16 # ~0.5625
}
# Find the closest aspect ratio
closest_ratio = min(aspect_ratios.keys(), key=lambda k: abs(aspect_ratios[k] - image_ratio))
return closest_ratio, resolutions[closest_ratio]
def process_image_for_aspect_ratio(image):
"""
Automatically infers the best aspect ratio and crops/resizes the image to match.
Returns the processed image, dimensions, and the detected aspect ratio string.
"""
aspect_ratio_str, (target_width, target_height) = infer_aspect_ratio(image)
# Use ImageOps.fit to center crop and resize automatically
# This preserves aspect ratio of the content while filling the target dimensions
processed_img = ImageOps.fit(
image,
(target_width, target_height),
method=Image.LANCZOS,
centering=(0.5, 0.5)
)
return processed_img, target_width, target_height, aspect_ratio_str
def get_audio_duration(audio_path):
"""
Gets the duration of an audio file and returns updated slider value.
Caps at the maximum allowed duration (12 seconds).
"""
if audio_path is None:
return gr.update()
try:
audio_clip = AudioFileClip(audio_path)
duration = audio_clip.duration
audio_clip.close()
# Cap at maximum of 12 seconds, round to nearest 0.5
capped_duration = min(duration, 12.0)
rounded_duration = round(capped_duration * 2) / 2 # Round to nearest 0.5
return gr.update(value=rounded_duration)
except Exception as e:
print(f"Error getting audio duration: {e}")
return gr.update()
# --- 3. Inference Function ---
@spaces.GPU(duration=85, size='xlarge')
def generate(
image_path,
audio_path,
prompt,
negative_prompt,
video_duration,
seed,
progress=gr.Progress(track_tqdm=True)
):
if not image_path or not audio_path:
raise gr.Error("Please provide both an image and an audio file.")
# Set reproducibility
if seed == -1:
seed = random.randint(0, 1000000)
generator = torch.Generator(device="cuda").manual_seed(seed)
# 1. Load and Preprocess Image (auto-detect aspect ratio)
original_image = load_image(image_path)
image, width, height, detected_ratio = process_image_for_aspect_ratio(original_image)
print(f"Generating with seed: {seed}, Resolution: {width}x{height} ({detected_ratio}), Duration: {video_duration}s")
# 2. Calculate Frames
fps = 24.0
# LTX-2 constraint: (num_frames - 1) % 8 == 0
total_frames = int(video_duration * fps)
# Round to nearest valid block of 8, plus 1
# Example: 4 seconds * 24 = 96 frames.
# 96 is divisible by 8. So we take 96 + 1 = 97 frames.
base_block = round(total_frames / 8) * 8
num_frames = base_block + 1
# Ensure sane minimum
if num_frames < 9: num_frames = 9
print(f"Calculated frames: {num_frames}")
# 3. Run Inference
video_output, _ = pipe(
image=image,
audio=audio_path,
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_frames=num_frames,
frame_rate=fps,
num_inference_steps=8, # Distilled uses 8 steps
sigmas=DISTILLED_SIGMA_VALUES,
guidance_scale=1.0,
generator=generator,
return_dict=False,
)
# 4. Post-process: Add audio
output_video_path = save_video_with_audio(video_output, audio_path, fps=fps)
return output_video_path, seed
# --- 4. Gradio Interface Definition ---
css = """
#col-container { max-width: 800px; margin: 0 auto; }
"""
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# ⚡ LTX-2 Distilled Audio-to-Video")
gr.Markdown("Generate lip-synced or audio-reactive video from a single image using the distilled 8-step LTX-2 model.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="filepath", height=300)
input_audio = gr.Audio(label="Input Audio", type="filepath")
with gr.Column():
result_video = gr.Video(label="Generated Video")
prompt = gr.Textbox(
label="Prompt",
value="A person speaking, lips moving in sync with the words, talking head",
lines=2
)
with gr.Row():
video_duration = gr.Slider(
label="Video Duration (Seconds)",
minimum=1.0,
maximum=12.0,
step=0.5,
value=4.0,
)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="low quality, worst quality, deformed, distorted",
placeholder="Usually ignored by distilled models with guidance 1.0"
)
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
run_btn = gr.Button("Generate Video", variant="primary")
# Output info
used_seed = gr.Number(label="Used Seed", visible=False)
# Event Logic
# Auto-update video duration when audio is uploaded
input_audio.change(
fn=get_audio_duration,
inputs=[input_audio],
outputs=[video_duration]
)
run_btn.click(
fn=generate,
inputs=[
input_image,
input_audio,
prompt,
negative_prompt,
video_duration,
seed
],
outputs=[result_video, used_seed]
)
if __name__ == "__main__":
demo.queue().launch()