| """ |
| Gradio demo for SHeaP (Self-Supervised Head Geometry Predictor). |
| Accepts video or image input and renders the SHEAP output overlayed. |
| """ |
|
|
| import os |
|
|
| |
| |
| |
| if "PYOPENGL_PLATFORM" not in os.environ: |
| |
| os.environ["PYOPENGL_PLATFORM"] = "egl" |
|
|
| import shutil |
| import subprocess |
| import tempfile |
| from pathlib import Path |
| from queue import Queue |
| from typing import Optional |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| import torch.hub |
| import torchvision.transforms.functional as TF |
| from PIL import Image |
| from torch.utils.data import DataLoader |
| import face_alignment |
|
|
| try: |
| import spaces |
| HAS_SPACES = True |
| except ImportError: |
| HAS_SPACES = False |
| |
| class spaces: |
| @staticmethod |
| def GPU(func): |
| return func |
|
|
| from demo import create_rendering_image |
| from sheap import load_sheap_model |
| from sheap.tiny_flame import TinyFlame, pose_components_to_rotmats |
| from video_demo import RenderingThread, VideoFrameDataset, _tensor_to_numpy_image |
| from sheap.fa_landmark_utils import detect_face_and_crop |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| sheap_model = None |
| flame = None |
| fa_model = None |
| c2w = torch.tensor([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.float32) |
|
|
|
|
| def initialize_models(): |
| """Initialize all models (called lazily on first use).""" |
| global sheap_model, flame, fa_model |
| |
| if sheap_model is not None: |
| return |
|
|
| print("Loading SHeaP model...", flush=True) |
| sheap_model = load_sheap_model(model_type="expressive").to(device) |
| sheap_model.eval() |
|
|
| print("Loading FLAME model...", flush=True) |
| flame_dir = Path("FLAME2020/") |
| flame = TinyFlame(flame_dir / "generic_model.pt", eyelids_ckpt=flame_dir / "eyelids.pt").to( |
| device |
| ) |
|
|
| print("Loading face alignment model...", flush=True) |
| |
| torch.hub.set_dir(str(Path(__file__).parent / "face_alignment_cache")) |
| fa_model = face_alignment.FaceAlignment( |
| face_alignment.LandmarksType.TWO_D, device=str(device), flip_input=False |
| ) |
|
|
| print("Models loaded successfully!", flush=True) |
|
|
|
|
| @spaces.GPU |
| def process_image(image: np.ndarray) -> Image.Image: |
| """ |
| Process a single image and return the rendered output. |
| |
| Args: |
| image: Input image as numpy array (RGB) |
| |
| Returns: |
| PIL Image with three views side-by-side (original, mesh, blended) |
| """ |
| |
| initialize_models() |
| |
| |
| image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 |
|
|
| |
| x0, y0, x1, y1 = detect_face_and_crop(image_tensor, fa_model, margin=0.9, shift_up=0.5) |
|
|
| |
| cropped_tensor = image_tensor[:, y0:y1, x0:x1] |
|
|
| |
| cropped_resized = TF.resize(cropped_tensor, [224, 224], antialias=True) |
|
|
| |
| img_tensor = cropped_resized.unsqueeze(0).to(device) |
|
|
| |
| cropped_for_render = TF.resize(cropped_tensor, [512, 512], antialias=True) |
|
|
| |
| with torch.no_grad(): |
| predictions = sheap_model(img_tensor) |
|
|
| |
| verts = flame( |
| shape=predictions["shape_from_facenet"], |
| expression=predictions["expr"], |
| pose=pose_components_to_rotmats(predictions), |
| eyelids=predictions["eyelids"], |
| translation=predictions["cam_trans"], |
| ) |
|
|
| |
| verts = verts.cpu() |
|
|
| |
| cropped_pil = TF.to_pil_image(cropped_for_render) |
|
|
| |
| combined = create_rendering_image( |
| original_image=cropped_pil, |
| verts=verts[0], |
| faces=flame.faces, |
| c2w=c2w, |
| output_size=512, |
| ) |
|
|
| return combined |
|
|
|
|
| @spaces.GPU |
| def process_video_frames(video_path: str, temp_dir: Path, progress=gr.Progress()): |
| """ |
| Process video frames with GPU (inference and rendering). |
| Returns fps and number of frames processed. |
| """ |
| |
| initialize_models() |
| |
| render_size = 512 |
| |
| dataset = VideoFrameDataset(video_path, fa_model) |
| dataloader = DataLoader(dataset, batch_size=8, num_workers=0) |
| fps = dataset.fps |
| num_frames = len(dataset) |
| |
| render_queue = Queue(maxsize=32) |
| num_render_workers = 4 |
| rendering_threads = [] |
| for _ in range(num_render_workers): |
| thread = RenderingThread(render_queue, temp_dir, flame.faces, c2w, render_size) |
| thread.start() |
| rendering_threads.append(thread) |
| progress(0, desc="Processing video frames...") |
| frame_idx = 0 |
| with torch.no_grad(): |
| for batch in dataloader: |
| images = batch["image"].to(device) |
| cropped_frames = batch["cropped_frame"] |
| |
| predictions = sheap_model(images) |
| verts = flame( |
| shape=predictions["shape_from_facenet"], |
| expression=predictions["expr"], |
| pose=pose_components_to_rotmats(predictions), |
| eyelids=predictions["eyelids"], |
| translation=predictions["cam_trans"], |
| ) |
| verts = verts.cpu() |
| for i in range(images.shape[0]): |
| cropped_frame = _tensor_to_numpy_image(cropped_frames[i]) |
| render_queue.put((frame_idx, cropped_frame, verts[i])) |
| frame_idx += 1 |
| progress( |
| 0.95 * frame_idx / num_frames, desc=f"Processing frame {frame_idx}/{num_frames}" |
| ) |
| |
| for _ in range(num_render_workers): |
| render_queue.put(None) |
| for thread in rendering_threads: |
| thread.join() |
| if frame_idx == 0: |
| raise ValueError("No frames were successfully processed!") |
| |
| return fps, frame_idx |
|
|
|
|
| def process_video(video_path: str, progress=gr.Progress()) -> str: |
| """ |
| Process a video and return path to the rendered output video. |
| """ |
| temp_dir = Path(tempfile.mkdtemp()) |
| try: |
| |
| fps, num_frames = process_video_frames(video_path, temp_dir, progress) |
| |
| |
| progress(0.95, desc="Encoding video...") |
| output_path = temp_dir / "output.mp4" |
| ffmpeg_cmd = [ |
| "ffmpeg", |
| "-y", |
| "-framerate", |
| str(fps), |
| "-i", |
| str(temp_dir / "frame_%06d.png"), |
| "-c:v", |
| "libx264", |
| "-pix_fmt", |
| "yuv420p", |
| "-crf", |
| "18", |
| str(output_path), |
| ] |
| subprocess.run(ffmpeg_cmd, check=True, capture_output=True) |
| progress(1.0, desc="Done!") |
| return str(output_path) |
| except Exception as e: |
| shutil.rmtree(temp_dir, ignore_errors=True) |
| raise e |
|
|
|
|
| def process_input(image: Optional[np.ndarray], video: Optional[str]): |
| """ |
| Process either image or video input. |
| |
| Args: |
| image: Input image (if provided) |
| video: Input video path (if provided) |
| |
| Returns: |
| Either an image or video path depending on input |
| """ |
| if image is not None: |
| return process_image(image), None |
| elif video is not None: |
| return None, process_video(video) |
| else: |
| raise ValueError("Please provide either an image or video!") |
|
|
|
|
| |
| |
| |
|
|
| |
| with gr.Blocks(title="SHeaP Demo") as demo: |
| gr.Markdown( |
| """ |
| # ๐ SHeaP: Self-Supervised Head Geometry Predictor ๐ |
| |
| Upload an image or video to predict head geometry and render a 3D mesh overlay! |
| |
| The output shows three views: |
| - **Left**: Original cropped face |
| - **Center**: Rendered FLAME mesh |
| - **Right**: Mesh overlaid on original |
| |
| [Project Page](https://nlml.github.io/sheap) | [Paper](https://arxiv.org/abs/2504.12292) | [GitHub](https://github.com/nlml/sheap) |
| """ |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("### Input") |
| image_input = gr.Image(label="Upload Image", type="numpy") |
| video_input = gr.Video(label="Upload Video") |
| process_btn = gr.Button("Process", variant="primary") |
|
|
| with gr.Column(): |
| gr.Markdown("### Output") |
| image_output = gr.Image(label="Rendered Image", type="pil") |
| video_output = gr.Video(label="Rendered Video") |
|
|
| gr.Markdown( |
| """ |
| ### Tips: |
| - For best results, use images/videos with clearly visible faces |
| - The model works best with frontal face views |
| - Video processing may take a few minutes depending on length |
| """ |
| ) |
|
|
| |
| process_btn.click( |
| fn=process_input, |
| inputs=[image_input, video_input], |
| outputs=[image_output, video_output], |
| ) |
|
|
| |
| gr.Examples( |
| examples=[ |
| ["example_images/00000206.jpg", None], |
| [None, "example_videos/dafoe.mp4"], |
| ], |
| inputs=[image_input, video_input], |
| outputs=[image_output, video_output], |
| fn=process_input, |
| cache_examples=False, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|