sapiens2-pose / sapiens /engine /datasets /transforms /common_transforms.py
Rawal Khirodkar
Pin Python 3.10 + torch 2.1.2; vendor sapiens2 to bypass requires-python
5f5f544
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
from typing import Dict, List, Optional, Sequence
import cv2
import numpy as np
import torchvision.transforms as T
from sapiens.registry import TRANSFORMS
from .base_transform import BaseTransform, to_tensor
@TRANSFORMS.register_module()
class ImageResize(BaseTransform):
def __init__(
self,
image_height: int,
image_width: int,
):
self.image_height = image_height
self.image_width = image_width
def transform(self, results: Dict) -> Optional[Dict]:
image = results["image"]
image = cv2.resize(
image, (self.image_width, self.image_height), interpolation=cv2.INTER_AREA
)
results["image"] = image
return results
@TRANSFORMS.register_module()
class ImagePackInputs(BaseTransform):
def __init__(self, meta_keys: List[str]):
self.meta_keys = meta_keys
self.to_tensor = T.ToTensor()
def transform(self, results: Dict) -> Optional[Dict]:
packed_results = dict()
raw_image = results["image"]
image = raw_image.copy()
if len(image.shape) < 3:
image = np.expand_dims(image, -1)
if not image.flags.c_contiguous:
image = to_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
else:
image = image.transpose(2, 0, 1)
image = to_tensor(image).contiguous()
packed_results["inputs"] = image
data_samples = dict()
# Pack the specified meta keys
for key in self.meta_keys:
if key in results:
data_samples[key] = results[key]
data_samples["image"] = self.to_tensor(raw_image)
packed_results["data_samples"] = data_samples
return packed_results
@TRANSFORMS.register_module()
class PhotoMetricDistortion(BaseTransform):
def __init__(
self,
brightness_delta: int = 32,
contrast_range: Sequence[float] = (0.5, 1.5),
saturation_range: Sequence[float] = (0.5, 1.5),
hue_delta: int = 18,
):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def convert(self, img: np.ndarray, alpha: int = 1, beta: int = 0) -> np.ndarray:
img = img.astype(np.float32) * alpha + beta
img = np.clip(img, 0, 255)
return img.astype(np.uint8)
def brightness(self, img: np.ndarray) -> np.ndarray:
if random.randint(0, 1):
return self.convert(
img, beta=random.uniform(-self.brightness_delta, self.brightness_delta)
)
return img
def contrast(self, img: np.ndarray) -> np.ndarray:
if random.randint(0, 1):
return self.convert(
img, alpha=random.uniform(self.contrast_lower, self.contrast_upper)
)
return img
def saturation(self, img: np.ndarray) -> np.ndarray:
if random.randint(0, 1):
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
img[:, :, 1] = self.convert(
img[:, :, 1],
alpha=random.uniform(self.saturation_lower, self.saturation_upper),
)
img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
return img
def hue(self, img: np.ndarray) -> np.ndarray:
if random.randint(0, 1):
img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
img[:, :, 0] = (
img[:, :, 0].astype(int)
+ random.randint(-self.hue_delta, self.hue_delta)
) % 180
img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR)
return img
def transform(self, results: dict) -> dict:
img = results["img"]
# random brightness
img = self.brightness(img)
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(0, 1)
if mode == 1:
img = self.contrast(img)
# random saturation
img = self.saturation(img)
# random hue
img = self.hue(img)
# random contrast
if mode == 0:
img = self.contrast(img)
results["img"] = img
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (
f"(brightness_delta={self.brightness_delta}, "
f"contrast_range=({self.contrast_lower}, "
f"{self.contrast_upper}), "
f"saturation_range=({self.saturation_lower}, "
f"{self.saturation_upper}), "
f"hue_delta={self.hue_delta})"
)
return repr_str
@TRANSFORMS.register_module()
class RandomPhotoMetricDistortion(PhotoMetricDistortion):
def __init__(
self,
prob: float = 0.5,
**kwargs,
):
super().__init__(**kwargs)
self.prob = prob
def transform(self, results: Dict) -> Optional[Dict]:
if np.random.rand() > self.prob:
return results
return super().transform(results)
@TRANSFORMS.register_module()
class RandomDownUpSampleImage(BaseTransform):
_INTERP_LIST = [
cv2.INTER_NEAREST,
cv2.INTER_LINEAR,
cv2.INTER_CUBIC,
cv2.INTER_AREA,
cv2.INTER_LANCZOS4,
]
def __init__(self, scale_range=(0.1, 0.5), prob=0.4):
super().__init__()
self.scale_range = scale_range
self.prob = prob
def transform(self, results: dict) -> dict:
if np.random.rand() > self.prob:
return results # Skip with probability (1 - prob)
img = results["img"]
orig_h, orig_w = img.shape[:2]
# Pick a random factor in [min_scale, max_scale]
min_scale, max_scale = self.scale_range
scale_factor = np.random.uniform(min_scale, max_scale)
# Randomly select interpolation modes for downsampling and upsampling
down_interp = random.choice(self._INTERP_LIST)
up_interp = random.choice(self._INTERP_LIST)
# Compute downsample size
down_w = max(1, int(orig_w * scale_factor))
down_h = max(1, int(orig_h * scale_factor))
# Downsample
img_down = cv2.resize(img, (down_w, down_h), interpolation=down_interp)
img_up = cv2.resize(img_down, (orig_w, orig_h), interpolation=up_interp)
# Replace the original image with the heavily down-up-sampled version
results["img"] = img_up
return results
def __repr__(self):
return (
f"{self.__class__.__name__}(scale_range={self.scale_range}, "
f"prob={self.prob})"
)
@TRANSFORMS.register_module()
class RandomGaussianBlur(BaseTransform):
def __init__(self, prob=0.4, kernel_size=(3, 3), sigma_range=(0.1, 2.0)):
super().__init__()
self.prob = prob
self.kernel_size = kernel_size
self.sigma_range = sigma_range
def transform(self, results: dict) -> dict:
if np.random.rand() > self.prob:
return results
img = results["img"]
if self.sigma_range is not None:
sigma = np.random.uniform(self.sigma_range[0], self.sigma_range[1])
else:
sigma = 0 # OpenCV auto-calculates
blurred = cv2.GaussianBlur(img, self.kernel_size, sigma)
results["img"] = blurred
return results
def __repr__(self):
return (
f"{self.__class__.__name__}(prob={self.prob}, "
f"kernel_size={self.kernel_size}, sigma_range={self.sigma_range})"
)
@TRANSFORMS.register_module()
class RandomJPEGCompression(BaseTransform):
def __init__(self, prob=0.4, quality_range=(30, 60)):
super().__init__()
self.prob = prob
self.quality_range = quality_range
def transform(self, results: dict) -> dict:
if np.random.rand() > self.prob:
return results
img = results["img"]
q_min, q_max = self.quality_range
quality = np.random.randint(q_min, q_max + 1)
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
success, enc_img = cv2.imencode(".jpg", img, encode_param)
if success:
dec_img = cv2.imdecode(enc_img, cv2.IMREAD_COLOR)
results["img"] = dec_img
return results
def __repr__(self):
return (
f"{self.__class__.__name__}(prob={self.prob}, "
f"quality_range={self.quality_range})"
)
@TRANSFORMS.register_module()
class RandomGaussianNoise(BaseTransform):
def __init__(self, prob=0.4, mean=0.0, var_range=(5.0, 20.0)):
super().__init__()
self.prob = prob
self.mean = mean
self.var_range = var_range
def transform(self, results: dict) -> dict:
if np.random.rand() > self.prob:
return results
img = results["img"].astype(np.float32)
var = np.random.uniform(self.var_range[0], self.var_range[1])
sigma = var**0.5
noise = np.random.normal(self.mean, sigma, img.shape).astype(np.float32)
noisy_img = img + noise
noisy_img = np.clip(noisy_img, 0, 255).astype(np.uint8)
results["img"] = noisy_img
return results
def __repr__(self):
return (
f"{self.__class__.__name__}(prob={self.prob}, "
f"mean={self.mean}, var_range={self.var_range})"
)
@TRANSFORMS.register_module()
class RandomGamma(BaseTransform):
def __init__(self, prob=0.4, gamma_range=(0.7, 1.3)):
super().__init__()
self.prob = prob
self.gamma_range = gamma_range
def transform(self, results: dict) -> dict:
if np.random.rand() > self.prob:
return results
img = results["img"]
gamma = np.random.uniform(self.gamma_range[0], self.gamma_range[1])
# Build a lookup table for [0..255]
table = (
np.array([(i / 255.0) ** gamma * 255 for i in range(256)])
.clip(0, 255)
.astype(np.uint8)
)
img_corrected = cv2.LUT(img, table)
results["img"] = img_corrected
return results
def __repr__(self):
return (
f"{self.__class__.__name__}(prob={self.prob}, "
f"gamma_range={self.gamma_range})"
)
@TRANSFORMS.register_module()
class RandomGrayscale(BaseTransform):
def __init__(self, prob=0.4):
super().__init__()
self.prob = prob
def transform(self, results: dict) -> dict:
if np.random.rand() > self.prob:
return results
img = results["img"]
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray_3ch = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
results["img"] = gray_3ch
return results
def __repr__(self):
return f"{self.__class__.__name__}(prob={self.prob})"
@TRANSFORMS.register_module()
class RandomChannelShuffle(BaseTransform):
def __init__(self, prob=0.4):
super().__init__()
self.prob = prob
def transform(self, results: dict) -> dict:
if np.random.rand() > self.prob:
return results
img = results["img"]
channels = [0, 1, 2]
np.random.shuffle(channels)
img = img[..., channels]
results["img"] = img
return results
def __repr__(self):
return f"{self.__class__.__name__}(prob={self.prob})"
@TRANSFORMS.register_module()
class RandomInvert(BaseTransform):
def __init__(self, prob=0.4):
super().__init__()
self.prob = prob
def transform(self, results: dict) -> dict:
if np.random.rand() > self.prob:
return results
img = results["img"]
results["img"] = 255 - img
return results
def __repr__(self):
return f"{self.__class__.__name__}(prob={self.prob})"
@TRANSFORMS.register_module()
class RandomSolarize(BaseTransform):
def __init__(self, prob=0.4, threshold=128):
super().__init__()
self.prob = prob
self.threshold = threshold
def transform(self, results: dict) -> dict:
if np.random.rand() > self.prob:
return results
img = results["img"]
mask = img > self.threshold
img[mask] = 255 - img[mask]
results["img"] = img
return results
def __repr__(self):
return (
f"{self.__class__.__name__}(prob={self.prob}, threshold={self.threshold})"
)
@TRANSFORMS.register_module()
class RandomPosterize(BaseTransform):
def __init__(self, prob=0.4, bits=(2, 5)):
super().__init__()
self.prob = prob
self.bits = bits
def transform(self, results: dict) -> dict:
if np.random.rand() > self.prob:
return results
img = results["img"]
# pick random bits
bits_chosen = random.randint(self.bits[0], self.bits[1])
shift = 8 - bits_chosen
img = (img >> shift) << shift
results["img"] = img
return results
def __repr__(self):
return f"{self.__class__.__name__}(prob={self.prob}, bits={self.bits})"