boka773's picture
app.py
bbe5cb2 verified
raw
history blame
3.1 kB
import gradio as gr
import torch
import numpy as np
from PIL import Image
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import export_to_video
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = None
# ----------------------------
# Load model (lazy loading)
# ----------------------------
def load_model():
global pipe
if pipe is None:
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
)
if device == "cuda":
pipe.to(device)
pipe.enable_model_cpu_offload()
pipe.unet.enable_forward_chunking()
pipe.enable_attention_slicing()
return pipe
# ----------------------------
# Resize helper
# ----------------------------
def resize_image(image, size=(576, 1024)):
return image.resize(size)
# ----------------------------
# Interpolation function
# ----------------------------
def generate_video(
start_image,
end_image,
num_frames,
fps,
motion_bucket_id,
seed,
):
if start_image is None or end_image is None:
return None, "Please upload both start and end images."
pipe = load_model()
generator = torch.manual_seed(int(seed))
start = resize_image(start_image)
end = resize_image(end_image)
# simple blending (basic interpolation conditioning)
blend = Image.blend(start, end, alpha=0.5)
frames = pipe(
blend,
num_frames=int(num_frames),
motion_bucket_id=int(motion_bucket_id),
generator=generator,
decode_chunk_size=1, # low VRAM
).frames[0]
video_path = export_to_video(frames, fps=int(fps))
return video_path, "βœ… Done!"
# ----------------------------
# UI
# ----------------------------
with gr.Blocks(title="SVD Keyframe Interpolation") as demo:
gr.Markdown(
"""
# πŸŽ₯ SVD Keyframe Interpolation
Generate smooth video between two images using Stable Video Diffusion.
Upload a start and end frame β†’ generate motion between them.
"""
)
with gr.Row():
start_image = gr.Image(label="Start Image", type="pil")
end_image = gr.Image(label="End Image", type="pil")
with gr.Row():
num_frames = gr.Slider(8, 32, value=16, step=1, label="Number of Frames")
fps = gr.Slider(4, 24, value=8, step=1, label="FPS")
with gr.Row():
motion_bucket_id = gr.Slider(1, 255, value=127, step=1, label="Motion Strength")
seed = gr.Number(value=42, label="Seed")
run_btn = gr.Button("πŸš€ Generate Video")
with gr.Row():
output_video = gr.Video(label="Output Video")
status = gr.Textbox(label="Status")
run_btn.click(
fn=generate_video,
inputs=[
start_image,
end_image,
num_frames,
fps,
motion_bucket_id,
seed,
],
outputs=[output_video, status],
)
demo.queue().launch()