Spaces:
Running on Zero
Running on Zero
File size: 5,330 Bytes
5f5f544 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional, Sequence, Union
import torch
import torch.nn.functional as F
from sapiens.registry import MODELS
from .base_preprocessor import BasePreprocessor
@MODELS.register_module()
class ImagePreprocessor(BasePreprocessor):
def __init__(
self,
mean: Optional[Sequence[Union[float, int]]] = None,
std: Optional[Sequence[Union[float, int]]] = None,
pad_size_divisor: int = 1,
pad_value: Union[float, int] = 0,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
non_blocking: Optional[bool] = False,
):
super().__init__(non_blocking)
self._validate_params(mean, std, bgr_to_rgb, rgb_to_bgr)
self._setup_normalization(mean, std)
self._channel_conversion = bgr_to_rgb or rgb_to_bgr
self.pad_size_divisor = pad_size_divisor
self.pad_value = pad_value
def _validate_params(self, mean, std, bgr_to_rgb, rgb_to_bgr):
if bgr_to_rgb and rgb_to_bgr:
raise ValueError("Cannot set both bgr_to_rgb and rgb_to_bgr to True")
if (mean is None) != (std is None):
raise ValueError("mean and std must both be None or both be provided")
def _setup_normalization(self, mean, std):
if mean is None:
self._enable_normalize = False
return
if len(mean) not in [1, 3] or len(std) not in [1, 3]:
raise ValueError("mean and std must have 1 or 3 values")
self._enable_normalize = True
self.register_buffer("mean", torch.tensor(mean).view(-1, 1, 1), False)
self.register_buffer("std", torch.tensor(std).view(-1, 1, 1), False)
def _process_single_image(self, img: torch.Tensor) -> torch.Tensor:
if img.dtype not in [torch.uint8, torch.float16, torch.float32, torch.float64]:
raise TypeError(f"Unsupported image dtype: {img.dtype}")
# Handle batched input (NCHW)
if img.dim() == 4:
if img.shape[1] != 3:
raise ValueError(f"Expected 3 channels in dim=1, got {img.shape}")
img = img.float()
if self._channel_conversion:
img = img[:, [2, 1, 0], ...] # BGR<->RGB
if self._enable_normalize:
img = (img - self.mean[None]) / self.std[None]
return img
# Handle single image (CHW)
elif img.dim() == 3:
if img.shape[0] != 3:
raise ValueError(f"Expected 3 channels in dim=0, got {img.shape}")
img = img.float()
if self._channel_conversion:
img = img[[2, 1, 0], ...]
if self._enable_normalize:
img = (img - self.mean) / self.std
return img
else:
raise ValueError(f"Expected 3D or 4D tensor, got shape {img.shape}")
def _pad_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
if self.pad_size_divisor <= 1:
return tensor
h, w = tensor.shape[-2:]
target_h = math.ceil(h / self.pad_size_divisor) * self.pad_size_divisor
target_w = math.ceil(w / self.pad_size_divisor) * self.pad_size_divisor
pad_h = target_h - h
pad_w = target_w - w
if pad_h == 0 and pad_w == 0:
return tensor
return F.pad(tensor, (0, pad_w, 0, pad_h), "constant", self.pad_value)
def forward(self, data: dict) -> dict:
data = self.cast_data(data, device=self.mean.device)
inputs = data["inputs"]
if self.is_seq_of(inputs, torch.Tensor):
# Process list of individual images
processed_imgs = [self._process_single_image(img) for img in inputs]
batch_inputs = self.stack_batch(
processed_imgs, self.pad_size_divisor, self.pad_value
)
elif isinstance(inputs, torch.Tensor):
# Process batched tensor
if inputs.dim() == 4:
batch_inputs = self._process_single_image(inputs)
batch_inputs = self._pad_tensor(batch_inputs)
elif inputs.dim() == 5:
# inputs: (B, V, C, H, W)
B, V, C, H, W = inputs.shape
flat_inputs = inputs.view(B * V, C, H, W)
processed = self._process_single_image(flat_inputs)
processed = self._pad_tensor(processed)
batch_inputs = processed.view(
B, V, C, processed.shape[-2], processed.shape[-1]
)
elif inputs.dim() == 3:
# Single image (C, H, W), unsqueeze to (1, C, H, W)
img = inputs.unsqueeze(0)
processed = self._process_single_image(img)
batch_inputs = self._pad_tensor(processed)
else:
raise ValueError(
f"Expected 3D, 4D or 5D tensor, got shape {inputs.shape}"
)
else:
raise TypeError(f"Expected tensor or list of tensors, got {type(inputs)}")
data["inputs"] = batch_inputs
data.setdefault("data_samples", None)
return data
|