AniDoc / gradio_app.py
fffiloni's picture
gradio asyncio bug
0645e7a verified
def _patch_asyncio_event_loop_del():
"""
Patch a noisy asyncio teardown issue sometimes seen in Spaces environments.
In some runtime/container combinations, Python may try to close an already
invalid file descriptor when the event loop is garbage-collected. We silence
only that specific harmless case.
"""
try:
import asyncio.base_events as base_events
original_del = getattr(base_events.BaseEventLoop, "__del__", None)
if original_del is None:
return
def patched_del(self):
try:
original_del(self)
except ValueError as e:
if "Invalid file descriptor" not in str(e):
raise
base_events.BaseEventLoop.__del__ = patched_del
except Exception:
pass
_patch_asyncio_event_loop_del()
import os
import sys
import gc
import cv2
import types
import uuid
from glob import glob
import gradio as gr
import imageio.v2 as imageio
import numpy as np
import spaces
import torch
import torchvision.transforms as T
from packaging import version
from PIL import Image
sys.path.insert(0, ".")
# -----------------------------------------------------------------------------
# Compatibility shims
# -----------------------------------------------------------------------------
import huggingface_hub
if not hasattr(huggingface_hub, "cached_download"):
from huggingface_hub import hf_hub_download as _hf_hub_download_compat
huggingface_hub.cached_download = _hf_hub_download_compat
import torchvision.transforms.functional as TVF
if "torchvision.transforms.functional_tensor" not in sys.modules:
functional_tensor = types.ModuleType("torchvision.transforms.functional_tensor")
functional_tensor.rgb_to_grayscale = TVF.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor
# -----------------------------------------------------------------------------
# Imports that rely on the shims
# -----------------------------------------------------------------------------
from huggingface_hub import snapshot_download, hf_hub_download
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available
from LightGlue.lightglue import LightGlue, SuperPoint
from LightGlue.lightglue.utils import rbd
from cotracker.predictor import CoTrackerPredictor, sample_trajectories_with_ref
from lineart_extractor.annotator.lineart import LineartDetector
from models_diffusers.controlnet_svd import ControlNetSVDModel
from models_diffusers.unet_spatio_temporal_condition import (
UNetSpatioTemporalConditionModel,
)
from pipelines.AniDoc import AniDocPipeline
from utils import (
extract_frames_from_video,
export_gif_side_by_side_complete_ablation,
generate_point_map,
generate_point_map_frames,
load_images_from_folder,
safe_round,
select_multiple_points,
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
PRETRAINED_SVD_DIR = "pretrained_weights/stable-video-diffusion-img2vid-xt"
PRETRAINED_ANIDOC_DIR = "pretrained_weights/anidoc"
PRETRAINED_CONTROLNET_DIR = "pretrained_weights/anidoc/controlnet"
COTRACKER_PATH = "pretrained_weights/cotracker2.pth"
DEFAULT_WIDTH = 512
DEFAULT_HEIGHT = 320
DEFAULT_NUM_FRAMES = 14
DEFAULT_SEED = 42
DEFAULT_MAX_POINTS = 10
DEFAULT_FPS = 7
DEFAULT_MOTION_BUCKET_ID = 127
DEFAULT_NOISE_AUG = 0.02
DEFAULT_DECODE_CHUNK_SIZE = 8
DEVICE = "cuda"
DTYPE = torch.float16
# -----------------------------------------------------------------------------
# Startup downloads only
# -----------------------------------------------------------------------------
def ensure_weights_downloaded() -> None:
os.makedirs("pretrained_weights", exist_ok=True)
os.makedirs(PRETRAINED_SVD_DIR, exist_ok=True)
snapshot_download(
repo_id="stabilityai/stable-video-diffusion-img2vid-xt",
local_dir=PRETRAINED_SVD_DIR,
)
snapshot_download(
repo_id="Yhmeng1106/anidoc",
local_dir="./pretrained_weights",
)
hf_hub_download(
repo_id="facebook/cotracker",
filename="cotracker2.pth",
local_dir="./pretrained_weights",
)
# -----------------------------------------------------------------------------
# Runtime model loading (GPU-only, per run)
# -----------------------------------------------------------------------------
def load_runtime_models():
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available in the current execution context.")
unet = UNetSpatioTemporalConditionModel.from_pretrained(
PRETRAINED_ANIDOC_DIR,
subfolder="unet",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
custom_resume=True,
)
unet.to(DEVICE, DTYPE)
controlnet = ControlNetSVDModel.from_pretrained(PRETRAINED_CONTROLNET_DIR)
controlnet.to(DEVICE, DTYPE)
if is_xformers_available():
import xformers # noqa: F401
_ = version.parse(xformers.__version__)
unet.enable_xformers_memory_efficient_attention()
else:
raise RuntimeError("xformers is not available. Make sure it is installed correctly.")
pipe = AniDocPipeline.from_pretrained(
PRETRAINED_SVD_DIR,
unet=unet,
controlnet=controlnet,
low_cpu_mem_usage=False,
torch_dtype=torch.float16,
variant="fp16",
)
pipe.to(DEVICE)
detector = LineartDetector(DEVICE)
extractor = SuperPoint(max_num_keypoints=2000).eval().to(DEVICE)
matcher = LightGlue(features="superpoint").eval().to(DEVICE)
tracker = CoTrackerPredictor(
checkpoint=COTRACKER_PATH,
shift_grid=0,
)
tracker.requires_grad_(False)
tracker.to(DEVICE, dtype=torch.float32)
return {
"unet": unet,
"controlnet": controlnet,
"pipe": pipe,
"detector": detector,
"extractor": extractor,
"matcher": matcher,
"tracker": tracker,
}
def cleanup_runtime_models(models) -> None:
try:
if models:
for value in models.values():
try:
del value
except Exception:
pass
finally:
gc.collect()
if torch.cuda.is_available():
try:
torch.cuda.empty_cache()
except Exception:
pass
try:
torch.cuda.ipc_collect()
except Exception:
pass
# -----------------------------------------------------------------------------
# Video metadata / final video writer
# -----------------------------------------------------------------------------
def get_video_dimensions(video_path: str) -> tuple[int, int]:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Unable to open input video: {video_path}")
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
if width <= 0 or height <= 0:
raise ValueError(f"Invalid input video dimensions: {video_path}")
return width, height
def get_video_fps(video_path: str) -> float:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Unable to open input video: {video_path}")
fps = cap.get(cv2.CAP_PROP_FPS)
cap.release()
if fps <= 0:
fps = DEFAULT_FPS
return float(fps)
def write_video_frames_to_mp4(
video_frames,
output_path: str,
target_width: int,
target_height: int,
target_fps: float,
) -> str:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
writer = imageio.get_writer(
output_path,
fps=target_fps,
codec="libx264",
format="FFMPEG",
pixelformat="yuv420p",
)
try:
for frame in video_frames:
if isinstance(frame, Image.Image):
frame_np = np.array(frame.convert("RGB"))
else:
frame_np = np.asarray(frame)
if frame_np.ndim == 2:
frame_np = np.stack([frame_np] * 3, axis=-1)
elif frame_np.shape[-1] == 4:
frame_np = frame_np[..., :3]
resized = cv2.resize(
frame_np,
(target_width, target_height),
interpolation=cv2.INTER_LINEAR,
)
writer.append_data(resized)
finally:
writer.close()
return output_path
# -----------------------------------------------------------------------------
# Inference helpers
# -----------------------------------------------------------------------------
def load_control_images(control_path: str, width: int, height: int):
if os.path.isdir(control_path):
control_images = load_images_from_folder(control_path)
elif control_path.lower().endswith(".mp4"):
control_images = extract_frames_from_video(control_path)
else:
raise ValueError("Control Sequence must be a folder or an .mp4 file.")
if not control_images:
raise ValueError("No frames could be loaded from the control sequence.")
return [img.resize((width, height)) for img in control_images]
def build_sketch_condition(
control_images,
detector,
width: int,
height: int,
quantize_sketch: bool,
):
controlnet_images = []
for img in control_images:
sketch = np.array(img)
sketch = detector(sketch, coarse=False)
sketch = np.repeat(sketch[:, :, np.newaxis], 3, axis=2)
if quantize_sketch:
sketch = (sketch > 200).astype(np.uint8) * 255
sketch = Image.fromarray(sketch).resize((width, height))
controlnet_images.append(sketch)
sketch_tensors = [T.ToTensor()(img).unsqueeze(0) for img in controlnet_images]
controlnet_sketch_condition = torch.cat(sketch_tensors, dim=0).unsqueeze(0).to(
DEVICE, dtype=torch.float16
)
controlnet_sketch_condition = (controlnet_sketch_condition - 0.5) / 0.5
return controlnet_images, controlnet_sketch_condition
def build_matching_or_tracking_condition(
ref_image_pil,
controlnet_images,
controlnet_sketch_condition,
extractor,
matcher,
tracker,
width: int,
height: int,
num_frames: int,
max_points: int,
use_tracking: bool,
tracker_grid_size: int,
tracker_backward_tracking: bool,
repeat_matching: bool,
):
with torch.no_grad():
ref_img_value = T.ToTensor()(ref_image_pil).to(DEVICE, dtype=torch.float16)
ref_img_value = ref_img_value.to(torch.float32)
current_img = T.ToTensor()(controlnet_images[0]).to(DEVICE, dtype=torch.float16)
current_img = current_img.to(torch.float32)
feats0 = extractor.extract(ref_img_value)
feats1 = extractor.extract(current_img)
matches01 = matcher({"image0": feats0, "image1": feats1})
feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]]
matches = matches01["matches"]
points0 = feats0["keypoints"][matches[..., 0]].cpu().numpy()
points1 = feats1["keypoints"][matches[..., 1]].cpu().numpy()
points0 = safe_round(points0, current_img.shape)
points1 = safe_round(points1, current_img.shape)
num_selected_points = min(50, points0.shape[0])
points0, points1 = select_multiple_points(points0, points1, num_selected_points)
mask1, mask2 = generate_point_map(
size=current_img.shape,
coords0=points0,
coords1=points1,
)
point_map1 = torch.from_numpy(mask1).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(
DEVICE, dtype=torch.float16
)
point_map2 = torch.from_numpy(mask2).unsqueeze(0).unsqueeze(0).unsqueeze(0).to(
DEVICE, dtype=torch.float16
)
point_map = torch.cat([point_map1, point_map2], dim=2)
conditional_pixel_values = ref_img_value.unsqueeze(0).unsqueeze(0)
conditional_pixel_values = (conditional_pixel_values - 0.5) / 0.5
point_map_with_ref = torch.cat([point_map, conditional_pixel_values], dim=2)
if repeat_matching:
matching_controlnet_image = point_map_with_ref.repeat(1, num_frames, 1, 1, 1)
controlnet_condition = torch.cat(
[controlnet_sketch_condition, matching_controlnet_image], dim=2
)
return controlnet_condition
if use_tracking:
video_for_tracker = (controlnet_sketch_condition * 0.5 + 0.5) * 255.0
queries = np.insert(points1, 0, 0, axis=1)
queries = torch.from_numpy(queries).to(DEVICE, torch.float32).unsqueeze(0)
if queries.shape[1] == 0:
tracked_mask1 = np.zeros((height, width), dtype=np.uint8)
tracked_mask2 = np.zeros((num_frames, height, width), dtype=np.uint8)
else:
pred_tracks, pred_visibility = tracker(
video_for_tracker.to(dtype=torch.float32),
queries=queries,
grid_size=tracker_grid_size,
grid_query_frame=0,
backward_tracking=tracker_backward_tracking,
)
pred_tracks_sampled, pred_visibility_sampled, points0_sampled = (
sample_trajectories_with_ref(
pred_tracks.cpu(),
pred_visibility.cpu(),
torch.from_numpy(points0).unsqueeze(0).cpu(),
max_points=max_points,
motion_threshold=1,
vis_threshold=3,
)
)
if pred_tracks_sampled is None:
tracked_mask1 = np.zeros((height, width), dtype=np.uint8)
tracked_mask2 = np.zeros((num_frames, height, width), dtype=np.uint8)
else:
pred_tracks_sampled = pred_tracks_sampled.squeeze(0).cpu().numpy()
pred_visibility_sampled = pred_visibility_sampled.squeeze(0).cpu().numpy()
points0_sampled = points0_sampled.squeeze(0).cpu().numpy()
for frame_idx in range(num_frames):
pred_tracks_sampled[frame_idx] = safe_round(
pred_tracks_sampled[frame_idx],
current_img.shape,
)
points0_sampled = safe_round(points0_sampled, current_img.shape)
tracked_mask1, tracked_mask2 = generate_point_map_frames(
size=current_img.shape,
coords0=points0_sampled,
coords1=pred_tracks_sampled,
visibility=pred_visibility_sampled,
)
tracked_point_map1 = (
torch.from_numpy(tracked_mask1)
.unsqueeze(0)
.unsqueeze(0)
.repeat(1, num_frames, 1, 1, 1)
.to(DEVICE, dtype=torch.float16)
)
tracked_point_map2 = (
torch.from_numpy(tracked_mask2)
.unsqueeze(0)
.unsqueeze(2)
.to(DEVICE, dtype=torch.float16)
)
tracked_point_map = torch.cat([tracked_point_map1, tracked_point_map2], dim=2)
conditional_pixel_values_repeat = conditional_pixel_values.repeat(
1, num_frames, 1, 1, 1
)
point_map_with_ref = torch.cat(
[tracked_point_map, conditional_pixel_values_repeat], dim=2
)
controlnet_condition = torch.cat(
[controlnet_sketch_condition, point_map_with_ref], dim=2
)
return controlnet_condition
original_shape = list(point_map_with_ref.shape)
new_shape = original_shape.copy()
new_shape[1] = num_frames - 1
zero_tensor = torch.zeros(new_shape).to(DEVICE, dtype=torch.float16)
matching_controlnet_image = torch.cat((point_map_with_ref, zero_tensor), dim=1)
controlnet_condition = torch.cat(
[controlnet_sketch_condition, matching_controlnet_image], dim=2
)
return controlnet_condition
def run_anidoc(
models,
control_path: str,
ref_path: str,
output_dir: str,
width: int = DEFAULT_WIDTH,
height: int = DEFAULT_HEIGHT,
num_frames: int = DEFAULT_NUM_FRAMES,
seed: int = DEFAULT_SEED,
max_points: int = DEFAULT_MAX_POINTS,
noise_aug: float = DEFAULT_NOISE_AUG,
use_tracking: bool = True,
repeat_matching: bool = False,
quantize_sketch: bool = True,
tracker_grid_size: int = 8,
tracker_backward_tracking: bool = False,
):
pipe = models["pipe"]
detector = models["detector"]
extractor = models["extractor"]
matcher = models["matcher"]
tracker = models["tracker"]
os.makedirs(output_dir, exist_ok=True)
control_images = load_control_images(control_path, width=width, height=height)
ref_image_pil = load_image(ref_path).resize((width, height))
controlnet_images, controlnet_sketch_condition = build_sketch_condition(
control_images=control_images,
detector=detector,
width=width,
height=height,
quantize_sketch=quantize_sketch,
)
controlnet_condition = build_matching_or_tracking_condition(
ref_image_pil=ref_image_pil,
controlnet_images=controlnet_images,
controlnet_sketch_condition=controlnet_sketch_condition,
extractor=extractor,
matcher=matcher,
tracker=tracker,
width=width,
height=height,
num_frames=num_frames,
max_points=max_points,
use_tracking=use_tracking,
tracker_grid_size=tracker_grid_size,
tracker_backward_tracking=tracker_backward_tracking,
repeat_matching=repeat_matching,
)
input_fps = DEFAULT_FPS
target_width = width
target_height = height
if control_path.lower().endswith(".mp4"):
input_fps = get_video_fps(control_path)
target_width, target_height = get_video_dimensions(control_path)
generator = torch.manual_seed(seed)
with torch.inference_mode():
video_frames = pipe(
ref_image_pil,
controlnet_condition,
height=height,
width=width,
num_frames=num_frames,
decode_chunk_size=DEFAULT_DECODE_CHUNK_SIZE,
motion_bucket_id=DEFAULT_MOTION_BUCKET_ID,
fps=input_fps,
noise_aug_strength=noise_aug,
generator=generator,
).frames[0]
ref_base_name = os.path.splitext(os.path.basename(ref_path))[0]
control_base_name = os.path.splitext(os.path.basename(control_path))[0]
support_dir = os.path.join(output_dir, f"{ref_base_name}_{control_base_name}")
os.makedirs(support_dir, exist_ok=True)
out_file = support_dir + ".mp4"
export_gif_side_by_side_complete_ablation(
ref_image_pil,
controlnet_images,
video_frames,
out_file.replace(".mp4", ".gif"),
support_dir,
6,
)
final_output_path = os.path.join(
support_dir, f"{ref_base_name}_{control_base_name}_final.mp4"
)
final_output_path = write_video_frames_to_mp4(
video_frames=video_frames,
output_path=final_output_path,
target_width=target_width,
target_height=target_height,
target_fps=input_fps,
)
if not os.path.exists(final_output_path):
raise FileNotFoundError("Final output .mp4 file was not written.")
return final_output_path
# -----------------------------------------------------------------------------
# Gradio callback
# -----------------------------------------------------------------------------
@spaces.GPU()
def generate(
control_sequence,
ref_image,
seed,
max_points,
mode,
noise_aug,
quantize_sketch,
tracker_grid_size,
tracker_backward_tracking,
progress=gr.Progress(track_tqdm=True),
):
del progress
if not control_sequence:
raise gr.Error("Please provide a control video.")
if not ref_image:
raise gr.Error("Please provide a reference image.")
models = None
output_dir = f"results_{uuid.uuid4()}"
use_tracking = mode == "Tracking"
repeat_matching = mode == "Repeat matching"
try:
models = load_runtime_models()
output_path = run_anidoc(
models=models,
control_path=control_sequence,
ref_path=ref_image,
output_dir=output_dir,
width=DEFAULT_WIDTH,
height=DEFAULT_HEIGHT,
num_frames=DEFAULT_NUM_FRAMES,
seed=int(seed),
max_points=int(max_points),
noise_aug=float(noise_aug),
use_tracking=use_tracking,
repeat_matching=repeat_matching,
quantize_sketch=bool(quantize_sketch),
tracker_grid_size=int(tracker_grid_size),
tracker_backward_tracking=bool(tracker_backward_tracking),
)
return output_path
except Exception as e:
raise gr.Error(f"Error during inference: {e}") from e
finally:
cleanup_runtime_models(models)
# -----------------------------------------------------------------------------
# UI
# -----------------------------------------------------------------------------
ensure_weights_downloaded()
CSS = """
div#col-container {
margin: 0 auto;
max-width: 982px;
}
.small-note {
font-size: 0.92em;
opacity: 0.9;
}
"""
with gr.Blocks(css=CSS) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# AniDoc: Animation Creation Made Easier")
gr.Markdown(
"AniDoc colorizes a sequence of sketches based on a character design "
"reference with high fidelity, even when the sketches significantly "
"differ in pose and scale."
)
gr.HTML(
"""
<div style="display:flex;column-gap:4px;">
<a href="https://github.com/yihao-meng/AniDoc">
<img src="https://img.shields.io/badge/GitHub-Repo-blue">
</a>
<a href="https://yihao-meng.github.io/AniDoc_demo/">
<img src="https://img.shields.io/badge/Project-Page-green">
</a>
<a href="https://arxiv.org/pdf/2412.14173">
<img src="https://img.shields.io/badge/ArXiv-Paper-red">
</a>
<a href="https://huggingface.co/spaces/fffiloni/AniDoc?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
</a>
<a href="https://huggingface.co/fffiloni">
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
</a>
</div>
"""
)
with gr.Row():
with gr.Column():
control_sequence = gr.Video(label="Control Sequence", format="mp4")
ref_image = gr.Image(label="Reference Image", type="filepath")
submit_btn = gr.Button("Submit", variant="primary")
with gr.Accordion("Advanced settings", open=False):
seed = gr.Number(
value=DEFAULT_SEED,
precision=0,
label="Seed",
info="Use the same seed to reproduce a similar result with the same inputs and settings.",
)
max_points = gr.Slider(
minimum=5,
maximum=30,
value=DEFAULT_MAX_POINTS,
step=1,
label="Tracking points",
info="Maximum number of motion points kept for correspondence tracking. More points can improve motion guidance but may add instability or compute cost.",
)
mode = gr.Radio(
choices=["Tracking", "Matching only", "Repeat matching"],
value="Tracking",
label="Motion guidance mode",
info="Tracking usually gives the best temporal consistency. Matching only is simpler. Repeat matching reuses the same matching signal across frames.",
)
noise_aug = gr.Slider(
minimum=0.0,
maximum=0.10,
value=DEFAULT_NOISE_AUG,
step=0.01,
label="Noise augmentation",
info="Controls how much noise is added before generation. Small changes can affect stylization and stability.",
)
quantize_sketch = gr.Checkbox(
value=True,
label="Quantize extracted sketch",
info="Turns the extracted line art into a stronger black-and-white sketch. Usually helps keep lines cleaner and more stable.",
)
tracker_grid_size = gr.Slider(
minimum=4,
maximum=16,
value=8,
step=1,
label="Tracker grid size",
info="Controls the density of the tracking grid. Smaller values can track more densely but may cost more computation.",
)
tracker_backward_tracking = gr.Checkbox(
value=False,
label="Backward tracking",
info="Tracks points in both directions instead of only forward. Can sometimes improve difficult motion, but may be slower.",
)
with gr.Column():
video_result = gr.Video(label="Result")
gr.Examples(
examples=[
["data_test/sample1.mp4", "custom_examples/sample1.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False],
["data_test/sample2.mp4", "custom_examples/sample2.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False],
["data_test/sample3.mp4", "custom_examples/sample3.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False],
["data_test/sample4.mp4", "custom_examples/sample4.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False],
],
inputs=[
control_sequence,
ref_image,
seed,
max_points,
mode,
noise_aug,
quantize_sketch,
tracker_grid_size,
tracker_backward_tracking,
],
)
gr.Markdown(
'<div class="small-note">'
"<b>Tips:</b> Tracking mode is the best default choice. "
"If the result feels too unstable, try lowering Tracking points or disabling Backward tracking. "
"If the sketch looks too soft, keep sketch quantization enabled."
"</div>"
)
submit_btn.click(
fn=generate,
inputs=[
control_sequence,
ref_image,
seed,
max_points,
mode,
noise_aug,
quantize_sketch,
tracker_grid_size,
tracker_backward_tracking,
],
outputs=[video_result],
api_visibility="private",
)
demo.queue().launch(ssr_mode=False, show_error=True)