boka773 commited on
Commit
bbe5cb2
·
verified ·
1 Parent(s): 61d1d42
Files changed (1) hide show
  1. gradio_app.py +116 -169
gradio_app.py CHANGED
@@ -1,176 +1,123 @@
1
- import os
2
  import gradio as gr
3
- import spaces
4
  import torch
5
- import gc
6
- from huggingface_hub import snapshot_download
7
-
8
- # import argparse
9
-
10
- snapshot_download(repo_id="fffiloni/svd_keyframe_interpolation", local_dir="checkpoints")
11
- checkpoint_dir = "checkpoints/svd_reverse_motion_with_attnflip"
12
-
13
- from diffusers.utils import load_image, export_to_video
14
- from diffusers import UNetSpatioTemporalConditionModel
15
- from custom_diffusers.pipelines.pipeline_frame_interpolation_with_noise_injection import FrameInterpolationWithNoiseInjectionPipeline
16
- from custom_diffusers.schedulers.scheduling_euler_discrete import EulerDiscreteScheduler
17
- from attn_ctrl.attention_control import (AttentionStore,
18
- register_temporal_self_attention_control,
19
- register_temporal_self_attention_flip_control,
20
- )
21
-
22
-
23
- pretrained_model_name_or_path = "stabilityai/stable-video-diffusion-img2vid-xt"
24
- noise_scheduler = EulerDiscreteScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
25
-
26
- pipe = FrameInterpolationWithNoiseInjectionPipeline.from_pretrained(
27
- pretrained_model_name_or_path,
28
- scheduler=noise_scheduler,
29
- variant="fp16",
30
- torch_dtype=torch.float16,
31
- )
32
- ref_unet = pipe.ori_unet
33
-
34
- state_dict = pipe.unet.state_dict()
35
- # computing delta w
36
- finetuned_unet = UNetSpatioTemporalConditionModel.from_pretrained(
37
- checkpoint_dir,
38
- subfolder="unet",
39
- torch_dtype=torch.float16,
40
- )
41
- assert finetuned_unet.config.num_frames==14
42
- ori_unet = UNetSpatioTemporalConditionModel.from_pretrained(
43
- "stabilityai/stable-video-diffusion-img2vid",
44
- subfolder="unet",
45
- variant='fp16',
46
- torch_dtype=torch.float16,
47
- )
48
-
49
- finetuned_state_dict = finetuned_unet.state_dict()
50
- ori_state_dict = ori_unet.state_dict()
51
- for name, param in finetuned_state_dict.items():
52
- if 'temporal_transformer_blocks.0.attn1.to_v' in name or "temporal_transformer_blocks.0.attn1.to_out.0" in name:
53
- delta_w = param - ori_state_dict[name]
54
- state_dict[name] = state_dict[name] + delta_w
55
- pipe.unet.load_state_dict(state_dict)
56
-
57
- controller_ref= AttentionStore()
58
- register_temporal_self_attention_control(ref_unet, controller_ref)
59
-
60
- controller = AttentionStore()
61
- register_temporal_self_attention_flip_control(pipe.unet, controller, controller_ref)
62
-
63
- device = "cuda"
64
- pipe = pipe.to(device)
65
-
66
- def check_outputs_folder(folder_path):
67
- # Check if the folder exists
68
- if os.path.exists(folder_path) and os.path.isdir(folder_path):
69
- # Delete all contents inside the folder
70
- for filename in os.listdir(folder_path):
71
- file_path = os.path.join(folder_path, filename)
72
- try:
73
- if os.path.isfile(file_path) or os.path.islink(file_path):
74
- os.unlink(file_path) # Remove file or link
75
- elif os.path.isdir(file_path):
76
- shutil.rmtree(file_path) # Remove directory
77
- except Exception as e:
78
- print(f'Failed to delete {file_path}. Reason: {e}')
79
- else:
80
- print(f'The folder {folder_path} does not exist.')
81
-
82
- # Custom CUDA memory management function
83
- def cuda_memory_cleanup():
84
- torch.cuda.empty_cache()
85
- torch.cuda.ipc_collect()
86
- gc.collect()
87
-
88
- @spaces.GPU(duration=90)
89
- def infer(frame1_path, frame2_path, progress=gr.Progress(track_tqdm=True)):
90
-
91
- seed = 42
92
- num_inference_steps = 10
93
- noise_injection_steps = 0
94
- noise_injection_ratio = 0.5
95
- weighted_average = False
96
-
97
- generator = torch.Generator(device)
98
- if seed is not None:
99
- generator = generator.manual_seed(seed)
100
-
101
-
102
- frame1 = load_image(frame1_path)
103
- frame1 = frame1.resize((512, 288))
104
-
105
- frame2 = load_image(frame2_path)
106
- frame2 = frame2.resize((512, 288))
107
-
108
- cuda_memory_cleanup()
109
-
110
- frames = pipe(image1=frame1, image2=frame2,
111
- num_inference_steps=num_inference_steps, # 50
112
  generator=generator,
113
- weighted_average=weighted_average, # True
114
- noise_injection_steps=noise_injection_steps, # 0
115
- noise_injection_ratio= noise_injection_ratio, # 0.5
116
- decode_chunk_size=18
117
  ).frames[0]
118
 
