Spaces:
Running on Zero
Running on Zero
| import os | |
| import sys | |
| import subprocess | |
| import tempfile | |
| import logging | |
| import gradio as gr | |
| import spaces | |
| import huggingface_hub | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger(__name__) | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| MIMIC_DIR = "./MimicMotion" | |
| MODELS_DIR = "./models" | |
| SVD_DIR = f"{MODELS_DIR}/SVD" | |
| DWPOSE_DIR = f"{MODELS_DIR}/DWPose" | |
| MAX_OUTPUT_FRAMES = 48 | |
| # Setup | |
| def setup(): | |
| if not os.path.exists(MIMIC_DIR): | |
| logger.info("Cloning tencent/MimicMotion ...") | |
| subprocess.run( | |
| ["git", "clone", "--depth=1", | |
| "https://github.com/tencent/MimicMotion.git", MIMIC_DIR], | |
| check=True, | |
| ) | |
| sys.path.insert(0, MIMIC_DIR) | |
| loader_path = os.path.join(MIMIC_DIR, "mimicmotion/utils/loader.py") | |
| if os.path.exists(loader_path): | |
| with open(loader_path) as f: | |
| content = f.read() | |
| if "safe_globals(*allowed_modules)" in content: | |
| logger.info("Patching loader.py for newer PyTorch") | |
| content = content.replace( | |
| "safe_globals(*allowed_modules)", | |
| "safe_globals(allowed_modules)", | |
| ) | |
| with open(loader_path, "w") as f: | |
| f.write(content) | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| os.makedirs(DWPOSE_DIR, exist_ok=True) | |
| for fname in ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]: | |
| dst = os.path.join(DWPOSE_DIR, fname) | |
| if not os.path.exists(dst): | |
| logger.info(f"Downloading DWPose model: {fname}") | |
| huggingface_hub.hf_hub_download( | |
| repo_id="yzd-v/DWPose", filename=fname, local_dir=DWPOSE_DIR | |
| ) | |
| mimic_weight = os.path.join(MODELS_DIR, "MimicMotion_1-1.pth") | |
| if not os.path.exists(mimic_weight): | |
| logger.info("Downloading MimicMotion_1-1.pth ...") | |
| huggingface_hub.hf_hub_download( | |
| repo_id="tencent/MimicMotion", | |
| filename="MimicMotion_1-1.pth", | |
| local_dir=MODELS_DIR, | |
| ) | |
| if not os.path.exists(os.path.join(SVD_DIR, "model_index.json")): | |
| if HF_TOKEN: | |
| logger.info("Downloading stable-video-diffusion-img2vid-xt-1-1 ...") | |
| huggingface_hub.snapshot_download( | |
| repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1", | |
| local_dir=SVD_DIR, | |
| token=HF_TOKEN, | |
| ignore_patterns=["*.bin"], | |
| ) | |
| else: | |
| logger.warning("HF_TOKEN not set -- SVD model unavailable.") | |
| setup() | |
| # Video pre-trim (runs before GPU block, fast) | |
| def trim_video_to_budget(video_path, max_output_frames, sample_stride): | |
| """Trim video with ffmpeg so DWPose only sees as many frames as needed.""" | |
| try: | |
| import decord | |
| vr = decord.VideoReader(video_path) | |
| total = len(vr) | |
| max_raw = (max_output_frames + 1) * sample_stride | |
| if total <= max_raw: | |
| logger.info(f"Video {total} frames <= budget {max_raw}, no trim.") | |
| return video_path | |
| fps = vr.get_avg_fps() or 30.0 | |
| duration = max_raw / fps | |
| out_path = tempfile.mktemp(suffix=".mp4") | |
| subprocess.run( | |
| ["ffmpeg", "-y", "-i", video_path, | |
| "-t", f"{duration:.3f}", | |
| "-c:v", "libx264", "-preset", "fast", "-crf", "18", | |
| "-an", out_path], | |
| check=True, capture_output=True, | |
| ) | |
| logger.info(f"Pre-trimmed: {total} -> {max_raw} raw frames ({duration:.1f}s)") | |
| return out_path | |
| except Exception as e: | |
| logger.warning(f"Video pre-trim failed ({e}), using original.") | |
| return video_path | |
| # Single GPU block: load pipeline + DWPose + SVD | |
| def run_mimicmotion(ref_image_path, ref_video_path, | |
| resolution, num_frames, num_inference_steps, | |
| noise_aug_strength, guidance_scale, | |
| sample_stride, seed): | |
| import torch | |
| from omegaconf import OmegaConf | |
| from mimicmotion.utils.utils import save_to_mp4 | |
| from inference import preprocess, run_pipeline | |
| from mimicmotion.utils.geglu_patch import patch_geglu_inplace | |
| patch_geglu_inplace() | |
| from mimicmotion.utils.loader import create_pipeline | |
| import yaml | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"GPU call: loading pipeline to {device} ...") | |
| cfg_dict = { | |
| "base_model_path": SVD_DIR, | |
| "ckpt_path": os.path.join(MODELS_DIR, "MimicMotion_1-1.pth"), | |
| } | |
| with tempfile.NamedTemporaryFile("w", suffix=".yaml", delete=False) as f: | |
| yaml.dump(cfg_dict, f) | |
| cfg_path = f.name | |
| infer_config = OmegaConf.load(cfg_path) | |
| torch.set_default_dtype(torch.float16) | |
| pipe = create_pipeline(infer_config, device) | |
| logger.info("Pipeline loaded. Running DWPose ...") | |
| pose_pixels, image_pixels = preprocess( | |
| ref_video_path, ref_image_path, | |
| resolution=resolution, | |
| sample_stride=sample_stride, | |
| ) | |
| if pose_pixels.shape[0] > MAX_OUTPUT_FRAMES + 1: | |
| logger.info(f"Trimming pose: {pose_pixels.shape[0]} -> {MAX_OUTPUT_FRAMES + 1}") | |
| pose_pixels = pose_pixels[: MAX_OUTPUT_FRAMES + 1] | |
| # Cap num_frames to actual available frames so pipeline chunking never sees an empty index list | |
| actual_frames = pose_pixels.shape[0] | |
| if num_frames > actual_frames: | |
| logger.info(f"Capping num_frames {num_frames} -> {actual_frames} (available pose frames)") | |
| num_frames = actual_frames | |
| logger.info(f"DWPose done ({pose_pixels.shape[0]} frames). Running SVD ...") | |
| task_config = OmegaConf.create({ | |
| "num_frames": num_frames, | |
| "frames_overlap": 4, | |
| "num_inference_steps": num_inference_steps, | |
| "noise_aug_strength": noise_aug_strength, | |
| "guidance_scale": guidance_scale, | |
| "seed": seed, | |
| "resolution": resolution, | |
| "sample_stride": sample_stride, | |
| }) | |
| video_frames = run_pipeline(pipe, image_pixels, pose_pixels, device, task_config) | |
| out_path = tempfile.mktemp(suffix=".mp4") | |
| save_to_mp4(video_frames, out_path, fps=15) | |
| logger.info(f"Done. Output: {out_path}") | |
| return out_path | |
| # Gradio wrapper | |
| def generate(ref_image, ref_video, resolution, num_frames, num_inference_steps, | |
| noise_aug_strength, guidance_scale, sample_stride, seed): | |
| if ref_image is None: | |
| raise gr.Error("Please upload a reference image.") | |
| if ref_video is None: | |
| raise gr.Error("Please upload a driving video.") | |
| if isinstance(ref_video, dict): | |
| ref_video = ref_video.get("video") or ref_video.get("name") or ref_video.get("path") | |
| if isinstance(ref_image, dict): | |
| ref_image = ref_image.get("path") or ref_image.get("name") | |
| # Trim BEFORE the GPU block so DWPose only sees ~49 frames | |
| ref_video = trim_video_to_budget(ref_video, MAX_OUTPUT_FRAMES, int(sample_stride)) | |
| return run_mimicmotion( | |
| ref_image, ref_video, | |
| int(resolution), int(num_frames), int(num_inference_steps), | |
| float(noise_aug_strength), float(guidance_scale), | |
| int(sample_stride), int(seed), | |
| ) | |
| # UI | |
| with gr.Blocks(title="MimicMotion") as demo: | |
| gr.Markdown( | |
| """ | |
| # MimicMotion | |
| Upload a **reference image** and a **driving video** -- MimicMotion will animate | |
| the person in the image following the motion in the video. | |
| > **Tips:** | |
| > - Use a **short video (3-5 seconds)**. Longer videos are auto-trimmed to 48 frames. | |
| > - Keep **Sample stride = 4** (default). Lower values = smoother but slower. | |
| > - Generation takes about 3-4 minutes total. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| ref_image = gr.Image(label="Reference Image", type="filepath") | |
| ref_video = gr.Video(label="Driving Video") | |
| with gr.Accordion("Advanced settings", open=False): | |
| resolution = gr.Slider(256, 768, value=576, step=64, label="Resolution") | |
| num_frames = gr.Slider(8, 72, value=16, step=8, label="Frames per tile") | |
| num_steps = gr.Slider(5, 50, value=20, step=1, label="Inference steps") | |
| noise_aug = gr.Slider(0.0, 0.1, value=0.0563, step=0.001, label="Noise aug strength") | |
| guidance = gr.Slider(1.0, 10.0, value=2.0, step=0.5, label="Guidance scale") | |
| sample_stride = gr.Slider(1, 4, value=4, step=1, label="Sample stride (4=fast, 1=smooth)") | |
| seed = gr.Number(value=42, label="Seed", precision=0) | |
| run_btn = gr.Button("Generate", variant="primary") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Output Video", autoplay=True) | |
| run_btn.click( | |
| fn=generate, | |
| inputs=[ref_image, ref_video, resolution, num_frames, num_steps, | |
| noise_aug, guidance, sample_stride, seed], | |
| outputs=output_video, | |
| ) | |
| demo.launch() | |