# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ GradioPipeline: ActionMesh pipeline with Gradio progress tracking. This module provides a subclass of ActionMeshPipeline that adds progress callbacks for integration with Gradio's progress bar. """ from typing import Callable, Optional import torch import trimesh from actionmesh.io.video_input import ActionMeshInput from actionmesh.pipeline import ActionMeshPipeline ProgressCallback = Callable[[float, str], None] class GradioPipeline(ActionMeshPipeline): """ ActionMesh pipeline with Gradio progress tracking support. Progress breakdown: - 0% -> 10%: Anchor 3D generation (image_to_3d) - 10% -> 90%: Stage 1 - Flow matching denoising (step-by-step) - 90% -> 100%: Stage 2 - Mesh decoding (step-by-step) """ def __call__( self, input: ActionMeshInput, seed: int = 44, stage_0_steps: int | None = None, face_decimation: int | None = None, floaters_threshold: float | None = None, stage_1_steps: int | None = None, guidance_scales: list[float] | None = None, anchor_idx: int | None = None, progress_callback: Optional[ProgressCallback] = None, ) -> list[trimesh.Trimesh]: """Generate an animated mesh sequence with progress tracking.""" # Apply parameter overrides if stage_0_steps is not None: self.cfg.model.image_to_3D_denoiser.num_inference_steps = stage_0_steps if stage_1_steps is not None: self.scheduler.num_inference_steps = stage_1_steps if guidance_scales is not None: self.cf_guidance.guidance_scales = guidance_scales if face_decimation is not None: self.mesh_process.face_decimation = face_decimation if floaters_threshold is not None: self.mesh_process.floaters_threshold = floaters_threshold if anchor_idx is not None: self.cfg.anchor_idx = anchor_idx # -- Preprocessing: remove background input.frames = self.background_removal.process_images(input.frames) # -- Preprocessing: grouped cropping & padding input.frames = self.image_process.process_images(input.frames) with torch.inference_mode(): # -- Stage 0: generate anchor 3D mesh & latent from single frame latent_bank, mesh_bank = self.init_banks_from_anchor(input, seed) if progress_callback is not None: progress_callback(0.10, "Anchor 3D generated, starting Stage 1...") # Stage 1 callback: 10% -> 90% def stage1_callback( step: int, total_steps: int, window_idx: int, total_windows: int ) -> None: if progress_callback is not None: window_progress = (window_idx + step / total_steps) / total_windows progress_callback( 0.10 + 0.80 * window_progress, f"Stage 1: step {step}/{total_steps} ", ) # Stage 2 callback: 90% -> 100% def stage2_callback( step: int, total_steps: int, window_idx: int, total_windows: int ) -> None: if progress_callback is not None: window_progress = (window_idx + step / total_steps) / total_windows progress_callback( 0.90 + 0.10 * window_progress, f"Stage 2: step {step}/{total_steps} ", ) with torch.autocast(device_type="cuda", dtype=torch.bfloat16): # -- Stage I: denoise synchronized 3D latents latent_bank = self.generate_3d_latents( input, latent_bank=latent_bank, seed=seed, step_callback=stage1_callback, ) # -- Stage II: decode latents into mesh displacements mesh_bank = self.generate_mesh_animation( latent_bank=latent_bank, mesh_bank=mesh_bank, step_callback=stage2_callback, ) if progress_callback is not None: progress_callback(1.0, "Pipeline complete!") return mesh_bank.get_ordered(device="cpu")[0]