119
- # cuda_memory_cleanup()
120
-
121
- print(f"FRAMES: {frames}")
122
-
123
- out_dir = "result"
124
-
125
- check_outputs_folder(out_dir)
126
- os.makedirs(out_dir, exist_ok=True)
127
- out_path = "result/video_result.mp4"
128
-
129
-
130
- if out_path.endswith('.gif'):
131
- frames[0].save(out_path, save_all=True, append_images=frames[1:], duration=142, loop=0)
132
- else:
133
- export_to_video(frames, out_path, fps=7)
134
-
135
- return out_path
136
-
137
- with gr.Blocks() as demo:
138
-
139
- with gr.Column():
140
- gr.Markdown("# Keyframe Interpolation with Stable Video Diffusion")
141
- gr.Markdown("## Generative Inbetweening: Adapting Image-to-Video Models for Keyframe Interpolation")
142
- gr.HTML("""
143
- <div style="display:flex;column-gap:4px;">
144
- <a href='https://svd-keyframe-interpolation.github.io/'>
145
- <img src='https://img.shields.io/badge/Project-Page-Green'>
146
- </a>
147
- <a href='https://arxiv.org/abs/2408.15239'>
148
- <img src='https://img.shields.io/badge/Paper-Arxiv-red'>
149
- </a>
150
- </div>
151
- """)
152
- with gr.Row():
153
- with gr.Column():
154
- image_input1 = gr.Image(label="FRAME 1", type="filepath")
155
- image_input2 = gr.Image(label="FRAME 2", type="filepath")
156
- submit_btn = gr.Button("Submit")
157
- with gr.Column():
158
- output = gr.Video(label="Interpolated result")
159
- gr.Examples(
160
- examples = [
161
- ["examples/example_001/frame1.png", "examples/example_001/frame2.png"],
162
- ["examples/example_002/frame1.png", "examples/example_002/frame2.png"],
163
- ["examples/example_003/frame1.png", "examples/example_003/frame2.png"],
164
- ["examples/example_004/frame1.png", "examples/example_004/frame2.png"]
165
- ],
166
- inputs = [image_input1, image_input2]
167
- )
168
-
169
- submit_btn.click(
170
- fn = infer,
171
- inputs = [image_input1, image_input2],
172
- outputs = [output],
173
- show_api = False
174
  )
175
 
176
- demo.queue().launch(show_api=False, show_error=True)
 
 
1
  import gradio as gr
 
2
  import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from diffusers import StableVideoDiffusionPipeline
6
+ from diffusers.utils import export_to_video
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ pipe = None
11
+
12
+
13
+ # ----------------------------
14
+ # Load model (lazy loading)
15
+ # ----------------------------
16
+ def load_model():
17
+ global pipe
18
+ if pipe is None:
19
+ pipe = StableVideoDiffusionPipeline.from_pretrained(
20
+ "stabilityai/stable-video-diffusion-img2vid-xt",
21
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
22
+ )
23
+
24
+ if device == "cuda":
25
+ pipe.to(device)
26
+ pipe.enable_model_cpu_offload()
27
+ pipe.unet.enable_forward_chunking()
28
+
29
+ pipe.enable_attention_slicing()
30
+
31
+ return pipe
32
+
33
+
34
+ # ----------------------------
35
+ # Resize helper
36
+ # ----------------------------
37
+ def resize_image(image, size=(576, 1024)):
38
+ return image.resize(size)
39
+
40
+
41
+ # ----------------------------
42
+ # Interpolation function
43
+ # ----------------------------
44
+ def generate_video(
45
+ start_image,
46
+ end_image,
47
+ num_frames,
48
+ fps,
49
+ motion_bucket_id,
50
+ seed,
51
+ ):
52
+ if start_image is None or end_image is None:
53
+ return None, "Please upload both start and end images."
54
+
55
+ pipe = load_model()
56
+
57
+ generator = torch.manual_seed(int(seed))
58
+
59
+ start = resize_image(start_image)
60
+ end = resize_image(end_image)
61
+
62
+ # simple blending (basic interpolation conditioning)
63
+ blend = Image.blend(start, end, alpha=0.5)
64
+
65
+ frames = pipe(
66
+ blend,
67
+ num_frames=int(num_frames),
68
+ motion_bucket_id=int(motion_bucket_id),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  generator=generator,
70
+ decode_chunk_size=1, # low VRAM
 
 
 
71
  ).frames[0]
72
 
73
+ video_path = export_to_video(frames, fps=int(fps))
74
+
75
+ return video_path, " Done!"
76
+
77
+
78
+ # ----------------------------
79
+ # UI
80
+ # ----------------------------
81
+ with gr.Blocks(title="SVD Keyframe Interpolation") as demo:
82
+
83
+ gr.Markdown(
84
+ """
85
+ # 🎥 SVD Keyframe Interpolation
86
+ Generate smooth video between two images using Stable Video Diffusion.
87
+
88
+ Upload a start and end frame → generate motion between them.
89
+ """
90
+ )
91
+
92
+ with gr.Row():
93
+ start_image = gr.Image(label="Start Image", type="pil")
94
+ end_image = gr.Image(label="End Image", type="pil")
95
+
96
+ with gr.Row():
97
+ num_frames = gr.Slider(8, 32, value=16, step=1, label="Number of Frames")
98
+ fps = gr.Slider(4, 24, value=8, step=1, label="FPS")
99
+
100
+ with gr.Row():
101
+ motion_bucket_id = gr.Slider(1, 255, value=127, step=1, label="Motion Strength")
102
+ seed = gr.Number(value=42, label="Seed")
103
+
104
+ run_btn = gr.Button("🚀 Generate Video")
105
+
106
+ with gr.Row():
107
+ output_video = gr.Video(label="Output Video")
108
+ status = gr.Textbox(label="Status")
109
+
110
+ run_btn.click(
111
+ fn=generate_video,
112
+ inputs=[
113
+ start_image,
114
+ end_image,
115
+ num_frames,
116
+ fps,
117
+ motion_bucket_id,
118
+ seed,
119
+ ],
120
+ outputs=[output_video, status],
 
 
 
 
 
 
 
121
  )
122
 
123
+ demo.queue().launch()