Rawal Khirodkar
Pin Python 3.10 + torch 2.1.2; vendor sapiens2 to bypass requires-python
5f5f544
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional, Sequence, Union
import torch
import torch.nn.functional as F
from sapiens.registry import MODELS
from .base_preprocessor import BasePreprocessor
@MODELS.register_module()
class ImagePreprocessor(BasePreprocessor):
def __init__(
self,
mean: Optional[Sequence[Union[float, int]]] = None,
std: Optional[Sequence[Union[float, int]]] = None,
pad_size_divisor: int = 1,
pad_value: Union[float, int] = 0,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
non_blocking: Optional[bool] = False,
):
super().__init__(non_blocking)
self._validate_params(mean, std, bgr_to_rgb, rgb_to_bgr)
self._setup_normalization(mean, std)
self._channel_conversion = bgr_to_rgb or rgb_to_bgr
self.pad_size_divisor = pad_size_divisor
self.pad_value = pad_value
def _validate_params(self, mean, std, bgr_to_rgb, rgb_to_bgr):
if bgr_to_rgb and rgb_to_bgr:
raise ValueError("Cannot set both bgr_to_rgb and rgb_to_bgr to True")
if (mean is None) != (std is None):
raise ValueError("mean and std must both be None or both be provided")
def _setup_normalization(self, mean, std):
if mean is None:
self._enable_normalize = False
return
if len(mean) not in [1, 3] or len(std) not in [1, 3]:
raise ValueError("mean and std must have 1 or 3 values")
self._enable_normalize = True
self.register_buffer("mean", torch.tensor(mean).view(-1, 1, 1), False)
self.register_buffer("std", torch.tensor(std).view(-1, 1, 1), False)
def _process_single_image(self, img: torch.Tensor) -> torch.Tensor:
if img.dtype not in [torch.uint8, torch.float16, torch.float32, torch.float64]:
raise TypeError(f"Unsupported image dtype: {img.dtype}")
# Handle batched input (NCHW)
if img.dim() == 4:
if img.shape[1] != 3:
raise ValueError(f"Expected 3 channels in dim=1, got {img.shape}")
img = img.float()
if self._channel_conversion:
img = img[:, [2, 1, 0], ...] # BGR<->RGB
if self._enable_normalize:
img = (img - self.mean[None]) / self.std[None]
return img
# Handle single image (CHW)
elif img.dim() == 3:
if img.shape[0] != 3:
raise ValueError(f"Expected 3 channels in dim=0, got {img.shape}")
img = img.float()
if self._channel_conversion:
img = img[[2, 1, 0], ...]
if self._enable_normalize:
img = (img - self.mean) / self.std
return img
else:
raise ValueError(f"Expected 3D or 4D tensor, got shape {img.shape}")
def _pad_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
if self.pad_size_divisor <= 1:
return tensor
h, w = tensor.shape[-2:]
target_h = math.ceil(h / self.pad_size_divisor) * self.pad_size_divisor
target_w = math.ceil(w / self.pad_size_divisor) * self.pad_size_divisor
pad_h = target_h - h
pad_w = target_w - w
if pad_h == 0 and pad_w == 0:
return tensor
return F.pad(tensor, (0, pad_w, 0, pad_h), "constant", self.pad_value)
def forward(self, data: dict) -> dict:
data = self.cast_data(data, device=self.mean.device)
inputs = data["inputs"]
if self.is_seq_of(inputs, torch.Tensor):
# Process list of individual images
processed_imgs = [self._process_single_image(img) for img in inputs]
batch_inputs = self.stack_batch(
processed_imgs, self.pad_size_divisor, self.pad_value
)
elif isinstance(inputs, torch.Tensor):
# Process batched tensor
if inputs.dim() == 4:
batch_inputs = self._process_single_image(inputs)
batch_inputs = self._pad_tensor(batch_inputs)
elif inputs.dim() == 5:
# inputs: (B, V, C, H, W)
B, V, C, H, W = inputs.shape
flat_inputs = inputs.view(B * V, C, H, W)
processed = self._process_single_image(flat_inputs)
processed = self._pad_tensor(processed)
batch_inputs = processed.view(
B, V, C, processed.shape[-2], processed.shape[-1]
)
elif inputs.dim() == 3:
# Single image (C, H, W), unsqueeze to (1, C, H, W)
img = inputs.unsqueeze(0)
processed = self._process_single_image(img)
batch_inputs = self._pad_tensor(processed)
else:
raise ValueError(
f"Expected 3D, 4D or 5D tensor, got shape {inputs.shape}"
)
else:
raise TypeError(f"Expected tensor or list of tensors, got {type(inputs)}")
data["inputs"] = batch_inputs
data.setdefault("data_samples", None)
return data