| import os |
| from datetime import datetime |
| from pathlib import Path |
| import torch |
| from diffusers import AutoencoderKL, DDIMScheduler |
| from einops import repeat |
| from omegaconf import OmegaConf, DictConfig |
| from PIL import Image |
| from torchvision import transforms |
| from transformers import CLIPVisionModelWithProjection |
| import torch.nn.functional as F |
| import gc |
| from huggingface_hub import hf_hub_download |
| import gradio as gr |
|
|
| from musepose.models.pose_guider import PoseGuider |
| from musepose.models.unet_2d_condition import UNet2DConditionModel |
| from musepose.models.unet_3d import UNet3DConditionModel |
| from musepose.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline |
| from musepose.utils.util import get_fps, read_frames, save_videos_grid |
| from downloading_weights import download_models |
|
|
| |
| import spaces |
|
|
|
|
| class MusePoseInference: |
| def __init__(self, |
| model_dir, |
| output_dir): |
| self.image_gen_model_paths = { |
| "pretrained_base_model": os.path.join(model_dir, "sd-image-variations-diffusers"), |
| "pretrained_vae": os.path.join(model_dir, "sd-vae-ft-mse"), |
| "image_encoder": os.path.join(model_dir, "image_encoder"), |
| } |
| self.musepose_model_paths = { |
| "denoising_unet": os.path.join(model_dir, "MusePose", "denoising_unet.pth"), |
| "reference_unet": os.path.join(model_dir, "MusePose", "reference_unet.pth"), |
| "pose_guider": os.path.join(model_dir, "MusePose", "pose_guider.pth"), |
| "motion_module": os.path.join(model_dir, "MusePose", "motion_module.pth"), |
| } |
| self.inference_config_path = os.path.join("configs", "inference_v2.yaml") |
| self.vae = None |
| self.reference_unet = None |
| self.denoising_unet = None |
| self.pose_guider = None |
| self.image_enc = None |
| self.pipe = None |
| self.model_dir = model_dir |
| self.output_dir = os.path.join(output_dir, "musepose_inference") |
| if not os.path.exists(self.output_dir): |
| os.makedirs(self.output_dir) |
|
|
| @spaces.GPU(duration=180) |
| def infer_musepose( |
| self, |
| ref_image_path: str, |
| pose_video_path: str, |
| weight_dtype: str, |
| W: int, |
| H: int, |
| L: int, |
| S: int, |
| O: int, |
| cfg: float, |
| seed: int, |
| steps: int, |
| fps: int, |
| skip: int, |
| gradio_progress=gr.Progress() |
| ): |
| download_models(model_dir=self.model_dir) |
| print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}") |
| print(f"Input Image Path: {ref_image_path}") |
| print(f"Pose Video Path: {pose_video_path}") |
| print(f"Dtype: {weight_dtype}") |
| print(f"Width: {W}") |
| print(f"Height: {H}") |
| print(f"Video Frame Length: {L}") |
| print(f"VIDEO SLICE FRAME LENGTH:: {S}") |
| print(f"VIDEO SLICE OVERLAP_FRAME NUMBER: {O}") |
| print(f"CFG: {cfg}") |
| print(f"Seed: {seed}") |
| print(f"Steps: {steps}") |
| print(f"FPS: {fps}") |
| print(f"Skip: {skip}") |
|
|
| output_filename = f"output_temp" |
| output_path = os.path.abspath(os.path.join(self.output_dir, f'{output_filename}.mp4')) |
| output_path_demo = os.path.abspath(os.path.join(self.output_dir, f'{output_filename}_demo.mp4')) |
|
|
| if weight_dtype == "fp16": |
| weight_dtype = torch.float16 |
| else: |
| weight_dtype = torch.float32 |
|
|
| inference_config_path = self.inference_config_path |
| infer_config = OmegaConf.load(inference_config_path) |
|
|
| sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) |
| scheduler = DDIMScheduler(**sched_kwargs) |
|
|
| generator = torch.manual_seed(seed) |
|
|
| width, height = W, H |
|
|
| self.init_model(weight_dtype=weight_dtype, infer_config=infer_config) |
|
|
| self.pipe = Pose2VideoPipeline( |
| vae=self.vae, |
| image_encoder=self.image_enc, |
| reference_unet=self.reference_unet, |
| denoising_unet=self.denoising_unet, |
| pose_guider=self.pose_guider, |
| scheduler=scheduler, |
| gradio_progress=gradio_progress |
| ) |
| self.pipe = self.pipe.to("cuda", dtype=weight_dtype) |
|
|
| print("image: ", ref_image_path, "pose_video: ", pose_video_path) |
|
|
| ref_image_pil = Image.open(ref_image_path).convert("RGB") |
|
|
| pose_list = [] |
| pose_tensor_list = [] |
| pose_images = read_frames(pose_video_path) |
| src_fps = get_fps(pose_video_path) |
| print(f"pose video has {len(pose_images)} frames, with {src_fps} fps") |
| L = min(L, len(pose_images)) |
| pose_transform = transforms.Compose( |
| [transforms.Resize((height, width)), transforms.ToTensor()] |
| ) |
| original_width, original_height = 0, 0 |
|
|
| pose_images = pose_images[::skip + 1] |
| print("processing length:", len(pose_images)) |
| src_fps = src_fps // (skip + 1) |
| print("fps", src_fps) |
| L = L // ((skip + 1)) |
|
|
| for pose_image_pil in pose_images[: L]: |
| pose_tensor_list.append(pose_transform(pose_image_pil)) |
| pose_list.append(pose_image_pil) |
| original_width, original_height = pose_image_pil.size |
| pose_image_pil = pose_image_pil.resize((width, height)) |
|
|
| |
| last_segment_frame_num = (L - S) % (S - O) |
| repeart_frame_num = (S - O - last_segment_frame_num) % (S - O) |
| for i in range(repeart_frame_num): |
| pose_list.append(pose_list[-1]) |
| pose_tensor_list.append(pose_tensor_list[-1]) |
|
|
| ref_image_tensor = pose_transform(ref_image_pil) |
| ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) |
| ref_image_tensor = repeat(ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=L) |
|
|
| pose_tensor = torch.stack(pose_tensor_list, dim=0) |
| pose_tensor = pose_tensor.transpose(0, 1) |
| pose_tensor = pose_tensor.unsqueeze(0) |
|
|
| video = self.pipe( |
| ref_image_pil, |
| pose_list, |
| width, |
| height, |
| len(pose_list), |
| steps, |
| cfg, |
| generator=generator, |
| context_frames=S, |
| context_stride=1, |
| context_overlap=O, |
| ).videos |
|
|
| result = self.scale_video(video[:, :, :L], original_width, original_height) |
| save_videos_grid( |
| result, |
| output_path, |
| n_rows=1, |
| fps=src_fps if fps is None or fps < 0 else fps, |
| ) |
|
|
| video = torch.cat([ref_image_tensor, pose_tensor[:, :, :L], video[:, :, :L]], dim=0) |
| video = self.scale_video(video, original_width, original_height) |
| save_videos_grid( |
| video, |
| output_path_demo, |
| n_rows=3, |
| fps=src_fps if fps is None or fps < 0 else fps, |
| ) |
| return output_path, output_path_demo |
|
|
| @spaces.GPU(duration=120) |
| def init_model(self, |
| weight_dtype: torch.dtype, |
| infer_config: DictConfig |
| ): |
| if self.vae is None: |
| self.vae = AutoencoderKL.from_pretrained( |
| self.image_gen_model_paths["pretrained_vae"], |
| ).to("cuda", dtype=weight_dtype) |
|
|
| if self.reference_unet is None: |
| self.reference_unet = UNet2DConditionModel.from_pretrained( |
| self.image_gen_model_paths["pretrained_base_model"], |
| subfolder="unet", |
| ).to(dtype=weight_dtype, device="cuda") |
| self.reference_unet.load_state_dict( |
| torch.load(self.musepose_model_paths["reference_unet"], map_location="cpu"), |
| ) |
|
|
| if self.denoising_unet is None: |
| self.denoising_unet = UNet3DConditionModel.from_pretrained_2d( |
| Path(self.image_gen_model_paths["pretrained_base_model"]), |
| Path(self.musepose_model_paths["motion_module"]), |
| subfolder="unet", |
| unet_additional_kwargs=infer_config.unet_additional_kwargs, |
| ).to(dtype=weight_dtype, device="cuda") |
| self.denoising_unet.load_state_dict( |
| torch.load(self.musepose_model_paths["denoising_unet"], map_location="cpu"), |
| strict=False, |
| ) |
|
|
| if self.pose_guider is None: |
| self.pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( |
| dtype=weight_dtype, device="cuda" |
| ) |
| self.pose_guider.load_state_dict( |
| torch.load(self.musepose_model_paths["pose_guider"], map_location="cpu"), |
| ) |
|
|
| if self.image_enc is None: |
| self.image_enc = CLIPVisionModelWithProjection.from_pretrained( |
| self.image_gen_model_paths["image_encoder"] |
| ).to(dtype=weight_dtype, device="cuda") |
|
|
| def release_vram(self): |
| models = [ |
| 'vae', 'reference_unet', 'denoising_unet', |
| 'pose_guider', 'image_enc', 'pipe' |
| ] |
|
|
| for model_name in models: |
| model = getattr(self, model_name, None) |
| if model is not None: |
| del model |
| setattr(self, model_name, None) |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| gc.collect() |
|
|
| @staticmethod |
| def scale_video(video, width, height): |
| video_reshaped = video.view(-1, *video.shape[2:]) |
| scaled_video = F.interpolate(video_reshaped, size=(height, width), mode='bilinear', align_corners=False) |
| scaled_video = scaled_video.view(*video.shape[:2], scaled_video.shape[1], height, |
| width) |
|
|
| return scaled_video |