pormungtai's picture
Update app.py
73f49b6 verified
import os
import sys
import subprocess
import tempfile
import logging
import gradio as gr
import spaces
import huggingface_hub
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MIMIC_DIR = "./MimicMotion"
MODELS_DIR = "./models"
SVD_DIR = f"{MODELS_DIR}/SVD"
DWPOSE_DIR = f"{MODELS_DIR}/DWPose"
MAX_OUTPUT_FRAMES = 48
# Setup
def setup():
if not os.path.exists(MIMIC_DIR):
logger.info("Cloning tencent/MimicMotion ...")
subprocess.run(
["git", "clone", "--depth=1",
"https://github.com/tencent/MimicMotion.git", MIMIC_DIR],
check=True,
)
sys.path.insert(0, MIMIC_DIR)
loader_path = os.path.join(MIMIC_DIR, "mimicmotion/utils/loader.py")
if os.path.exists(loader_path):
with open(loader_path) as f:
content = f.read()
if "safe_globals(*allowed_modules)" in content:
logger.info("Patching loader.py for newer PyTorch")
content = content.replace(
"safe_globals(*allowed_modules)",
"safe_globals(allowed_modules)",
)
with open(loader_path, "w") as f:
f.write(content)
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(DWPOSE_DIR, exist_ok=True)
for fname in ["yolox_l.onnx", "dw-ll_ucoco_384.onnx"]:
dst = os.path.join(DWPOSE_DIR, fname)
if not os.path.exists(dst):
logger.info(f"Downloading DWPose model: {fname}")
huggingface_hub.hf_hub_download(
repo_id="yzd-v/DWPose", filename=fname, local_dir=DWPOSE_DIR
)
mimic_weight = os.path.join(MODELS_DIR, "MimicMotion_1-1.pth")
if not os.path.exists(mimic_weight):
logger.info("Downloading MimicMotion_1-1.pth ...")
huggingface_hub.hf_hub_download(
repo_id="tencent/MimicMotion",
filename="MimicMotion_1-1.pth",
local_dir=MODELS_DIR,
)
if not os.path.exists(os.path.join(SVD_DIR, "model_index.json")):
if HF_TOKEN:
logger.info("Downloading stable-video-diffusion-img2vid-xt-1-1 ...")
huggingface_hub.snapshot_download(
repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1",
local_dir=SVD_DIR,
token=HF_TOKEN,
ignore_patterns=["*.bin"],
)
else:
logger.warning("HF_TOKEN not set -- SVD model unavailable.")
setup()
# Video pre-trim (runs before GPU block, fast)
def trim_video_to_budget(video_path, max_output_frames, sample_stride):
"""Trim video with ffmpeg so DWPose only sees as many frames as needed."""
try:
import decord
vr = decord.VideoReader(video_path)
total = len(vr)
max_raw = (max_output_frames + 1) * sample_stride
if total <= max_raw:
logger.info(f"Video {total} frames <= budget {max_raw}, no trim.")
return video_path
fps = vr.get_avg_fps() or 30.0
duration = max_raw / fps
out_path = tempfile.mktemp(suffix=".mp4")
subprocess.run(
["ffmpeg", "-y", "-i", video_path,
"-t", f"{duration:.3f}",
"-c:v", "libx264", "-preset", "fast", "-crf", "18",
"-an", out_path],
check=True, capture_output=True,
)
logger.info(f"Pre-trimmed: {total} -> {max_raw} raw frames ({duration:.1f}s)")
return out_path
except Exception as e:
logger.warning(f"Video pre-trim failed ({e}), using original.")
return video_path
# Single GPU block: load pipeline + DWPose + SVD
@spaces.GPU(duration=300)
def run_mimicmotion(ref_image_path, ref_video_path,
resolution, num_frames, num_inference_steps,
noise_aug_strength, guidance_scale,
sample_stride, seed):
import torch
from omegaconf import OmegaConf
from mimicmotion.utils.utils import save_to_mp4
from inference import preprocess, run_pipeline
from mimicmotion.utils.geglu_patch import patch_geglu_inplace
patch_geglu_inplace()
from mimicmotion.utils.loader import create_pipeline
import yaml
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"GPU call: loading pipeline to {device} ...")
cfg_dict = {
"base_model_path": SVD_DIR,
"ckpt_path": os.path.join(MODELS_DIR, "MimicMotion_1-1.pth"),
}
with tempfile.NamedTemporaryFile("w", suffix=".yaml", delete=False) as f:
yaml.dump(cfg_dict, f)
cfg_path = f.name
infer_config = OmegaConf.load(cfg_path)
torch.set_default_dtype(torch.float16)
pipe = create_pipeline(infer_config, device)
logger.info("Pipeline loaded. Running DWPose ...")
pose_pixels, image_pixels = preprocess(
ref_video_path, ref_image_path,
resolution=resolution,
sample_stride=sample_stride,
)
if pose_pixels.shape[0] > MAX_OUTPUT_FRAMES + 1:
logger.info(f"Trimming pose: {pose_pixels.shape[0]} -> {MAX_OUTPUT_FRAMES + 1}")
pose_pixels = pose_pixels[: MAX_OUTPUT_FRAMES + 1]
# Cap num_frames to actual available frames so pipeline chunking never sees an empty index list
actual_frames = pose_pixels.shape[0]
if num_frames > actual_frames:
logger.info(f"Capping num_frames {num_frames} -> {actual_frames} (available pose frames)")
num_frames = actual_frames
logger.info(f"DWPose done ({pose_pixels.shape[0]} frames). Running SVD ...")
task_config = OmegaConf.create({
"num_frames": num_frames,
"frames_overlap": 4,
"num_inference_steps": num_inference_steps,
"noise_aug_strength": noise_aug_strength,
"guidance_scale": guidance_scale,
"seed": seed,
"resolution": resolution,
"sample_stride": sample_stride,
})
video_frames = run_pipeline(pipe, image_pixels, pose_pixels, device, task_config)
out_path = tempfile.mktemp(suffix=".mp4")
save_to_mp4(video_frames, out_path, fps=15)
logger.info(f"Done. Output: {out_path}")
return out_path
# Gradio wrapper
def generate(ref_image, ref_video, resolution, num_frames, num_inference_steps,
noise_aug_strength, guidance_scale, sample_stride, seed):
if ref_image is None:
raise gr.Error("Please upload a reference image.")
if ref_video is None:
raise gr.Error("Please upload a driving video.")
if isinstance(ref_video, dict):
ref_video = ref_video.get("video") or ref_video.get("name") or ref_video.get("path")
if isinstance(ref_image, dict):
ref_image = ref_image.get("path") or ref_image.get("name")
# Trim BEFORE the GPU block so DWPose only sees ~49 frames
ref_video = trim_video_to_budget(ref_video, MAX_OUTPUT_FRAMES, int(sample_stride))
return run_mimicmotion(
ref_image, ref_video,
int(resolution), int(num_frames), int(num_inference_steps),
float(noise_aug_strength), float(guidance_scale),
int(sample_stride), int(seed),
)
# UI
with gr.Blocks(title="MimicMotion") as demo:
gr.Markdown(
"""
# MimicMotion
Upload a **reference image** and a **driving video** -- MimicMotion will animate
the person in the image following the motion in the video.
> **Tips:**
> - Use a **short video (3-5 seconds)**. Longer videos are auto-trimmed to 48 frames.
> - Keep **Sample stride = 4** (default). Lower values = smoother but slower.
> - Generation takes about 3-4 minutes total.
"""
)
with gr.Row():
with gr.Column():
ref_image = gr.Image(label="Reference Image", type="filepath")
ref_video = gr.Video(label="Driving Video")
with gr.Accordion("Advanced settings", open=False):
resolution = gr.Slider(256, 768, value=576, step=64, label="Resolution")
num_frames = gr.Slider(8, 72, value=16, step=8, label="Frames per tile")
num_steps = gr.Slider(5, 50, value=20, step=1, label="Inference steps")
noise_aug = gr.Slider(0.0, 0.1, value=0.0563, step=0.001, label="Noise aug strength")
guidance = gr.Slider(1.0, 10.0, value=2.0, step=0.5, label="Guidance scale")
sample_stride = gr.Slider(1, 4, value=4, step=1, label="Sample stride (4=fast, 1=smooth)")
seed = gr.Number(value=42, label="Seed", precision=0)
run_btn = gr.Button("Generate", variant="primary")
with gr.Column():
output_video = gr.Video(label="Output Video", autoplay=True)
run_btn.click(
fn=generate,
inputs=[ref_image, ref_video, resolution, num_frames, num_steps,
noise_aug, guidance, sample_stride, seed],
outputs=output_video,
)
demo.launch()