| from typing import List, Union |
|
|
| import av |
| import numpy as np |
| import torch |
| from diffusers.modular_pipelines import ( |
| ComponentSpec, |
| InputParam, |
| ModularPipelineBlocks, |
| OutputParam, |
| PipelineState, |
| ) |
| from matplotlib import colormaps |
| from PIL import Image |
| from transformers import DepthProForDepthEstimation, DepthProImageProcessor |
|
|
| TURBO_CMAP = colormaps["turbo"] |
|
|
|
|
| def save_video(frames: List[Image.Image], fps: float, output_path: str) -> None: |
| """Save a list of PIL Image frames as an MP4 video.""" |
| container = av.open(output_path, mode="w") |
| stream = container.add_stream("libx264", rate=int(fps)) |
| stream.pix_fmt = "yuv420p" |
| stream.width = frames[0].width |
| stream.height = frames[0].height |
|
|
| for frame in frames: |
| video_frame = av.VideoFrame.from_image(frame) |
| for packet in stream.encode(video_frame): |
| container.mux(packet) |
|
|
| for packet in stream.encode(): |
| container.mux(packet) |
| container.close() |
|
|
|
|
| class DepthProEstimatorBlock(ModularPipelineBlocks): |
| _requirements = { |
| "transformers": ">=5.1.0", |
| "torch": ">=2.9.0", |
| "torchvision": ">=0.16.0", |
| "av": ">=12.0.0", |
| "matplotlib": ">=3.7.0", |
| } |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ |
| ComponentSpec( |
| name="depth_estimator", |
| type_hint=DepthProForDepthEstimation, |
| pretrained_model_name_or_path="apple/DepthPro-hf", |
| ), |
| ComponentSpec( |
| name="depth_estimator_processor", |
| type_hint=DepthProImageProcessor, |
| pretrained_model_name_or_path="apple/DepthPro-hf", |
| ), |
| ] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "image", |
| type_hint=Union[Image.Image, List[Image.Image]], |
| required=False, |
| description="Image(s) to estimate depth for", |
| ), |
| InputParam( |
| "video_path", |
| type_hint=str, |
| required=False, |
| description="Path to input video file. When provided, image is ignored.", |
| ), |
| InputParam( |
| "output_type", |
| type_hint=str, |
| default="depth_image", |
| description="Output type: 'depth_image', 'depth_tensor', or 'depth_and_fov'", |
| ), |
| InputParam( |
| "colormap", |
| type_hint=str, |
| default="grayscale", |
| description="Depth visualization format: 'grayscale' or 'turbo' (colormapped)", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "depth_image", |
| type_hint=Image.Image, |
| description="Normalized depth map as a grayscale PIL image (single image mode)", |
| ), |
| OutputParam( |
| "predicted_depth", |
| type_hint=torch.Tensor, |
| description="Raw metric depth tensor (H x W) (single image mode)", |
| ), |
| OutputParam( |
| "field_of_view", |
| type_hint=float, |
| description="Estimated horizontal field of view (single image mode)", |
| ), |
| OutputParam( |
| "focal_length", |
| type_hint=float, |
| description="Estimated focal length (single image mode)", |
| ), |
| OutputParam( |
| "depth_frames", |
| type_hint=list, |
| description="List of per-frame depth PIL images (video mode)", |
| ), |
| OutputParam( |
| "fps", |
| type_hint=float, |
| description="Source video frame rate (video mode)", |
| ), |
| ] |
|
|
| def _estimate_depth(self, image: Image.Image, processor, model) -> np.ndarray: |
| inputs = processor(images=[image], return_tensors="pt").to(model.device) |
| outputs = model(**inputs) |
| post_processed = processor.post_process_depth_estimation( |
| outputs, target_sizes=[(image.height, image.width)] |
| ) |
| return post_processed[0] |
|
|
| def _normalize_depth(self, depth: np.ndarray) -> np.ndarray: |
| inverse_depth = 1.0 / np.clip(depth, 0.1, 250.0) |
| inv_min = inverse_depth.min() |
| inv_max = inverse_depth.max() |
| return (inverse_depth - inv_min) / (inv_max - inv_min + 1e-8) |
|
|
| def _apply_colormap(self, normalized: np.ndarray, mode: str) -> np.ndarray: |
| if mode == "turbo": |
| colored = (TURBO_CMAP(normalized)[..., :3] * 255).astype(np.uint8) |
| return colored |
| return (normalized * 255.0).astype(np.uint8) |
|
|
| def _process_video(self, video_path, processor, model, colormap): |
| input_container = av.open(video_path) |
| video_stream = input_container.streams.video[0] |
| fps = video_stream.average_rate |
|
|
| depth_frames = [] |
| for frame in input_container.decode(video=0): |
| pil_image = frame.to_image().convert("RGB") |
|
|
| result = self._estimate_depth(pil_image, processor, model) |
| depth_np = result["predicted_depth"].float().cpu().numpy() |
| normalized = self._normalize_depth(depth_np) |
| colored = self._apply_colormap(normalized, colormap) |
|
|
| if colormap == "turbo": |
| depth_frame = Image.fromarray(colored, mode="RGB") |
| else: |
| depth_frame = Image.fromarray(colored, mode="L") |
| depth_frames.append(depth_frame) |
|
|
| input_container.close() |
|
|
| return depth_frames, fps |
|
|
| @torch.no_grad() |
| def __call__(self, components, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| processor = components.depth_estimator_processor |
| model = components.depth_estimator |
|
|
| video_path = getattr(block_state, "video_path", None) |
|
|
| if video_path: |
| depth_frames, fps = self._process_video( |
| video_path, processor, model, block_state.colormap |
| ) |
| block_state.depth_frames = depth_frames |
| block_state.fps = float(fps) |
| block_state.depth_image = None |
| block_state.predicted_depth = None |
| block_state.field_of_view = None |
| block_state.focal_length = None |
| else: |
| image = block_state.image |
| if not isinstance(image, list): |
| image = [image] |
|
|
| result = self._estimate_depth(image[0], processor, model) |
| predicted_depth = result["predicted_depth"] |
|
|
| block_state.predicted_depth = predicted_depth |
| block_state.field_of_view = result.get("field_of_view") |
| block_state.focal_length = result.get("focal_length") |
|
|
| depth_np = predicted_depth.float().cpu().numpy() |
| normalized = self._normalize_depth(depth_np) |
| colored = self._apply_colormap(normalized, block_state.colormap) |
| if block_state.colormap == "turbo": |
| block_state.depth_image = Image.fromarray(colored, mode="RGB") |
| else: |
| block_state.depth_image = Image.fromarray(colored, mode="L") |
|
|
| block_state.depth_frames = None |
| block_state.fps = None |
|
|
| self.set_block_state(state, block_state) |
|
|
| return components, state |
|
|