|
|
| import os |
| import json |
| import torch |
| import random |
|
|
| import gradio as gr |
| from glob import glob |
| from omegaconf import OmegaConf |
| from datetime import datetime |
| from safetensors import safe_open |
|
|
| from diffusers import AutoencoderKL |
| from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler |
| from diffusers.utils.import_utils import is_xformers_available |
| from transformers import CLIPTextModel, CLIPTokenizer |
|
|
| from animatediff.models.unet import UNet3DConditionModel |
| from animatediff.pipelines.pipeline_animation import AnimationPipeline |
| from animatediff.utils.util import save_videos_grid |
| from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint |
| from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora |
|
|
|
|
| sample_idx = 0 |
| scheduler_dict = { |
| "Euler": EulerDiscreteScheduler, |
| "PNDM": PNDMScheduler, |
| "DDIM": DDIMScheduler, |
| } |
|
|
| css = """ |
| .toolbutton { |
| margin-buttom: 0em 0em 0em 0em; |
| max-width: 2.5em; |
| min-width: 2.5em !important; |
| height: 2.5em; |
| } |
| """ |
|
|
| class AnimateController: |
| def __init__(self): |
| |
| |
| self.basedir = os.getcwd() |
| self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") |
| self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") |
| self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA") |
| self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) |
| self.savedir_sample = os.path.join(self.savedir, "sample") |
| os.makedirs(self.savedir, exist_ok=True) |
|
|
| self.stable_diffusion_list = [] |
| self.motion_module_list = [] |
| self.personalized_model_list = [] |
| |
| self.refresh_stable_diffusion() |
| self.refresh_motion_module() |
| self.refresh_personalized_model() |
| |
| |
| self.tokenizer = None |
| self.text_encoder = None |
| self.vae = None |
| self.unet = None |
| self.pipeline = None |
| self.lora_model_state_dict = {} |
| |
| self.inference_config = OmegaConf.load("configs/inference/inference.yaml") |
|
|
| def refresh_stable_diffusion(self): |
| self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/")) |
|
|
| def refresh_motion_module(self): |
| motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt")) |
| self.motion_module_list = [os.path.basename(p) for p in motion_module_list] |
|
|
| def refresh_personalized_model(self): |
| personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) |
| self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] |
|
|
| def update_stable_diffusion(self, stable_diffusion_dropdown): |
| self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer") |
| self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda() |
| self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda() |
| self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() |
| return gr.Dropdown.update() |
|
|
| def update_motion_module(self, motion_module_dropdown): |
| if self.unet is None: |
| gr.Info(f"Please select a pretrained model path.") |
| return gr.Dropdown.update(value=None) |
| else: |
| motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown) |
| motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu") |
| missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False) |
| assert len(unexpected) == 0 |
| return gr.Dropdown.update() |
|
|
| def update_base_model(self, base_model_dropdown): |
| if self.unet is None: |
| gr.Info(f"Please select a pretrained model path.") |
| return gr.Dropdown.update(value=None) |
| else: |
| base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown) |
| base_model_state_dict = {} |
| with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: |
| for key in f.keys(): |
| base_model_state_dict[key] = f.get_tensor(key) |
| |
| converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config) |
| self.vae.load_state_dict(converted_vae_checkpoint) |
|
|
| converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config) |
| self.unet.load_state_dict(converted_unet_checkpoint, strict=False) |
|
|
| self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict) |
| return gr.Dropdown.update() |
|
|
| def update_lora_model(self, lora_model_dropdown): |
| lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown) |
| self.lora_model_state_dict = {} |
| if lora_model_dropdown == "none": pass |
| else: |
| with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: |
| for key in f.keys(): |
| self.lora_model_state_dict[key] = f.get_tensor(key) |
| return gr.Dropdown.update() |
|
|
| def animate( |
| self, |
| stable_diffusion_dropdown, |
| motion_module_dropdown, |
| base_model_dropdown, |
| lora_alpha_slider, |
| prompt_textbox, |
| negative_prompt_textbox, |
| sampler_dropdown, |
| sample_step_slider, |
| width_slider, |
| length_slider, |
| height_slider, |
| cfg_scale_slider, |
| seed_textbox |
| ): |
| if self.unet is None: |
| raise gr.Error(f"Please select a pretrained model path.") |
| if motion_module_dropdown == "": |
| raise gr.Error(f"Please select a motion module.") |
| if base_model_dropdown == "": |
| raise gr.Error(f"Please select a base DreamBooth model.") |
|
|
| if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() |
|
|
| pipeline = AnimationPipeline( |
| vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, |
| scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) |
| ).to("cuda") |
| |
| if self.lora_model_state_dict != {}: |
| pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider) |
|
|
| pipeline.to("cuda") |
|
|
| if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) |
| else: torch.seed() |
| seed = torch.initial_seed() |
| |
| sample = pipeline( |
| prompt_textbox, |
| negative_prompt = negative_prompt_textbox, |
| num_inference_steps = sample_step_slider, |
| guidance_scale = cfg_scale_slider, |
| width = width_slider, |
| height = height_slider, |
| video_length = length_slider, |
| ).videos |
|
|
| save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4") |
| save_videos_grid(sample, save_sample_path) |
| |
| sample_config = { |
| "prompt": prompt_textbox, |
| "n_prompt": negative_prompt_textbox, |
| "sampler": sampler_dropdown, |
| "num_inference_steps": sample_step_slider, |
| "guidance_scale": cfg_scale_slider, |
| "width": width_slider, |
| "height": height_slider, |
| "video_length": length_slider, |
| "seed": seed |
| } |
| json_str = json.dumps(sample_config, indent=4) |
| with open(os.path.join(self.savedir, "logs.json"), "a") as f: |
| f.write(json_str) |
| f.write("\n\n") |
| |
| return gr.Video.update(value=save_sample_path) |
| |
|
|
| controller = AnimateController() |
|
|
|
|
| def ui(): |
| with gr.Blocks(css=css) as demo: |
| gr.Markdown( |
| """ |
| # [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725) |
| Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br> |
| [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/) |
| """ |
| ) |
| with gr.Column(variant="panel"): |
| gr.Markdown( |
| """ |
| ### 1. Model checkpoints (select pretrained model path first). |
| """ |
| ) |
| with gr.Row(): |
| stable_diffusion_dropdown = gr.Dropdown( |
| label="Pretrained Model Path", |
| choices=controller.stable_diffusion_list, |
| interactive=True, |
| ) |
| stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown]) |
| |
| stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") |
| def update_stable_diffusion(): |
| controller.refresh_stable_diffusion() |
| return gr.Dropdown.update(choices=controller.stable_diffusion_list) |
| stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown]) |
|
|
| with gr.Row(): |
| motion_module_dropdown = gr.Dropdown( |
| label="Select motion module", |
| choices=controller.motion_module_list, |
| interactive=True, |
| ) |
| motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown]) |
| |
| motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") |
| def update_motion_module(): |
| controller.refresh_motion_module() |
| return gr.Dropdown.update(choices=controller.motion_module_list) |
| motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown]) |
| |
| base_model_dropdown = gr.Dropdown( |
| label="Select base Dreambooth model (required)", |
| choices=controller.personalized_model_list, |
| interactive=True, |
| ) |
| base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown]) |
| |
| lora_model_dropdown = gr.Dropdown( |
| label="Select LoRA model (optional)", |
| choices=["none"] + controller.personalized_model_list, |
| value="none", |
| interactive=True, |
| ) |
| lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown]) |
| |
| lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) |
| |
| personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") |
| def update_personalized_model(): |
| controller.refresh_personalized_model() |
| return [ |
| gr.Dropdown.update(choices=controller.personalized_model_list), |
| gr.Dropdown.update(choices=["none"] + controller.personalized_model_list) |
| ] |
| personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown]) |
|
|
| with gr.Column(variant="panel"): |
| gr.Markdown( |
| """ |
| ### 2. Configs for AnimateDiff. |
| """ |
| ) |
| |
| prompt_textbox = gr.Textbox(label="Prompt", lines=2) |
| negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2) |
| |
| with gr.Row().style(equal_height=False): |
| with gr.Column(): |
| with gr.Row(): |
| sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) |
| sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1) |
| |
| width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) |
| height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) |
| length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1) |
| cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20) |
| |
| with gr.Row(): |
| seed_textbox = gr.Textbox(label="Seed", value=-1) |
| seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") |
| seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) |
| |
| generate_button = gr.Button(value="Generate", variant='primary') |
| |
| result_video = gr.Video(label="Generated Animation", interactive=False) |
|
|
| generate_button.click( |
| fn=controller.animate, |
| inputs=[ |
| stable_diffusion_dropdown, |
| motion_module_dropdown, |
| base_model_dropdown, |
| lora_alpha_slider, |
| prompt_textbox, |
| negative_prompt_textbox, |
| sampler_dropdown, |
| sample_step_slider, |
| width_slider, |
| length_slider, |
| height_slider, |
| cfg_scale_slider, |
| seed_textbox, |
| ], |
| outputs=[result_video] |
| ) |
| |
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| demo = ui() |
| demo.launch(share=True) |
|
|