def _patch_asyncio_event_loop_del(): """ Patch a noisy asyncio teardown issue sometimes seen in Spaces environments. In some runtime/container combinations, Python may try to close an already invalid file descriptor when the event loop is garbage-collected. We silence only that specific harmless case. """ try: import asyncio.base_events as base_events original_del = getattr(base_events.BaseEventLoop, "__del__", None) if original_del is None: return def patched_del(self): try: original_del(self) except ValueError as e: if "Invalid file descriptor" not in str(e): raise base_events.BaseEventLoop.__del__ = patched_del except Exception: pass _patch_asyncio_event_loop_del() import os import sys import gc import cv2 import types import uuid from glob import glob import gradio as gr import imageio.v2 as imageio import numpy as np import spaces import torch import torchvision.transforms as T from packaging import version from PIL import Image sys.path.insert(0, ".") # ----------------------------------------------------------------------------- # Compatibility shims # ----------------------------------------------------------------------------- import huggingface_hub if not hasattr(huggingface_hub, "cached_download"): from huggingface_hub import hf_hub_download as _hf_hub_download_compat huggingface_hub.cached_download = _hf_hub_download_compat import torchvision.transforms.functional as TVF if "torchvision.transforms.functional_tensor" not in sys.modules: functional_tensor = types.ModuleType("torchvision.transforms.functional_tensor") functional_tensor.rgb_to_grayscale = TVF.rgb_to_grayscale sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor # ----------------------------------------------------------------------------- # Imports that rely on the shims # ----------------------------------------------------------------------------- from huggingface_hub import snapshot_download, hf_hub_download from diffusers.utils import load_image from diffusers.utils.import_utils import is_xformers_available from LightGlue.lightglue import LightGlue, SuperPoint from LightGlue.lightglue.utils import rbd from cotracker.predictor import CoTrackerPredictor, sample_trajectories_with_ref from lineart_extractor.annotator.lineart import LineartDetector from models_diffusers.controlnet_svd import ControlNetSVDModel from models_diffusers.unet_spatio_temporal_condition import ( UNetSpatioTemporalConditionModel, ) from pipelines.AniDoc import AniDocPipeline from utils import ( extract_frames_from_video, export_gif_side_by_side_complete_ablation, generate_point_map, generate_point_map_frames, load_images_from_folder, safe_round, select_multiple_points, ) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- PRETRAINED_SVD_DIR = "pretrained_weights/stable-video-diffusion-img2vid-xt" PRETRAINED_ANIDOC_DIR = "pretrained_weights/anidoc" PRETRAINED_CONTROLNET_DIR = "pretrained_weights/anidoc/controlnet" COTRACKER_PATH = "pretrained_weights/cotracker2.pth" DEFAULT_WIDTH = 512 DEFAULT_HEIGHT = 320 DEFAULT_NUM_FRAMES = 14 DEFAULT_SEED = 42 DEFAULT_MAX_POINTS = 10 DEFAULT_FPS = 7 DEFAULT_MOTION_BUCKET_ID = 127 DEFAULT_NOISE_AUG = 0.02 DEFAULT_DECODE_CHUNK_SIZE = 8 DEVICE = "cuda" DTYPE = torch.float16 # ----------------------------------------------------------------------------- # Startup downloads only # ----------------------------------------------------------------------------- def ensure_weights_downloaded() -> None: os.makedirs("pretrained_weights", exist_ok=True) os.makedirs(PRETRAINED_SVD_DIR, exist_ok=True) snapshot_download( repo_id="stabilityai/stable-video-diffusion-img2vid-xt", local_dir=PRETRAINED_SVD_DIR, ) snapshot_download( repo_id="Yhmeng1106/anidoc", local_dir="./pretrained_weights", ) hf_hub_download( repo_id="facebook/cotracker", filename="cotracker2.pth", local_dir="./pretrained_weights", ) # ----------------------------------------------------------------------------- # Runtime model loading (GPU-only, per run) # ----------------------------------------------------------------------------- def load_runtime_models(): if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available in the current execution context.") unet = UNetSpatioTemporalConditionModel.from_pretrained( PRETRAINED_ANIDOC_DIR, subfolder="unet", torch_dtype=torch.float16, low_cpu_mem_usage=True, custom_resume=True, ) unet.to(DEVICE, DTYPE) controlnet = ControlNetSVDModel.from_pretrained(PRETRAINED_CONTROLNET_DIR) controlnet.to(DEVICE, DTYPE) if is_xformers_available(): import xformers # noqa: F401 _ = version.parse(xformers.__version__) unet.enable_xformers_memory_efficient_attention() else: raise RuntimeError("xformers is not available. Make sure it is installed correctly.") pipe = AniDocPipeline.from_pretrained( PRETRAINED_SVD_DIR, unet=unet, controlnet=controlnet, low_cpu_mem_usage=False, torch_dtype=torch.float16, variant="fp16", ) pipe.to(DEVICE) detector = LineartDetector(DEVICE) extractor = SuperPoint(max_num_keypoints=2000).eval().to(DEVICE) matcher = LightGlue(features="superpoint").eval().to(DEVICE) tracker = CoTrackerPredictor( checkpoint=COTRACKER_PATH, shift_grid=0, ) tracker.requires_grad_(False) tracker.to(DEVICE, dtype=torch.float32) return { "unet": unet, "controlnet": controlnet, "pipe": pipe, "detector": detector, "extractor": extractor, "matcher": matcher, "tracker": tracker, } def cleanup_runtime_models(models) -> None: try: if models: for value in models.values(): try: del value except Exception: pass finally: gc.collect() if torch.cuda.is_available(): try: torch.cuda.empty_cache() except Exception: pass try: torch.cuda.ipc_collect() except Exception: pass # ----------------------------------------------------------------------------- # Video metadata / final video writer # ----------------------------------------------------------------------------- def get_video_dimensions(video_path: str) -> tuple[int, int]: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Unable to open input video: {video_path}") width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) cap.release() if width <= 0 or height <= 0: raise ValueError(f"Invalid input video dimensions: {video_path}") return width, height def get_video_fps(video_path: str) -> float: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Unable to open input video: {video_path}") fps = cap.get(cv2.CAP_PROP_FPS) cap.release() if fps <= 0: fps = DEFAULT_FPS return float(fps) def write_video_frames_to_mp4( video_frames, output_path: str, target_width: int, target_height: int, target_fps: float, ) -> str: os.makedirs(os.path.dirname(output_path), exist_ok=True) writer = imageio.get_writer( output_path, fps=target_fps, codec="libx264", format="FFMPEG", pixelformat="yuv420p", ) try: for frame in video_frames: if isinstance(frame, Image.Image): frame_np = np.array(frame.convert("RGB")) else: frame_np = np.asarray(frame) if frame_np.ndim == 2: frame_np = np.stack([frame_np] * 3, axis=-1) elif frame_np.shape[-1] == 4: frame_np = frame_np[..., :3] resized = cv2.resize( frame_np, (target_width, target_height), interpolation=cv2.INTER_LINEAR, ) writer.append_data(resized) finally: writer.close() return output_path # ----------------------------------------------------------------------------- # Inference helpers # ----------------------------------------------------------------------------- def load_control_images(control_path: str, width: int, height: int): if os.path.isdir(control_path): control_images = load_images_from_folder(control_path) elif control_path.lower().endswith(".mp4"): control_images = extract_frames_from_video(control_path) else: raise ValueError("Control Sequence must be a folder or an .mp4 file.") if not control_images: raise ValueError("No frames could be loaded from the control sequence.") return [img.resize((width, height)) for img in control_images] def build_sketch_condition( control_images, detector, width: int, height: int, quantize_sketch: bool, ): controlnet_images = [] for img in control_images: sketch = np.array(img) sketch = detector(sketch, coarse=False) sketch = np.repeat(sketch[:, :, np.newaxis], 3, axis=2) if quantize_sketch: sketch = (sketch > 200).astype(np.uint8) * 255 sketch = Image.fromarray(sketch).resize((width, height)) controlnet_images.append(sketch) sketch_tensors = [T.ToTensor()(img).unsqueeze(0) for img in controlnet_images] controlnet_sketch_condition = torch.cat(sketch_tensors, dim=0).unsqueeze(0).to( DEVICE, dtype=torch.float16 ) controlnet_sketch_condition = (controlnet_sketch_condition - 0.5) / 0.5 return controlnet_images, controlnet_sketch_condition def build_matching_or_tracking_condition( ref_image_pil, controlnet_images, controlnet_sketch_condition, extractor, matcher, tracker, width: int, height: int, num_frames: int, max_points: int, use_tracking: bool, tracker_grid_size: int, tracker_backward_tracking: bool, repeat_matching: bool, ): with torch.no_grad(): ref_img_value = T.ToTensor()(ref_image_pil).to(DEVICE, dtype=torch.float16) ref_img_value = ref_img_value.to(torch.float32) current_img = T.ToTensor()(controlnet_images[0]).to(DEVICE, dtype=torch.float16) current_img = current_img.to(torch.float32) feats0 = extractor.extract(ref_img_value) feats1 = extractor.extract(current_img) matches01 = matcher({"image0": feats0, "image1": feats1}) feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] matches = matches01["matches"] points0 = feats0["keypoints"][matches[..., 0]].cpu().numpy() points1 = feats1["keypoints"][matches[..., 1]].cpu().numpy() points0 = safe_round(points0, current_img.shape) points1 = safe_round(points1, current_img.shape) num_selected_points = min(50, points0.shape[0]) points0, points1 = select_multiple_points(points0, points1, num_selected_points) mask1, mask2 = generate_point_map( size=current_img.shape, coords0=points0, coords1=points1, ) point_map1 = torch.from_numpy(mask1).unsqueeze(0).unsqueeze(0).unsqueeze(0).to( DEVICE, dtype=torch.float16 ) point_map2 = torch.from_numpy(mask2).unsqueeze(0).unsqueeze(0).unsqueeze(0).to( DEVICE, dtype=torch.float16 ) point_map = torch.cat([point_map1, point_map2], dim=2) conditional_pixel_values = ref_img_value.unsqueeze(0).unsqueeze(0) conditional_pixel_values = (conditional_pixel_values - 0.5) / 0.5 point_map_with_ref = torch.cat([point_map, conditional_pixel_values], dim=2) if repeat_matching: matching_controlnet_image = point_map_with_ref.repeat(1, num_frames, 1, 1, 1) controlnet_condition = torch.cat( [controlnet_sketch_condition, matching_controlnet_image], dim=2 ) return controlnet_condition if use_tracking: video_for_tracker = (controlnet_sketch_condition * 0.5 + 0.5) * 255.0 queries = np.insert(points1, 0, 0, axis=1) queries = torch.from_numpy(queries).to(DEVICE, torch.float32).unsqueeze(0) if queries.shape[1] == 0: tracked_mask1 = np.zeros((height, width), dtype=np.uint8) tracked_mask2 = np.zeros((num_frames, height, width), dtype=np.uint8) else: pred_tracks, pred_visibility = tracker( video_for_tracker.to(dtype=torch.float32), queries=queries, grid_size=tracker_grid_size, grid_query_frame=0, backward_tracking=tracker_backward_tracking, ) pred_tracks_sampled, pred_visibility_sampled, points0_sampled = ( sample_trajectories_with_ref( pred_tracks.cpu(), pred_visibility.cpu(), torch.from_numpy(points0).unsqueeze(0).cpu(), max_points=max_points, motion_threshold=1, vis_threshold=3, ) ) if pred_tracks_sampled is None: tracked_mask1 = np.zeros((height, width), dtype=np.uint8) tracked_mask2 = np.zeros((num_frames, height, width), dtype=np.uint8) else: pred_tracks_sampled = pred_tracks_sampled.squeeze(0).cpu().numpy() pred_visibility_sampled = pred_visibility_sampled.squeeze(0).cpu().numpy() points0_sampled = points0_sampled.squeeze(0).cpu().numpy() for frame_idx in range(num_frames): pred_tracks_sampled[frame_idx] = safe_round( pred_tracks_sampled[frame_idx], current_img.shape, ) points0_sampled = safe_round(points0_sampled, current_img.shape) tracked_mask1, tracked_mask2 = generate_point_map_frames( size=current_img.shape, coords0=points0_sampled, coords1=pred_tracks_sampled, visibility=pred_visibility_sampled, ) tracked_point_map1 = ( torch.from_numpy(tracked_mask1) .unsqueeze(0) .unsqueeze(0) .repeat(1, num_frames, 1, 1, 1) .to(DEVICE, dtype=torch.float16) ) tracked_point_map2 = ( torch.from_numpy(tracked_mask2) .unsqueeze(0) .unsqueeze(2) .to(DEVICE, dtype=torch.float16) ) tracked_point_map = torch.cat([tracked_point_map1, tracked_point_map2], dim=2) conditional_pixel_values_repeat = conditional_pixel_values.repeat( 1, num_frames, 1, 1, 1 ) point_map_with_ref = torch.cat( [tracked_point_map, conditional_pixel_values_repeat], dim=2 ) controlnet_condition = torch.cat( [controlnet_sketch_condition, point_map_with_ref], dim=2 ) return controlnet_condition original_shape = list(point_map_with_ref.shape) new_shape = original_shape.copy() new_shape[1] = num_frames - 1 zero_tensor = torch.zeros(new_shape).to(DEVICE, dtype=torch.float16) matching_controlnet_image = torch.cat((point_map_with_ref, zero_tensor), dim=1) controlnet_condition = torch.cat( [controlnet_sketch_condition, matching_controlnet_image], dim=2 ) return controlnet_condition def run_anidoc( models, control_path: str, ref_path: str, output_dir: str, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, num_frames: int = DEFAULT_NUM_FRAMES, seed: int = DEFAULT_SEED, max_points: int = DEFAULT_MAX_POINTS, noise_aug: float = DEFAULT_NOISE_AUG, use_tracking: bool = True, repeat_matching: bool = False, quantize_sketch: bool = True, tracker_grid_size: int = 8, tracker_backward_tracking: bool = False, ): pipe = models["pipe"] detector = models["detector"] extractor = models["extractor"] matcher = models["matcher"] tracker = models["tracker"] os.makedirs(output_dir, exist_ok=True) control_images = load_control_images(control_path, width=width, height=height) ref_image_pil = load_image(ref_path).resize((width, height)) controlnet_images, controlnet_sketch_condition = build_sketch_condition( control_images=control_images, detector=detector, width=width, height=height, quantize_sketch=quantize_sketch, ) controlnet_condition = build_matching_or_tracking_condition( ref_image_pil=ref_image_pil, controlnet_images=controlnet_images, controlnet_sketch_condition=controlnet_sketch_condition, extractor=extractor, matcher=matcher, tracker=tracker, width=width, height=height, num_frames=num_frames, max_points=max_points, use_tracking=use_tracking, tracker_grid_size=tracker_grid_size, tracker_backward_tracking=tracker_backward_tracking, repeat_matching=repeat_matching, ) input_fps = DEFAULT_FPS target_width = width target_height = height if control_path.lower().endswith(".mp4"): input_fps = get_video_fps(control_path) target_width, target_height = get_video_dimensions(control_path) generator = torch.manual_seed(seed) with torch.inference_mode(): video_frames = pipe( ref_image_pil, controlnet_condition, height=height, width=width, num_frames=num_frames, decode_chunk_size=DEFAULT_DECODE_CHUNK_SIZE, motion_bucket_id=DEFAULT_MOTION_BUCKET_ID, fps=input_fps, noise_aug_strength=noise_aug, generator=generator, ).frames[0] ref_base_name = os.path.splitext(os.path.basename(ref_path))[0] control_base_name = os.path.splitext(os.path.basename(control_path))[0] support_dir = os.path.join(output_dir, f"{ref_base_name}_{control_base_name}") os.makedirs(support_dir, exist_ok=True) out_file = support_dir + ".mp4" export_gif_side_by_side_complete_ablation( ref_image_pil, controlnet_images, video_frames, out_file.replace(".mp4", ".gif"), support_dir, 6, ) final_output_path = os.path.join( support_dir, f"{ref_base_name}_{control_base_name}_final.mp4" ) final_output_path = write_video_frames_to_mp4( video_frames=video_frames, output_path=final_output_path, target_width=target_width, target_height=target_height, target_fps=input_fps, ) if not os.path.exists(final_output_path): raise FileNotFoundError("Final output .mp4 file was not written.") return final_output_path # ----------------------------------------------------------------------------- # Gradio callback # ----------------------------------------------------------------------------- @spaces.GPU() def generate( control_sequence, ref_image, seed, max_points, mode, noise_aug, quantize_sketch, tracker_grid_size, tracker_backward_tracking, progress=gr.Progress(track_tqdm=True), ): del progress if not control_sequence: raise gr.Error("Please provide a control video.") if not ref_image: raise gr.Error("Please provide a reference image.") models = None output_dir = f"results_{uuid.uuid4()}" use_tracking = mode == "Tracking" repeat_matching = mode == "Repeat matching" try: models = load_runtime_models() output_path = run_anidoc( models=models, control_path=control_sequence, ref_path=ref_image, output_dir=output_dir, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, num_frames=DEFAULT_NUM_FRAMES, seed=int(seed), max_points=int(max_points), noise_aug=float(noise_aug), use_tracking=use_tracking, repeat_matching=repeat_matching, quantize_sketch=bool(quantize_sketch), tracker_grid_size=int(tracker_grid_size), tracker_backward_tracking=bool(tracker_backward_tracking), ) return output_path except Exception as e: raise gr.Error(f"Error during inference: {e}") from e finally: cleanup_runtime_models(models) # ----------------------------------------------------------------------------- # UI # ----------------------------------------------------------------------------- ensure_weights_downloaded() CSS = """ div#col-container { margin: 0 auto; max-width: 982px; } .small-note { font-size: 0.92em; opacity: 0.9; } """ with gr.Blocks(css=CSS) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# AniDoc: Animation Creation Made Easier") gr.Markdown( "AniDoc colorizes a sequence of sketches based on a character design " "reference with high fidelity, even when the sketches significantly " "differ in pose and scale." ) gr.HTML( """
""" ) with gr.Row(): with gr.Column(): control_sequence = gr.Video(label="Control Sequence", format="mp4") ref_image = gr.Image(label="Reference Image", type="filepath") submit_btn = gr.Button("Submit", variant="primary") with gr.Accordion("Advanced settings", open=False): seed = gr.Number( value=DEFAULT_SEED, precision=0, label="Seed", info="Use the same seed to reproduce a similar result with the same inputs and settings.", ) max_points = gr.Slider( minimum=5, maximum=30, value=DEFAULT_MAX_POINTS, step=1, label="Tracking points", info="Maximum number of motion points kept for correspondence tracking. More points can improve motion guidance but may add instability or compute cost.", ) mode = gr.Radio( choices=["Tracking", "Matching only", "Repeat matching"], value="Tracking", label="Motion guidance mode", info="Tracking usually gives the best temporal consistency. Matching only is simpler. Repeat matching reuses the same matching signal across frames.", ) noise_aug = gr.Slider( minimum=0.0, maximum=0.10, value=DEFAULT_NOISE_AUG, step=0.01, label="Noise augmentation", info="Controls how much noise is added before generation. Small changes can affect stylization and stability.", ) quantize_sketch = gr.Checkbox( value=True, label="Quantize extracted sketch", info="Turns the extracted line art into a stronger black-and-white sketch. Usually helps keep lines cleaner and more stable.", ) tracker_grid_size = gr.Slider( minimum=4, maximum=16, value=8, step=1, label="Tracker grid size", info="Controls the density of the tracking grid. Smaller values can track more densely but may cost more computation.", ) tracker_backward_tracking = gr.Checkbox( value=False, label="Backward tracking", info="Tracks points in both directions instead of only forward. Can sometimes improve difficult motion, but may be slower.", ) with gr.Column(): video_result = gr.Video(label="Result") gr.Examples( examples=[ ["data_test/sample1.mp4", "custom_examples/sample1.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False], ["data_test/sample2.mp4", "custom_examples/sample2.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False], ["data_test/sample3.mp4", "custom_examples/sample3.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False], ["data_test/sample4.mp4", "custom_examples/sample4.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False], ], inputs=[ control_sequence, ref_image, seed, max_points, mode, noise_aug, quantize_sketch, tracker_grid_size, tracker_backward_tracking, ], ) gr.Markdown( '