| import os |
| import copy |
| import time |
| import torch |
| import random |
| import gradio as gr |
| from glob import glob |
| from omegaconf import OmegaConf |
| from safetensors import safe_open |
| from diffusers import AutoencoderKL |
| from diffusers import DDIMScheduler |
| from diffusers.utils.import_utils import is_xformers_available |
| from transformers import CLIPTextModel, CLIPTokenizer |
|
|
| from utils.unet import UNet3DConditionModel |
| from utils.pipeline_magictime import MagicTimePipeline |
| from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model |
| import spaces |
|
|
| from huggingface_hub import snapshot_download |
|
|
| model_path = "ckpts" |
|
|
| if not os.path.exists(model_path) or not os.path.exists(f"{model_path}/model_real_esran") or not os.path.exists(f"{model_path}/model_rife"): |
| print("Model not found, downloading from Hugging Face...") |
| snapshot_download(repo_id="BestWishYsh/MagicTime", local_dir=f"{model_path}") |
| else: |
| print(f"Model already exists in {model_path}, skipping download.") |
| |
| pretrained_model_path = f"{model_path}/Base_Model/stable-diffusion-v1-5" |
| inference_config_path = "sample_configs/RealisticVision.yaml" |
| magic_adapter_s_path = f"{model_path}/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt" |
| magic_adapter_t_path = f"{model_path}/Magic_Weights/magic_adapter_t" |
| magic_text_encoder_path = f"{model_path}/Magic_Weights/magic_text_encoder" |
|
|
| css = """ |
| .toolbutton { |
| margin-buttom: 0em 0em 0em 0em; |
| max-width: 2.5em; |
| min-width: 2.5em !important; |
| height: 2.5em; |
| } |
| """ |
|
|
| examples = [ |
| |
| [ |
| "RealisticVisionV60B1_v51VAE.safetensors", |
| "motion_module.ckpt", |
| "Cherry blossoms transitioning from tightly closed buds to a peak state of bloom. The progression moves through stages of bud swelling, petal exposure, and gradual opening, culminating in a full and vibrant display of open blossoms.", |
| "worst quality, low quality, letterboxed", |
| 512, 512, "1534851746" |
| ], |
| |
| [ |
| "RcnzCartoon.safetensors", |
| "motion_module.ckpt", |
| "Time-lapse of a simple modern house's construction in a Minecraft virtual environment: beginning with an avatar laying a white foundation, progressing through wall erection and interior furnishing, to adding roof and exterior details, and completed with landscaping and a tall chimney.", |
| "worst quality, low quality, letterboxed", |
| 512, 512, "3480796026" |
| ], |
| |
| [ |
| "ToonYou_beta6.safetensors", |
| "motion_module.ckpt", |
| "Bean sprouts grow and mature from seeds.", |
| "worst quality, low quality, letterboxed", |
| 512, 512, "1496541313" |
| ] |
| ] |
|
|
| |
| print(f"### Cleaning cached examples ...") |
| os.system(f"rm -rf gradio_cached_examples/") |
|
|
| device = "cuda" |
|
|
| def random_seed(): |
| return random.randint(1, 10**16) |
|
|
| class MagicTimeController: |
| def __init__(self): |
| |
| self.basedir = os.getcwd() |
| self.stable_diffusion_dir = os.path.join(self.basedir, model_path, "Base_Model") |
| self.motion_module_dir = os.path.join(self.basedir, model_path, "Base_Model", "motion_module") |
| self.personalized_model_dir = os.path.join(self.basedir, model_path, "DreamBooth") |
| self.savedir = os.path.join(self.basedir, "outputs") |
| os.makedirs(self.savedir, exist_ok=True) |
|
|
| self.dreambooth_list = [] |
| self.motion_module_list = [] |
| |
| self.selected_dreambooth = None |
| self.selected_motion_module = None |
| |
| self.refresh_motion_module() |
| self.refresh_personalized_model() |
| |
| |
| self.inference_config = OmegaConf.load(inference_config_path)[1] |
|
|
| self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") |
| self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device) |
| self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device) |
| self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device) |
| self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") |
| self.unet_model = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)) |
|
|
| self.update_motion_module(self.motion_module_list[0]) |
| self.update_motion_module_2(self.motion_module_list[0]) |
| self.update_dreambooth(self.dreambooth_list[0]) |
| |
| def refresh_motion_module(self): |
| motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt")) |
| self.motion_module_list = [os.path.basename(p) for p in motion_module_list] |
|
|
| def refresh_personalized_model(self): |
| dreambooth_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) |
| self.dreambooth_list = [os.path.basename(p) for p in dreambooth_list] |
|
|
| def update_dreambooth(self, dreambooth_dropdown, motion_module_dropdown=None): |
| self.selected_dreambooth = dreambooth_dropdown |
| |
| dreambooth_dropdown = os.path.join(self.personalized_model_dir, dreambooth_dropdown) |
| dreambooth_state_dict = {} |
| with safe_open(dreambooth_dropdown, framework="pt", device="cpu") as f: |
| for key in f.keys(): dreambooth_state_dict[key] = f.get_tensor(key) |
| |
| converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, self.vae.config) |
| self.vae.load_state_dict(converted_vae_checkpoint) |
|
|
| del self.unet |
| self.unet = None |
| torch.cuda.empty_cache() |
| time.sleep(1) |
| converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, self.unet_model.config) |
| self.unet = copy.deepcopy(self.unet_model) |
| self.unet.load_state_dict(converted_unet_checkpoint, strict=False) |
|
|
| del self.text_encoder |
| self.text_encoder = None |
| torch.cuda.empty_cache() |
| time.sleep(1) |
| text_model = copy.deepcopy(self.text_model) |
| self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict) |
|
|
| from swift import Swift |
| magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu") |
| self.unet = load_diffusers_lora_unet(self.unet, magic_adapter_s_state_dict, alpha=1.0) |
| self.unet = Swift.from_pretrained(self.unet, magic_adapter_t_path) |
| self.text_encoder = Swift.from_pretrained(self.text_encoder, magic_text_encoder_path) |
|
|
| return gr.Dropdown() |
|
|
| def update_motion_module(self, motion_module_dropdown): |
| self.selected_motion_module = motion_module_dropdown |
| motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown) |
| motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu") |
| _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False) |
| assert len(unexpected) == 0 |
| return gr.Dropdown() |
| |
| def update_motion_module_2(self, motion_module_dropdown): |
| self.selected_motion_module = motion_module_dropdown |
| motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown) |
| motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu") |
| _, unexpected = self.unet_model.load_state_dict(motion_module_state_dict, strict=False) |
| assert len(unexpected) == 0 |
| return gr.Dropdown() |
|
|
| @spaces.GPU(duration=300) |
| def magictime( |
| self, |
| dreambooth_dropdown, |
| motion_module_dropdown, |
| prompt_textbox, |
| negative_prompt_textbox, |
| width_slider, |
| height_slider, |
| seed_textbox, |
| ): |
| torch.cuda.empty_cache() |
| time.sleep(1) |
|
|
| if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown) |
| if self.selected_motion_module != motion_module_dropdown: self.update_motion_module_2(motion_module_dropdown) |
| if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown) |
| |
| while self.text_encoder is None or self.unet is None: |
| self.update_dreambooth(dreambooth_dropdown, motion_module_dropdown) |
|
|
| if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() |
|
|
| pipeline = MagicTimePipeline( |
| vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, |
| scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) |
| ).to(device) |
|
|
| if int(seed_textbox) > 0: seed = int(seed_textbox) |
| else: seed = random_seed() |
| torch.manual_seed(int(seed)) |
| |
| assert seed == torch.initial_seed() |
| print(f"### seed: {seed}") |
| |
| generator = torch.Generator(device=device) |
| generator.manual_seed(seed) |
| |
| sample = pipeline( |
| prompt_textbox, |
| negative_prompt = negative_prompt_textbox, |
| num_inference_steps = 25, |
| guidance_scale = 8., |
| width = width_slider, |
| height = height_slider, |
| video_length = 16, |
| generator = generator, |
| ).videos |
|
|
| save_sample_path = os.path.join(self.savedir, f"sample.mp4") |
| save_videos_grid(sample, save_sample_path) |
| |
| json_config = { |
| "prompt": prompt_textbox, |
| "n_prompt": negative_prompt_textbox, |
| "width": width_slider, |
| "height": height_slider, |
| "seed": seed, |
| "dreambooth": dreambooth_dropdown, |
| } |
|
|
| torch.cuda.empty_cache() |
| time.sleep(1) |
| return gr.Video(value=save_sample_path), gr.Json(value=json_config) |
|
|
| controller = MagicTimeController() |
|
|
| def ui(): |
| with gr.Blocks(css=css) as demo: |
| gr.HTML(""" |
| <div style='display: flex; align-items: center; justify-content: center; text-align: center;'> |
| <img src='https://www.pnglog.com/48rWnj.png' style='width: 300px; height: auto; margin-right: 10px;' /> |
| </div> |
| """) |
| gr.Markdown( |
| """ |
| <h2 align="center"> <a href="https://github.com/PKU-YuanGroup/MagicTime">[TPAMI 2025] MagicTime: Time-lapse Video Generation Models as Metamorphic Simulators</a></h2> |
| <h5 style="text-align:left;">If you like our project, please give us a star ⭐ on GitHub for the latest update.</h5> |
| |
| [GitHub](https://github.com/PKU-YuanGroup/MagicTime) | [arXiv](https://arxiv.org/abs/2404.05014) | [Home Page](https://pku-yuangroup.github.io/MagicTime/) | [Dataset](https://huggingface.co/datasets/BestWishYsh/ChronoMagic) |
| """ |
| ) |
| with gr.Row(): |
| with gr.Column(): |
| dreambooth_dropdown = gr.Dropdown(label="DreamBooth Model", choices=controller.dreambooth_list, value=controller.dreambooth_list[0], interactive=True) |
| motion_module_dropdown = gr.Dropdown(label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True) |
|
|
| prompt_textbox = gr.Textbox(label="Prompt", lines=3) |
| negative_prompt_textbox = gr.Textbox(label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo") |
|
|
| with gr.Accordion("Advance", open=False): |
| with gr.Row(): |
| width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) |
| height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) |
| with gr.Row(): |
| seed_textbox = gr.Textbox(label="Seed (-1 means random)", value="-1") |
| seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") |
| seed_button.click(fn=random_seed, inputs=[], outputs=[seed_textbox]) |
|
|
| generate_button = gr.Button(value="Generate", variant='primary') |
|
|
| with gr.Column(): |
| result_video = gr.Video(label="Generated Animation", interactive=False) |
| json_config = gr.Json(label="Config", value={}) |
|
|
| inputs = [dreambooth_dropdown, motion_module_dropdown, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox] |
| outputs = [result_video, json_config] |
|
|
| generate_button.click(fn=controller.magictime, inputs=inputs, outputs=outputs) |
|
|
| gr.Markdown(""" |
| <h5 style="text-align:left;">⚠ Warning: Even if you use the same seed and prompt, changing machines may produce different results. |
| If you find a better seed and prompt, please submit an issue on GitHub.</h5> |
| """) |
|
|
| gr.Examples(fn=controller.magictime, examples=examples, inputs=inputs, outputs=outputs, cache_examples=True) |
| |
| return demo |
|
|
| if __name__ == "__main__": |
| demo = ui() |
| demo.queue(max_size=20) |
| demo.launch() |