| import argparse |
| import os |
| import shutil |
| import subprocess |
| import threading |
| from pathlib import Path |
| from queue import Empty, Queue |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import torchvision.transforms.functional as TF |
| from PIL import Image |
| from torch.utils.data import DataLoader, IterableDataset |
| from tqdm import tqdm |
|
|
| from demo import create_rendering_image |
| from sheap import load_sheap_model |
| from sheap.tiny_flame import TinyFlame, pose_components_to_rotmats |
|
|
| try: |
| import face_alignment |
| except ImportError: |
| raise ImportError( |
| "The 'face_alignment' package is required. Please install it via 'pip install face-alignment'." |
| ) |
| from sheap.fa_landmark_utils import detect_face_and_crop |
|
|
|
|
| class RenderingThread(threading.Thread): |
| """Background thread for rendering frames to images.""" |
|
|
| def __init__( |
| self, |
| render_queue: Queue, |
| temp_dir: Path, |
| faces: torch.Tensor, |
| c2w: torch.Tensor, |
| render_size: int, |
| ): |
| """ |
| Initialize rendering thread. |
| |
| Args: |
| render_queue: Queue containing (frame_idx, cropped_frame, verts) tuples |
| temp_dir: Directory to save rendered images |
| faces: Face indices tensor from FLAME model |
| c2w: Camera-to-world transformation matrix |
| render_size: Size of each sub-image in the rendered output |
| """ |
| super().__init__(daemon=True) |
| self.render_queue = render_queue |
| self.temp_dir = temp_dir |
| self.faces = faces |
| self.c2w = c2w |
| self.render_size = render_size |
| self.stop_event = threading.Event() |
| self.frames_rendered = 0 |
|
|
| def run(self): |
| """Process rendering queue until stop signal is received.""" |
| |
| os.environ["PYOPENGL_PLATFORM"] = "egl" |
|
|
| while not self.stop_event.is_set(): |
| try: |
| |
| try: |
| item = self.render_queue.get(timeout=0.1) |
| except Empty: |
| continue |
| if item is None: |
| break |
|
|
| frame_idx, cropped_frame, verts = item |
| frame_idx, cropped_frame, verts = item |
|
|
| |
| cropped_pil = Image.fromarray(cropped_frame) |
| combined = create_rendering_image( |
| original_image=cropped_pil, |
| verts=verts, |
| faces=self.faces, |
| c2w=self.c2w, |
| output_size=self.render_size, |
| ) |
|
|
| |
| output_path = self.temp_dir / f"frame_{frame_idx:06d}.png" |
| combined.save(output_path) |
|
|
| self.frames_rendered += 1 |
| self.render_queue.task_done() |
|
|
| except Exception as e: |
| if not self.stop_event.is_set(): |
| print(f"Error rendering frame: {e}") |
| import traceback |
|
|
| traceback.print_exc() |
|
|
| def stop(self): |
| """Signal the thread to stop.""" |
| self.stop_event.set() |
|
|
|
|
| class VideoFrameDataset(IterableDataset): |
| """Iterable dataset for streaming video frames with face detection and cropping. |
| |
| Uses a background thread for video frame loading while face detection runs in the main thread. |
| """ |
|
|
| def __init__( |
| self, |
| video_path: str, |
| fa_model: face_alignment.FaceAlignment, |
| smoothing_alpha: float = 0.3, |
| frame_buffer_size: int = 32, |
| ): |
| """ |
| Initialize video frame dataset. |
| |
| Args: |
| video_path: Path to video file |
| fa_model: FaceAlignment model instance for face detection |
| smoothing_alpha: Smoothing factor for bounding box (0=no smoothing, 1=no change). |
| Lower values = more smoothing |
| frame_buffer_size: Size of the frame buffer queue for the background thread |
| """ |
| super().__init__() |
| self.video_path = video_path |
| self.fa_model = fa_model |
| self.smoothing_alpha = smoothing_alpha |
| self.frame_buffer_size = frame_buffer_size |
| self.prev_bbox: Optional[Tuple[int, int, int, int]] = None |
|
|
| |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| raise ValueError(f"Could not open video file: {video_path}") |
|
|
| self.fps = cap.get(cv2.CAP_PROP_FPS) |
| self.num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| self.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| self.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| cap.release() |
|
|
| print( |
| f"Video info: {self.num_frames} frames, {self.fps:.2f} fps, {self.width}x{self.height}" |
| ) |
|
|
| def _video_reader_thread(self, frame_queue: Queue, stop_event: threading.Event): |
| """Background thread that reads video frames and puts them in a queue. |
| |
| Args: |
| frame_queue: Queue to put (frame_idx, frame_rgb) tuples |
| stop_event: Event to signal thread to stop |
| """ |
| cap = cv2.VideoCapture(self.video_path) |
| if not cap.isOpened(): |
| frame_queue.put(("error", f"Could not open video file: {self.video_path}")) |
| return |
|
|
| frame_idx = 0 |
| try: |
| while not stop_event.is_set(): |
| ret, frame_bgr = cap.read() |
| if not ret: |
| break |
|
|
| |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) |
| |
| |
| frame_queue.put((frame_idx, frame_rgb)) |
| frame_idx += 1 |
|
|
| finally: |
| cap.release() |
| |
| frame_queue.put(None) |
|
|
| def __iter__(self): |
| """ |
| Iterate through video frames sequentially. |
| |
| Video frame loading happens in a background thread, while face detection |
| and processing happen in the main thread. |
| |
| Yields: |
| Dictionary containing frame_idx, processed image, and bounding box |
| """ |
| |
| self.prev_bbox = None |
|
|
| |
| frame_queue = Queue(maxsize=self.frame_buffer_size) |
| stop_event = threading.Event() |
| reader_thread = threading.Thread( |
| target=self._video_reader_thread, |
| args=(frame_queue, stop_event), |
| daemon=True |
| ) |
| reader_thread.start() |
|
|
| try: |
| while True: |
| |
| item = frame_queue.get() |
| |
| |
| if item is None: |
| break |
| |
| |
| if isinstance(item, tuple) and len(item) == 2 and item[0] == "error": |
| raise RuntimeError(item[1]) |
| |
| frame_idx, frame_rgb = item |
|
|
| |
| image = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0 |
|
|
| |
| bbox = detect_face_and_crop(image, self.fa_model, margin=0.9, shift_up=0.5) |
|
|
| |
| bbox = self._smooth_bbox(bbox) |
| x0, y0, x1, y1 = bbox |
|
|
| cropped = image[:, y0:y1, x0:x1] |
|
|
| |
| cropped_resized = TF.resize(cropped, [224, 224], antialias=True) |
| cropped_for_render = TF.resize(cropped, [512, 512], antialias=True) |
|
|
| yield { |
| "frame_idx": frame_idx, |
| "image": cropped_resized, |
| "bbox": bbox, |
| "original_frame": frame_rgb, |
| "cropped_frame": cropped_for_render, |
| } |
|
|
| finally: |
| |
| stop_event.set() |
| reader_thread.join(timeout=1.0) |
|
|
| def _smooth_bbox(self, bbox: Tuple[int, int, int, int]) -> Tuple[int, int, int, int]: |
| """Apply exponential moving average smoothing to bounding box.""" |
| if self.prev_bbox is None: |
| self.prev_bbox = bbox |
| return bbox |
|
|
| x0, y0, x1, y1 = bbox |
| prev_x0, prev_y0, prev_x1, prev_y1 = self.prev_bbox |
|
|
| |
| smoothed = ( |
| int(self.smoothing_alpha * x0 + (1 - self.smoothing_alpha) * prev_x0), |
| int(self.smoothing_alpha * y0 + (1 - self.smoothing_alpha) * prev_y0), |
| int(self.smoothing_alpha * x1 + (1 - self.smoothing_alpha) * prev_x1), |
| int(self.smoothing_alpha * y1 + (1 - self.smoothing_alpha) * prev_y1), |
| ) |
|
|
| self.prev_bbox = smoothed |
| return smoothed |
|
|
| def __len__(self) -> int: |
| return self.num_frames |
|
|
|
|
| def process_video( |
| video_path: str, |
| model_type: str = "expressive", |
| batch_size: int = 1, |
| num_workers: int = 0, |
| device: str = "cuda" if torch.cuda.is_available() else "cpu", |
| output_video_path: Optional[str] = None, |
| render_size: int = 512, |
| num_render_workers: int = 1, |
| max_queue_size: int = 128, |
| ) -> List[Dict[str, Any]]: |
| """ |
| Process video frames through SHEAP model and optionally render output video. |
| |
| Uses an IterableDataset for efficient sequential video processing without seeking overhead. |
| Rendering is done in a background thread, and ffmpeg is used to create the final video. |
| |
| Args: |
| video_path: Path to video file |
| model_type: SHEAP model variant ("paper", "expressive", or "lightweight") |
| batch_size: Batch size for processing |
| num_workers: Number of workers (0 or 1 only). Will be clamped to max 1. |
| device: Device to run model on ("cpu" or "cuda") |
| output_video_path: If provided, render and save output video to this path |
| render_size: Size of each sub-image in the rendered output |
| num_render_workers: Number of background threads for rendering |
| max_queue_size: Maximum size of the rendering queue |
| |
| Returns: |
| List of dictionaries containing frame index, bounding box, and FLAME parameters |
| """ |
| |
| num_workers = min(num_workers, 1) |
| if num_workers > 1: |
| print(f"Warning: num_workers > 1 not supported with IterableDataset. Using num_workers=1.") |
|
|
| |
| print(f"Loading SHEAP model (type: {model_type})...") |
| sheap_model = load_sheap_model(model_type=model_type) |
| sheap_model.eval() |
| sheap_model = sheap_model.to(device) |
|
|
| |
| |
| fa_device = "cpu" if num_workers >= 1 else device |
| print(f"Loading face alignment model on {fa_device}...") |
| fa_model = face_alignment.FaceAlignment( |
| face_alignment.LandmarksType.THREE_D, flip_input=False, device=fa_device |
| ) |
|
|
| |
| dataset = VideoFrameDataset(video_path, fa_model) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| pin_memory=torch.cuda.is_available(), |
| ) |
|
|
| print(f"Processing {len(dataset)} frames from {video_path}") |
|
|
| |
| flame = None |
| rendering_threads = [] |
| render_queue = None |
| temp_dir = None |
| c2w = None |
|
|
| if output_video_path: |
| print("Loading FLAME model for rendering...") |
| flame_dir = Path("FLAME2020/") |
| flame = TinyFlame(flame_dir / "generic_model.pt", eyelids_ckpt=flame_dir / "eyelids.pt") |
| flame = flame.to(device) |
| c2w = torch.tensor( |
| [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]], dtype=torch.float32 |
| ) |
|
|
| |
| temp_dir = Path("./temp_sheap_render/") |
| temp_dir.mkdir(parents=True, exist_ok=True) |
| print(f"Using temporary directory: {temp_dir}") |
|
|
| |
| render_queue = Queue(maxsize=max_queue_size) |
| for _ in range(num_render_workers): |
| thread = RenderingThread(render_queue, temp_dir, flame.faces, c2w, render_size) |
| thread.start() |
| rendering_threads.append(thread) |
| print(f"Started {num_render_workers} background rendering threads") |
|
|
| results = [] |
| frame_count = 0 |
|
|
| with torch.no_grad(): |
| progbar = tqdm(total=len(dataset), desc="Processing frames") |
| for batch in dataloader: |
| frame_indices = batch["frame_idx"] |
| images = batch["image"].to(device) |
| bboxes = batch["bbox"] |
|
|
| |
| flame_params_dict = sheap_model(images) |
|
|
| |
| if output_video_path and flame is not None: |
| verts = flame( |
| shape=flame_params_dict["shape_from_facenet"], |
| expression=flame_params_dict["expr"], |
| pose=pose_components_to_rotmats(flame_params_dict), |
| eyelids=flame_params_dict["eyelids"], |
| translation=flame_params_dict["cam_trans"], |
| ) |
|
|
| |
| for i in range(len(frame_indices)): |
| frame_idx = _extract_scalar(frame_indices[i]) |
| bbox = tuple(_extract_scalar(b[i]) for b in bboxes) |
|
|
| result = { |
| "frame_idx": frame_idx, |
| "bbox": bbox, |
| "flame_params": {k: v[i].cpu() for k, v in flame_params_dict.items()}, |
| } |
| results.append(result) |
|
|
| |
| if output_video_path: |
| cropped_frame = _tensor_to_numpy_image(batch["cropped_frame"][i]) |
| render_queue.put((frame_idx, cropped_frame, verts[i].cpu())) |
| frame_count += 1 |
|
|
| progbar.update(len(frame_indices)) |
| progbar.close() |
|
|
| |
| if output_video_path and render_queue is not None: |
| _finalize_rendering( |
| rendering_threads, |
| render_queue, |
| num_render_workers, |
| temp_dir, |
| dataset.fps, |
| output_video_path, |
| ) |
|
|
| return results |
|
|
|
|
| def _extract_scalar(value: Any) -> int: |
| """Extract scalar integer from tensor or return as-is.""" |
| return value.item() if isinstance(value, torch.Tensor) else value |
|
|
|
|
| def _tensor_to_numpy_image(tensor: torch.Tensor) -> np.ndarray: |
| """Convert (C, H, W) tensor [0, 1] to numpy (H, W, C) uint8 [0, 255].""" |
| if not isinstance(tensor, torch.Tensor): |
| return tensor |
| return (tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) |
|
|
|
|
| def _finalize_rendering( |
| rendering_threads: List[RenderingThread], |
| render_queue: Queue, |
| num_render_workers: int, |
| temp_dir: Path, |
| fps: float, |
| output_video_path: str, |
| ) -> None: |
| """Finish rendering threads and create final video with ffmpeg.""" |
| print("\nWaiting for rendering threads to complete...") |
|
|
| |
| for _ in range(num_render_workers): |
| render_queue.put(None) |
|
|
| |
| for thread in rendering_threads: |
| thread.join() |
|
|
| total_rendered = sum(thread.frames_rendered for thread in rendering_threads) |
| print(f"Rendered {total_rendered} frames") |
|
|
| |
| print("Creating video with ffmpeg...") |
| output_path = Path(output_video_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| ffmpeg_cmd = [ |
| "ffmpeg", |
| "-y", |
| "-framerate", |
| str(fps), |
| "-i", |
| str(temp_dir / "frame_%06d.png"), |
| "-c:v", |
| "libx264", |
| "-pix_fmt", |
| "yuv420p", |
| "-preset", |
| "medium", |
| "-crf", |
| "23", |
| str(output_path), |
| ] |
|
|
| subprocess.run(ffmpeg_cmd, check=True, capture_output=True) |
| print(f"Video saved to: {output_video_path}") |
|
|
| |
| if temp_dir.exists(): |
| print(f"Removing temporary directory: {temp_dir}") |
| shutil.rmtree(temp_dir) |
| print("Cleanup complete") |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| parser = argparse.ArgumentParser(description="Process and render video with SHEAP model.") |
| parser.add_argument("in_path", type=str, help="Path to input video file.") |
| parser.add_argument( |
| "--out_path", type=str, help="Path to save rendered output video.", default=None |
| ) |
| args = parser.parse_args() |
|
|
| if args.out_path is None: |
| args.out_path = str(Path(args.in_path).with_name(f"{Path(args.in_path).stem}_rendered.mp4")) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| results = process_video( |
| video_path=args.in_path, |
| model_type="expressive", |
| device=device, |
| output_video_path=args.out_path, |
| ) |
|
|