| import os |
| import json |
| import time |
| import torch |
| import random |
| import inspect |
| import argparse |
| import numpy as np |
| import pandas as pd |
| from pathlib import Path |
| from omegaconf import OmegaConf |
| from transformers import CLIPTextModel, CLIPTokenizer |
| from diffusers import AutoencoderKL, DDIMScheduler |
| from diffusers.utils.import_utils import is_xformers_available |
|
|
| from utils.unet import UNet3DConditionModel |
| from utils.pipeline_magictime import MagicTimePipeline |
| from utils.util import save_videos_grid |
| from utils.util import load_weights |
|
|
| @torch.no_grad() |
| def main(args): |
| *_, func_args = inspect.getargvalues(inspect.currentframe()) |
| func_args = dict(func_args) |
| |
| if 'counter' not in globals(): |
| globals()['counter'] = 0 |
| unique_id = globals()['counter'] |
| globals()['counter'] += 1 |
| savedir_base = f"{Path(args.config).stem}" |
| savedir_prefix = "outputs" |
| savedir = None |
| if args.save_path: |
| savedir = os.path.join(savedir_prefix, args.save_path, f"{savedir_base}-{unique_id}") |
| else: |
| savedir = os.path.join(savedir_prefix, f"{savedir_base}-{unique_id}") |
| while os.path.exists(savedir): |
| unique_id = globals()['counter'] |
| globals()['counter'] += 1 |
| if args.save_path: |
| savedir = os.path.join(savedir_prefix, args.save_path, f"{savedir_base}-{unique_id}") |
| else: |
| savedir = os.path.join(savedir_prefix, f"{savedir_base}-{unique_id}") |
| os.makedirs(savedir) |
| print(f"The results will be save to {savedir}") |
|
|
| model_config = OmegaConf.load(args.config)[0] |
| inference_config = OmegaConf.load(args.config)[1] |
|
|
| if model_config.magic_adapter_s_path: |
| print("Use MagicAdapter-S") |
| if model_config.magic_adapter_t_path: |
| print("Use MagicAdapter-T") |
| if model_config.magic_text_encoder_path: |
| print("Use Magic_Text_Encoder") |
|
|
| samples = [] |
|
|
| |
| tokenizer = CLIPTokenizer.from_pretrained(model_config.pretrained_model_path, subfolder="tokenizer") |
| text_encoder = CLIPTextModel.from_pretrained(model_config.pretrained_model_path, subfolder="text_encoder").cuda() |
| vae = AutoencoderKL.from_pretrained(model_config.pretrained_model_path, subfolder="vae").cuda() |
| unet = UNet3DConditionModel.from_pretrained_2d(model_config.pretrained_model_path, subfolder="unet", |
| unet_additional_kwargs=OmegaConf.to_container( |
| inference_config.unet_additional_kwargs)).cuda() |
|
|
| |
| if is_xformers_available() and (not args.without_xformers): |
| unet.enable_xformers_memory_efficient_attention() |
|
|
| pipeline = MagicTimePipeline( |
| vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, |
| scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), |
| ).to("cuda") |
|
|
| pipeline = load_weights( |
| pipeline, |
| motion_module_path=model_config.get("motion_module", ""), |
| dreambooth_model_path=model_config.get("dreambooth_path", ""), |
| magic_adapter_s_path=model_config.get("magic_adapter_s_path", ""), |
| magic_adapter_t_path=model_config.get("magic_adapter_t_path", ""), |
| magic_text_encoder_path=model_config.get("magic_text_encoder_path", ""), |
| ).to("cuda") |
|
|
| sample_idx = 0 |
| if args.human: |
| sample_idx = 0 |
| while True: |
| user_prompt = input("Enter your prompt (or type 'exit' to quit): ") |
| if user_prompt.lower() == "exit": |
| break |
|
|
| random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() |
| torch.manual_seed(random_seed) |
|
|
| print(f"current seed: {random_seed}") |
| print(f"sampling {user_prompt} ...") |
|
|
| |
| |
| sample = pipeline( |
| user_prompt, |
| num_inference_steps=model_config.steps, |
| guidance_scale=model_config.guidance_scale, |
| width=model_config.W, |
| height=model_config.H, |
| video_length=model_config.L, |
| ).videos |
|
|
| |
| prompt_for_filename = "-".join(user_prompt.replace("/", "").split(" ")[:10]) |
| save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.gif") |
| print(f"save to {savedir}/sample/{sample_idx}-{random_seed}-{prompt_for_filename}.gif") |
|
|
| sample_idx += 1 |
| elif args.run_csv: |
| print("run_csv") |
| file_path = args.run_csv |
| data = pd.read_csv(file_path) |
| for index, row in data.iterrows(): |
| user_prompt = row['name'] |
| videoid = row['videoid'] |
|
|
| random_seed = torch.randint(0, 2 ** 32 - 1, (1,)).item() |
| torch.manual_seed(random_seed) |
|
|
| print(f"current seed: {random_seed}") |
| print(f"sampling {user_prompt} ...") |
|
|
| sample = pipeline( |
| user_prompt, |
| num_inference_steps=model_config.steps, |
| guidance_scale=model_config.guidance_scale, |
| width=model_config.W, |
| height=model_config.H, |
| video_length=model_config.L, |
| ).videos |
|
|
| |
| save_videos_grid(sample, f"{savedir}/sample/{videoid}.gif") |
| print(f"save to {savedir}/sample/{videoid}.gif") |
| elif args.run_json: |
| print("run_json") |
| file_path = args.run_json |
|
|
| with open(file_path, 'r') as file: |
| data = json.load(file) |
|
|
| prompts = [] |
| videoids = [] |
| senids = [] |
|
|
| for item in data: |
| prompts.append(item['caption']) |
| videoids.append(item['video_id']) |
| senids.append(item['sen_id']) |
|
|
| n_prompts = list(model_config.n_prompt) * len(prompts) if len( |
| model_config.n_prompt) == 1 else model_config.n_prompt |
|
|
| random_seeds = model_config.get("seed", [-1]) |
| random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) |
| random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds |
|
|
| model_config.random_seed = [] |
| for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): |
| filename = f"MSRVTT/sample/{videoids[prompt_idx]}-{senids[prompt_idx]}.gif" |
|
|
| if os.path.exists(filename): |
| print(f"File {filename} already exists, skipping...") |
| continue |
|
|
| |
| if random_seed != -1: |
| torch.manual_seed(random_seed) |
| else: |
| torch.seed() |
| model_config.random_seed.append(torch.initial_seed()) |
|
|
| print(f"current seed: {torch.initial_seed()}") |
| print(f"sampling {prompt} ...") |
|
|
| sample = pipeline( |
| prompt, |
| num_inference_steps=model_config.steps, |
| guidance_scale=model_config.guidance_scale, |
| width=model_config.W, |
| height=model_config.H, |
| video_length=model_config.L, |
| ).videos |
|
|
| |
| save_videos_grid(sample, filename) |
| print(f"save to {filename}") |
| else: |
| prompts = model_config.prompt |
| n_prompts = list(model_config.n_prompt) * len(prompts) if len( |
| model_config.n_prompt) == 1 else model_config.n_prompt |
|
|
| random_seeds = model_config.get("seed", [-1]) |
| random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) |
| random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds |
|
|
| model_config.random_seed = [] |
| for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): |
|
|
| |
| if random_seed != -1: |
| torch.manual_seed(random_seed) |
| np.random.seed(random_seed) |
| random.seed(random_seed) |
| else: |
| torch.seed() |
| model_config.random_seed.append(torch.initial_seed()) |
|
|
| print(f"current seed: {torch.initial_seed()}") |
| print(f"sampling {prompt} ...") |
| sample = pipeline( |
| prompt, |
| negative_prompt=n_prompt, |
| num_inference_steps=model_config.steps, |
| guidance_scale=model_config.guidance_scale, |
| width=model_config.W, |
| height=model_config.H, |
| video_length=model_config.L, |
| ).videos |
| samples.append(sample) |
|
|
| prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) |
| save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{random_seed}-{prompt}.gif") |
| print(f"save to {savedir}/sample/{random_seed}-{prompt}.gif") |
|
|
| sample_idx += 1 |
| samples = torch.concat(samples) |
| save_videos_grid(samples, f"{savedir}/merge_all.gif", n_rows=4) |
|
|
| OmegaConf.save(model_config, f"{savedir}/model_config.yaml") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True) |
| parser.add_argument("--without-xformers", action="store_true") |
| parser.add_argument("--human", action="store_true", help="Enable human mode for interactive video generation") |
| parser.add_argument("--run-csv", type=str, default=None) |
| parser.add_argument("--run-json", type=str, default=None) |
| parser.add_argument("--save-path", type=str, default=None) |
|
|
| args = parser.parse_args() |
| main(args) |
|
|