Spaces:
Running on Zero
Running on Zero
| def _patch_asyncio_event_loop_del(): | |
| """ | |
| Patch a noisy asyncio teardown issue sometimes seen in Spaces environments. | |
| In some runtime/container combinations, Python may try to close an already | |
| invalid file descriptor when the event loop is garbage-collected. We silence | |
| only that specific harmless case. | |
| """ | |
| try: | |
| import asyncio.base_events as base_events | |
| original_del = getattr(base_events.BaseEventLoop, "__del__", None) | |
| if original_del is None: | |
| return | |
| def patched_del(self): | |
| try: | |
| original_del(self) | |
| except ValueError as e: | |
| if "Invalid file descriptor" not in str(e): | |
| raise | |
| base_events.BaseEventLoop.__del__ = patched_del | |
| except Exception: | |
| pass | |
| _patch_asyncio_event_loop_del() | |
| import os | |
| import sys | |
| import gc | |
| import cv2 | |
| import types | |
| import uuid | |
| from glob import glob | |
| import gradio as gr | |
| import imageio.v2 as imageio | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torchvision.transforms as T | |
| from packaging import version | |
| from PIL import Image | |
| sys.path.insert(0, ".") | |
| # ----------------------------------------------------------------------------- | |
| # Compatibility shims | |
| # ----------------------------------------------------------------------------- | |
| import huggingface_hub | |
| if not hasattr(huggingface_hub, "cached_download"): | |
| from huggingface_hub import hf_hub_download as _hf_hub_download_compat | |
| huggingface_hub.cached_download = _hf_hub_download_compat | |
| import torchvision.transforms.functional as TVF | |
| if "torchvision.transforms.functional_tensor" not in sys.modules: | |
| functional_tensor = types.ModuleType("torchvision.transforms.functional_tensor") | |
| functional_tensor.rgb_to_grayscale = TVF.rgb_to_grayscale | |
| sys.modules["torchvision.transforms.functional_tensor"] = functional_tensor | |
| # ----------------------------------------------------------------------------- | |
| # Imports that rely on the shims | |
| # ----------------------------------------------------------------------------- | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| from diffusers.utils import load_image | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from LightGlue.lightglue import LightGlue, SuperPoint | |
| from LightGlue.lightglue.utils import rbd | |
| from cotracker.predictor import CoTrackerPredictor, sample_trajectories_with_ref | |
| from lineart_extractor.annotator.lineart import LineartDetector | |
| from models_diffusers.controlnet_svd import ControlNetSVDModel | |
| from models_diffusers.unet_spatio_temporal_condition import ( | |
| UNetSpatioTemporalConditionModel, | |
| ) | |
| from pipelines.AniDoc import AniDocPipeline | |
| from utils import ( | |
| extract_frames_from_video, | |
| export_gif_side_by_side_complete_ablation, | |
| generate_point_map, | |
| generate_point_map_frames, | |
| load_images_from_folder, | |
| safe_round, | |
| select_multiple_points, | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Constants | |
| # ----------------------------------------------------------------------------- | |
| PRETRAINED_SVD_DIR = "pretrained_weights/stable-video-diffusion-img2vid-xt" | |
| PRETRAINED_ANIDOC_DIR = "pretrained_weights/anidoc" | |
| PRETRAINED_CONTROLNET_DIR = "pretrained_weights/anidoc/controlnet" | |
| COTRACKER_PATH = "pretrained_weights/cotracker2.pth" | |
| DEFAULT_WIDTH = 512 | |
| DEFAULT_HEIGHT = 320 | |
| DEFAULT_NUM_FRAMES = 14 | |
| DEFAULT_SEED = 42 | |
| DEFAULT_MAX_POINTS = 10 | |
| DEFAULT_FPS = 7 | |
| DEFAULT_MOTION_BUCKET_ID = 127 | |
| DEFAULT_NOISE_AUG = 0.02 | |
| DEFAULT_DECODE_CHUNK_SIZE = 8 | |
| DEVICE = "cuda" | |
| DTYPE = torch.float16 | |
| # ----------------------------------------------------------------------------- | |
| # Startup downloads only | |
| # ----------------------------------------------------------------------------- | |
| def ensure_weights_downloaded() -> None: | |
| os.makedirs("pretrained_weights", exist_ok=True) | |
| os.makedirs(PRETRAINED_SVD_DIR, exist_ok=True) | |
| snapshot_download( | |
| repo_id="stabilityai/stable-video-diffusion-img2vid-xt", | |
| local_dir=PRETRAINED_SVD_DIR, | |
| ) | |
| snapshot_download( | |
| repo_id="Yhmeng1106/anidoc", | |
| local_dir="./pretrained_weights", | |
| ) | |
| hf_hub_download( | |
| repo_id="facebook/cotracker", | |
| filename="cotracker2.pth", | |
| local_dir="./pretrained_weights", | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # Runtime model loading (GPU-only, per run) | |
| # ----------------------------------------------------------------------------- | |
| def load_runtime_models(): | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is not available in the current execution context.") | |
| unet = UNetSpatioTemporalConditionModel.from_pretrained( | |
| PRETRAINED_ANIDOC_DIR, | |
| subfolder="unet", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| custom_resume=True, | |
| ) | |
| unet.to(DEVICE, DTYPE) | |
| controlnet = ControlNetSVDModel.from_pretrained(PRETRAINED_CONTROLNET_DIR) | |
| controlnet.to(DEVICE, DTYPE) | |
| if is_xformers_available(): | |
| import xformers # noqa: F401 | |
| _ = version.parse(xformers.__version__) | |
| unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise RuntimeError("xformers is not available. Make sure it is installed correctly.") | |
| pipe = AniDocPipeline.from_pretrained( | |
| PRETRAINED_SVD_DIR, | |
| unet=unet, | |
| controlnet=controlnet, | |
| low_cpu_mem_usage=False, | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| ) | |
| pipe.to(DEVICE) | |
| detector = LineartDetector(DEVICE) | |
| extractor = SuperPoint(max_num_keypoints=2000).eval().to(DEVICE) | |
| matcher = LightGlue(features="superpoint").eval().to(DEVICE) | |
| tracker = CoTrackerPredictor( | |
| checkpoint=COTRACKER_PATH, | |
| shift_grid=0, | |
| ) | |
| tracker.requires_grad_(False) | |
| tracker.to(DEVICE, dtype=torch.float32) | |
| return { | |
| "unet": unet, | |
| "controlnet": controlnet, | |
| "pipe": pipe, | |
| "detector": detector, | |
| "extractor": extractor, | |
| "matcher": matcher, | |
| "tracker": tracker, | |
| } | |
| def cleanup_runtime_models(models) -> None: | |
| try: | |
| if models: | |
| for value in models.values(): | |
| try: | |
| del value | |
| except Exception: | |
| pass | |
| finally: | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| try: | |
| torch.cuda.ipc_collect() | |
| except Exception: | |
| pass | |
| # ----------------------------------------------------------------------------- | |
| # Video metadata / final video writer | |
| # ----------------------------------------------------------------------------- | |
| def get_video_dimensions(video_path: str) -> tuple[int, int]: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Unable to open input video: {video_path}") | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| cap.release() | |
| if width <= 0 or height <= 0: | |
| raise ValueError(f"Invalid input video dimensions: {video_path}") | |
| return width, height | |
| def get_video_fps(video_path: str) -> float: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise ValueError(f"Unable to open input video: {video_path}") | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| cap.release() | |
| if fps <= 0: | |
| fps = DEFAULT_FPS | |
| return float(fps) | |
| def write_video_frames_to_mp4( | |
| video_frames, | |
| output_path: str, | |
| target_width: int, | |
| target_height: int, | |
| target_fps: float, | |
| ) -> str: | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| writer = imageio.get_writer( | |
| output_path, | |
| fps=target_fps, | |
| codec="libx264", | |
| format="FFMPEG", | |
| pixelformat="yuv420p", | |
| ) | |
| try: | |
| for frame in video_frames: | |
| if isinstance(frame, Image.Image): | |
| frame_np = np.array(frame.convert("RGB")) | |
| else: | |
| frame_np = np.asarray(frame) | |
| if frame_np.ndim == 2: | |
| frame_np = np.stack([frame_np] * 3, axis=-1) | |
| elif frame_np.shape[-1] == 4: | |
| frame_np = frame_np[..., :3] | |
| resized = cv2.resize( | |
| frame_np, | |
| (target_width, target_height), | |
| interpolation=cv2.INTER_LINEAR, | |
| ) | |
| writer.append_data(resized) | |
| finally: | |
| writer.close() | |
| return output_path | |
| # ----------------------------------------------------------------------------- | |
| # Inference helpers | |
| # ----------------------------------------------------------------------------- | |
| def load_control_images(control_path: str, width: int, height: int): | |
| if os.path.isdir(control_path): | |
| control_images = load_images_from_folder(control_path) | |
| elif control_path.lower().endswith(".mp4"): | |
| control_images = extract_frames_from_video(control_path) | |
| else: | |
| raise ValueError("Control Sequence must be a folder or an .mp4 file.") | |
| if not control_images: | |
| raise ValueError("No frames could be loaded from the control sequence.") | |
| return [img.resize((width, height)) for img in control_images] | |
| def build_sketch_condition( | |
| control_images, | |
| detector, | |
| width: int, | |
| height: int, | |
| quantize_sketch: bool, | |
| ): | |
| controlnet_images = [] | |
| for img in control_images: | |
| sketch = np.array(img) | |
| sketch = detector(sketch, coarse=False) | |
| sketch = np.repeat(sketch[:, :, np.newaxis], 3, axis=2) | |
| if quantize_sketch: | |
| sketch = (sketch > 200).astype(np.uint8) * 255 | |
| sketch = Image.fromarray(sketch).resize((width, height)) | |
| controlnet_images.append(sketch) | |
| sketch_tensors = [T.ToTensor()(img).unsqueeze(0) for img in controlnet_images] | |
| controlnet_sketch_condition = torch.cat(sketch_tensors, dim=0).unsqueeze(0).to( | |
| DEVICE, dtype=torch.float16 | |
| ) | |
| controlnet_sketch_condition = (controlnet_sketch_condition - 0.5) / 0.5 | |
| return controlnet_images, controlnet_sketch_condition | |
| def build_matching_or_tracking_condition( | |
| ref_image_pil, | |
| controlnet_images, | |
| controlnet_sketch_condition, | |
| extractor, | |
| matcher, | |
| tracker, | |
| width: int, | |
| height: int, | |
| num_frames: int, | |
| max_points: int, | |
| use_tracking: bool, | |
| tracker_grid_size: int, | |
| tracker_backward_tracking: bool, | |
| repeat_matching: bool, | |
| ): | |
| with torch.no_grad(): | |
| ref_img_value = T.ToTensor()(ref_image_pil).to(DEVICE, dtype=torch.float16) | |
| ref_img_value = ref_img_value.to(torch.float32) | |
| current_img = T.ToTensor()(controlnet_images[0]).to(DEVICE, dtype=torch.float16) | |
| current_img = current_img.to(torch.float32) | |
| feats0 = extractor.extract(ref_img_value) | |
| feats1 = extractor.extract(current_img) | |
| matches01 = matcher({"image0": feats0, "image1": feats1}) | |
| feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] | |
| matches = matches01["matches"] | |
| points0 = feats0["keypoints"][matches[..., 0]].cpu().numpy() | |
| points1 = feats1["keypoints"][matches[..., 1]].cpu().numpy() | |
| points0 = safe_round(points0, current_img.shape) | |
| points1 = safe_round(points1, current_img.shape) | |
| num_selected_points = min(50, points0.shape[0]) | |
| points0, points1 = select_multiple_points(points0, points1, num_selected_points) | |
| mask1, mask2 = generate_point_map( | |
| size=current_img.shape, | |
| coords0=points0, | |
| coords1=points1, | |
| ) | |
| point_map1 = torch.from_numpy(mask1).unsqueeze(0).unsqueeze(0).unsqueeze(0).to( | |
| DEVICE, dtype=torch.float16 | |
| ) | |
| point_map2 = torch.from_numpy(mask2).unsqueeze(0).unsqueeze(0).unsqueeze(0).to( | |
| DEVICE, dtype=torch.float16 | |
| ) | |
| point_map = torch.cat([point_map1, point_map2], dim=2) | |
| conditional_pixel_values = ref_img_value.unsqueeze(0).unsqueeze(0) | |
| conditional_pixel_values = (conditional_pixel_values - 0.5) / 0.5 | |
| point_map_with_ref = torch.cat([point_map, conditional_pixel_values], dim=2) | |
| if repeat_matching: | |
| matching_controlnet_image = point_map_with_ref.repeat(1, num_frames, 1, 1, 1) | |
| controlnet_condition = torch.cat( | |
| [controlnet_sketch_condition, matching_controlnet_image], dim=2 | |
| ) | |
| return controlnet_condition | |
| if use_tracking: | |
| video_for_tracker = (controlnet_sketch_condition * 0.5 + 0.5) * 255.0 | |
| queries = np.insert(points1, 0, 0, axis=1) | |
| queries = torch.from_numpy(queries).to(DEVICE, torch.float32).unsqueeze(0) | |
| if queries.shape[1] == 0: | |
| tracked_mask1 = np.zeros((height, width), dtype=np.uint8) | |
| tracked_mask2 = np.zeros((num_frames, height, width), dtype=np.uint8) | |
| else: | |
| pred_tracks, pred_visibility = tracker( | |
| video_for_tracker.to(dtype=torch.float32), | |
| queries=queries, | |
| grid_size=tracker_grid_size, | |
| grid_query_frame=0, | |
| backward_tracking=tracker_backward_tracking, | |
| ) | |
| pred_tracks_sampled, pred_visibility_sampled, points0_sampled = ( | |
| sample_trajectories_with_ref( | |
| pred_tracks.cpu(), | |
| pred_visibility.cpu(), | |
| torch.from_numpy(points0).unsqueeze(0).cpu(), | |
| max_points=max_points, | |
| motion_threshold=1, | |
| vis_threshold=3, | |
| ) | |
| ) | |
| if pred_tracks_sampled is None: | |
| tracked_mask1 = np.zeros((height, width), dtype=np.uint8) | |
| tracked_mask2 = np.zeros((num_frames, height, width), dtype=np.uint8) | |
| else: | |
| pred_tracks_sampled = pred_tracks_sampled.squeeze(0).cpu().numpy() | |
| pred_visibility_sampled = pred_visibility_sampled.squeeze(0).cpu().numpy() | |
| points0_sampled = points0_sampled.squeeze(0).cpu().numpy() | |
| for frame_idx in range(num_frames): | |
| pred_tracks_sampled[frame_idx] = safe_round( | |
| pred_tracks_sampled[frame_idx], | |
| current_img.shape, | |
| ) | |
| points0_sampled = safe_round(points0_sampled, current_img.shape) | |
| tracked_mask1, tracked_mask2 = generate_point_map_frames( | |
| size=current_img.shape, | |
| coords0=points0_sampled, | |
| coords1=pred_tracks_sampled, | |
| visibility=pred_visibility_sampled, | |
| ) | |
| tracked_point_map1 = ( | |
| torch.from_numpy(tracked_mask1) | |
| .unsqueeze(0) | |
| .unsqueeze(0) | |
| .repeat(1, num_frames, 1, 1, 1) | |
| .to(DEVICE, dtype=torch.float16) | |
| ) | |
| tracked_point_map2 = ( | |
| torch.from_numpy(tracked_mask2) | |
| .unsqueeze(0) | |
| .unsqueeze(2) | |
| .to(DEVICE, dtype=torch.float16) | |
| ) | |
| tracked_point_map = torch.cat([tracked_point_map1, tracked_point_map2], dim=2) | |
| conditional_pixel_values_repeat = conditional_pixel_values.repeat( | |
| 1, num_frames, 1, 1, 1 | |
| ) | |
| point_map_with_ref = torch.cat( | |
| [tracked_point_map, conditional_pixel_values_repeat], dim=2 | |
| ) | |
| controlnet_condition = torch.cat( | |
| [controlnet_sketch_condition, point_map_with_ref], dim=2 | |
| ) | |
| return controlnet_condition | |
| original_shape = list(point_map_with_ref.shape) | |
| new_shape = original_shape.copy() | |
| new_shape[1] = num_frames - 1 | |
| zero_tensor = torch.zeros(new_shape).to(DEVICE, dtype=torch.float16) | |
| matching_controlnet_image = torch.cat((point_map_with_ref, zero_tensor), dim=1) | |
| controlnet_condition = torch.cat( | |
| [controlnet_sketch_condition, matching_controlnet_image], dim=2 | |
| ) | |
| return controlnet_condition | |
| def run_anidoc( | |
| models, | |
| control_path: str, | |
| ref_path: str, | |
| output_dir: str, | |
| width: int = DEFAULT_WIDTH, | |
| height: int = DEFAULT_HEIGHT, | |
| num_frames: int = DEFAULT_NUM_FRAMES, | |
| seed: int = DEFAULT_SEED, | |
| max_points: int = DEFAULT_MAX_POINTS, | |
| noise_aug: float = DEFAULT_NOISE_AUG, | |
| use_tracking: bool = True, | |
| repeat_matching: bool = False, | |
| quantize_sketch: bool = True, | |
| tracker_grid_size: int = 8, | |
| tracker_backward_tracking: bool = False, | |
| ): | |
| pipe = models["pipe"] | |
| detector = models["detector"] | |
| extractor = models["extractor"] | |
| matcher = models["matcher"] | |
| tracker = models["tracker"] | |
| os.makedirs(output_dir, exist_ok=True) | |
| control_images = load_control_images(control_path, width=width, height=height) | |
| ref_image_pil = load_image(ref_path).resize((width, height)) | |
| controlnet_images, controlnet_sketch_condition = build_sketch_condition( | |
| control_images=control_images, | |
| detector=detector, | |
| width=width, | |
| height=height, | |
| quantize_sketch=quantize_sketch, | |
| ) | |
| controlnet_condition = build_matching_or_tracking_condition( | |
| ref_image_pil=ref_image_pil, | |
| controlnet_images=controlnet_images, | |
| controlnet_sketch_condition=controlnet_sketch_condition, | |
| extractor=extractor, | |
| matcher=matcher, | |
| tracker=tracker, | |
| width=width, | |
| height=height, | |
| num_frames=num_frames, | |
| max_points=max_points, | |
| use_tracking=use_tracking, | |
| tracker_grid_size=tracker_grid_size, | |
| tracker_backward_tracking=tracker_backward_tracking, | |
| repeat_matching=repeat_matching, | |
| ) | |
| input_fps = DEFAULT_FPS | |
| target_width = width | |
| target_height = height | |
| if control_path.lower().endswith(".mp4"): | |
| input_fps = get_video_fps(control_path) | |
| target_width, target_height = get_video_dimensions(control_path) | |
| generator = torch.manual_seed(seed) | |
| with torch.inference_mode(): | |
| video_frames = pipe( | |
| ref_image_pil, | |
| controlnet_condition, | |
| height=height, | |
| width=width, | |
| num_frames=num_frames, | |
| decode_chunk_size=DEFAULT_DECODE_CHUNK_SIZE, | |
| motion_bucket_id=DEFAULT_MOTION_BUCKET_ID, | |
| fps=input_fps, | |
| noise_aug_strength=noise_aug, | |
| generator=generator, | |
| ).frames[0] | |
| ref_base_name = os.path.splitext(os.path.basename(ref_path))[0] | |
| control_base_name = os.path.splitext(os.path.basename(control_path))[0] | |
| support_dir = os.path.join(output_dir, f"{ref_base_name}_{control_base_name}") | |
| os.makedirs(support_dir, exist_ok=True) | |
| out_file = support_dir + ".mp4" | |
| export_gif_side_by_side_complete_ablation( | |
| ref_image_pil, | |
| controlnet_images, | |
| video_frames, | |
| out_file.replace(".mp4", ".gif"), | |
| support_dir, | |
| 6, | |
| ) | |
| final_output_path = os.path.join( | |
| support_dir, f"{ref_base_name}_{control_base_name}_final.mp4" | |
| ) | |
| final_output_path = write_video_frames_to_mp4( | |
| video_frames=video_frames, | |
| output_path=final_output_path, | |
| target_width=target_width, | |
| target_height=target_height, | |
| target_fps=input_fps, | |
| ) | |
| if not os.path.exists(final_output_path): | |
| raise FileNotFoundError("Final output .mp4 file was not written.") | |
| return final_output_path | |
| # ----------------------------------------------------------------------------- | |
| # Gradio callback | |
| # ----------------------------------------------------------------------------- | |
| def generate( | |
| control_sequence, | |
| ref_image, | |
| seed, | |
| max_points, | |
| mode, | |
| noise_aug, | |
| quantize_sketch, | |
| tracker_grid_size, | |
| tracker_backward_tracking, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| del progress | |
| if not control_sequence: | |
| raise gr.Error("Please provide a control video.") | |
| if not ref_image: | |
| raise gr.Error("Please provide a reference image.") | |
| models = None | |
| output_dir = f"results_{uuid.uuid4()}" | |
| use_tracking = mode == "Tracking" | |
| repeat_matching = mode == "Repeat matching" | |
| try: | |
| models = load_runtime_models() | |
| output_path = run_anidoc( | |
| models=models, | |
| control_path=control_sequence, | |
| ref_path=ref_image, | |
| output_dir=output_dir, | |
| width=DEFAULT_WIDTH, | |
| height=DEFAULT_HEIGHT, | |
| num_frames=DEFAULT_NUM_FRAMES, | |
| seed=int(seed), | |
| max_points=int(max_points), | |
| noise_aug=float(noise_aug), | |
| use_tracking=use_tracking, | |
| repeat_matching=repeat_matching, | |
| quantize_sketch=bool(quantize_sketch), | |
| tracker_grid_size=int(tracker_grid_size), | |
| tracker_backward_tracking=bool(tracker_backward_tracking), | |
| ) | |
| return output_path | |
| except Exception as e: | |
| raise gr.Error(f"Error during inference: {e}") from e | |
| finally: | |
| cleanup_runtime_models(models) | |
| # ----------------------------------------------------------------------------- | |
| # UI | |
| # ----------------------------------------------------------------------------- | |
| ensure_weights_downloaded() | |
| CSS = """ | |
| div#col-container { | |
| margin: 0 auto; | |
| max-width: 982px; | |
| } | |
| .small-note { | |
| font-size: 0.92em; | |
| opacity: 0.9; | |
| } | |
| """ | |
| with gr.Blocks(css=CSS) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# AniDoc: Animation Creation Made Easier") | |
| gr.Markdown( | |
| "AniDoc colorizes a sequence of sketches based on a character design " | |
| "reference with high fidelity, even when the sketches significantly " | |
| "differ in pose and scale." | |
| ) | |
| gr.HTML( | |
| """ | |
| <div style="display:flex;column-gap:4px;"> | |
| <a href="https://github.com/yihao-meng/AniDoc"> | |
| <img src="https://img.shields.io/badge/GitHub-Repo-blue"> | |
| </a> | |
| <a href="https://yihao-meng.github.io/AniDoc_demo/"> | |
| <img src="https://img.shields.io/badge/Project-Page-green"> | |
| </a> | |
| <a href="https://arxiv.org/pdf/2412.14173"> | |
| <img src="https://img.shields.io/badge/ArXiv-Paper-red"> | |
| </a> | |
| <a href="https://huggingface.co/spaces/fffiloni/AniDoc?duplicate=true"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> | |
| </a> | |
| <a href="https://huggingface.co/fffiloni"> | |
| <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF"> | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| control_sequence = gr.Video(label="Control Sequence", format="mp4") | |
| ref_image = gr.Image(label="Reference Image", type="filepath") | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| with gr.Accordion("Advanced settings", open=False): | |
| seed = gr.Number( | |
| value=DEFAULT_SEED, | |
| precision=0, | |
| label="Seed", | |
| info="Use the same seed to reproduce a similar result with the same inputs and settings.", | |
| ) | |
| max_points = gr.Slider( | |
| minimum=5, | |
| maximum=30, | |
| value=DEFAULT_MAX_POINTS, | |
| step=1, | |
| label="Tracking points", | |
| info="Maximum number of motion points kept for correspondence tracking. More points can improve motion guidance but may add instability or compute cost.", | |
| ) | |
| mode = gr.Radio( | |
| choices=["Tracking", "Matching only", "Repeat matching"], | |
| value="Tracking", | |
| label="Motion guidance mode", | |
| info="Tracking usually gives the best temporal consistency. Matching only is simpler. Repeat matching reuses the same matching signal across frames.", | |
| ) | |
| noise_aug = gr.Slider( | |
| minimum=0.0, | |
| maximum=0.10, | |
| value=DEFAULT_NOISE_AUG, | |
| step=0.01, | |
| label="Noise augmentation", | |
| info="Controls how much noise is added before generation. Small changes can affect stylization and stability.", | |
| ) | |
| quantize_sketch = gr.Checkbox( | |
| value=True, | |
| label="Quantize extracted sketch", | |
| info="Turns the extracted line art into a stronger black-and-white sketch. Usually helps keep lines cleaner and more stable.", | |
| ) | |
| tracker_grid_size = gr.Slider( | |
| minimum=4, | |
| maximum=16, | |
| value=8, | |
| step=1, | |
| label="Tracker grid size", | |
| info="Controls the density of the tracking grid. Smaller values can track more densely but may cost more computation.", | |
| ) | |
| tracker_backward_tracking = gr.Checkbox( | |
| value=False, | |
| label="Backward tracking", | |
| info="Tracks points in both directions instead of only forward. Can sometimes improve difficult motion, but may be slower.", | |
| ) | |
| with gr.Column(): | |
| video_result = gr.Video(label="Result") | |
| gr.Examples( | |
| examples=[ | |
| ["data_test/sample1.mp4", "custom_examples/sample1.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False], | |
| ["data_test/sample2.mp4", "custom_examples/sample2.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False], | |
| ["data_test/sample3.mp4", "custom_examples/sample3.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False], | |
| ["data_test/sample4.mp4", "custom_examples/sample4.png", DEFAULT_SEED, 10, "Tracking", 0.02, True, 8, False], | |
| ], | |
| inputs=[ | |
| control_sequence, | |
| ref_image, | |
| seed, | |
| max_points, | |
| mode, | |
| noise_aug, | |
| quantize_sketch, | |
| tracker_grid_size, | |
| tracker_backward_tracking, | |
| ], | |
| ) | |
| gr.Markdown( | |
| '<div class="small-note">' | |
| "<b>Tips:</b> Tracking mode is the best default choice. " | |
| "If the result feels too unstable, try lowering Tracking points or disabling Backward tracking. " | |
| "If the sketch looks too soft, keep sketch quantization enabled." | |
| "</div>" | |
| ) | |
| submit_btn.click( | |
| fn=generate, | |
| inputs=[ | |
| control_sequence, | |
| ref_image, | |
| seed, | |
| max_points, | |
| mode, | |
| noise_aug, | |
| quantize_sketch, | |
| tracker_grid_size, | |
| tracker_backward_tracking, | |
| ], | |
| outputs=[video_result], | |
| api_visibility="private", | |
| ) | |
| demo.queue().launch(ssr_mode=False, show_error=True) |