Spaces:
Running on Zero
Running on Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from typing import Optional, Sequence, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from sapiens.registry import MODELS | |
| from .base_preprocessor import BasePreprocessor | |
| class ImagePreprocessor(BasePreprocessor): | |
| def __init__( | |
| self, | |
| mean: Optional[Sequence[Union[float, int]]] = None, | |
| std: Optional[Sequence[Union[float, int]]] = None, | |
| pad_size_divisor: int = 1, | |
| pad_value: Union[float, int] = 0, | |
| bgr_to_rgb: bool = False, | |
| rgb_to_bgr: bool = False, | |
| non_blocking: Optional[bool] = False, | |
| ): | |
| super().__init__(non_blocking) | |
| self._validate_params(mean, std, bgr_to_rgb, rgb_to_bgr) | |
| self._setup_normalization(mean, std) | |
| self._channel_conversion = bgr_to_rgb or rgb_to_bgr | |
| self.pad_size_divisor = pad_size_divisor | |
| self.pad_value = pad_value | |
| def _validate_params(self, mean, std, bgr_to_rgb, rgb_to_bgr): | |
| if bgr_to_rgb and rgb_to_bgr: | |
| raise ValueError("Cannot set both bgr_to_rgb and rgb_to_bgr to True") | |
| if (mean is None) != (std is None): | |
| raise ValueError("mean and std must both be None or both be provided") | |
| def _setup_normalization(self, mean, std): | |
| if mean is None: | |
| self._enable_normalize = False | |
| return | |
| if len(mean) not in [1, 3] or len(std) not in [1, 3]: | |
| raise ValueError("mean and std must have 1 or 3 values") | |
| self._enable_normalize = True | |
| self.register_buffer("mean", torch.tensor(mean).view(-1, 1, 1), False) | |
| self.register_buffer("std", torch.tensor(std).view(-1, 1, 1), False) | |
| def _process_single_image(self, img: torch.Tensor) -> torch.Tensor: | |
| if img.dtype not in [torch.uint8, torch.float16, torch.float32, torch.float64]: | |
| raise TypeError(f"Unsupported image dtype: {img.dtype}") | |
| # Handle batched input (NCHW) | |
| if img.dim() == 4: | |
| if img.shape[1] != 3: | |
| raise ValueError(f"Expected 3 channels in dim=1, got {img.shape}") | |
| img = img.float() | |
| if self._channel_conversion: | |
| img = img[:, [2, 1, 0], ...] # BGR<->RGB | |
| if self._enable_normalize: | |
| img = (img - self.mean[None]) / self.std[None] | |
| return img | |
| # Handle single image (CHW) | |
| elif img.dim() == 3: | |
| if img.shape[0] != 3: | |
| raise ValueError(f"Expected 3 channels in dim=0, got {img.shape}") | |
| img = img.float() | |
| if self._channel_conversion: | |
| img = img[[2, 1, 0], ...] | |
| if self._enable_normalize: | |
| img = (img - self.mean) / self.std | |
| return img | |
| else: | |
| raise ValueError(f"Expected 3D or 4D tensor, got shape {img.shape}") | |
| def _pad_tensor(self, tensor: torch.Tensor) -> torch.Tensor: | |
| if self.pad_size_divisor <= 1: | |
| return tensor | |
| h, w = tensor.shape[-2:] | |
| target_h = math.ceil(h / self.pad_size_divisor) * self.pad_size_divisor | |
| target_w = math.ceil(w / self.pad_size_divisor) * self.pad_size_divisor | |
| pad_h = target_h - h | |
| pad_w = target_w - w | |
| if pad_h == 0 and pad_w == 0: | |
| return tensor | |
| return F.pad(tensor, (0, pad_w, 0, pad_h), "constant", self.pad_value) | |
| def forward(self, data: dict) -> dict: | |
| data = self.cast_data(data, device=self.mean.device) | |
| inputs = data["inputs"] | |
| if self.is_seq_of(inputs, torch.Tensor): | |
| # Process list of individual images | |
| processed_imgs = [self._process_single_image(img) for img in inputs] | |
| batch_inputs = self.stack_batch( | |
| processed_imgs, self.pad_size_divisor, self.pad_value | |
| ) | |
| elif isinstance(inputs, torch.Tensor): | |
| # Process batched tensor | |
| if inputs.dim() == 4: | |
| batch_inputs = self._process_single_image(inputs) | |
| batch_inputs = self._pad_tensor(batch_inputs) | |
| elif inputs.dim() == 5: | |
| # inputs: (B, V, C, H, W) | |
| B, V, C, H, W = inputs.shape | |
| flat_inputs = inputs.view(B * V, C, H, W) | |
| processed = self._process_single_image(flat_inputs) | |
| processed = self._pad_tensor(processed) | |
| batch_inputs = processed.view( | |
| B, V, C, processed.shape[-2], processed.shape[-1] | |
| ) | |
| elif inputs.dim() == 3: | |
| # Single image (C, H, W), unsqueeze to (1, C, H, W) | |
| img = inputs.unsqueeze(0) | |
| processed = self._process_single_image(img) | |
| batch_inputs = self._pad_tensor(processed) | |
| else: | |
| raise ValueError( | |
| f"Expected 3D, 4D or 5D tensor, got shape {inputs.shape}" | |
| ) | |
| else: | |
| raise TypeError(f"Expected tensor or list of tensors, got {type(inputs)}") | |
| data["inputs"] = batch_inputs | |
| data.setdefault("data_samples", None) | |
| return data | |