Rawal Khirodkar
Initial sapiens2-pointmap Space (HF download at startup, all 4 sizes, 3D viewer)
bff20b3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import random
import warnings
import cv2
import numpy as np
from sapiens.engine.datasets import BaseTransform, to_tensor
from sapiens.registry import TRANSFORMS
@TRANSFORMS.register_module()
class SegRandomRotate(BaseTransform):
def __init__(
self,
prob=0.5,
degree=60,
pad_val=0,
seg_pad_val=255,
):
super().__init__()
self.prob = prob
assert prob >= 0 and prob <= 1
assert degree > 0, f"degree {degree} should be positive"
self.degree = (-degree, degree)
assert len(self.degree) == 2, (
f"degree {self.degree} should be a tuple of (min, max)"
)
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val
def transform(self, results: dict) -> dict:
if random.random() > self.prob:
return results
degree = random.uniform(min(*self.degree), max(*self.degree))
img = results["img"]
gt_seg = results["gt_seg"]
h, w = img.shape[:2]
center = (w / 2, h / 2)
M = cv2.getRotationMatrix2D(center, degree, 1.0)
results["img"] = cv2.warpAffine(
img, M, (w, h), flags=cv2.INTER_LINEAR, borderValue=self.pad_val
)
# Rotate the segmentation map
results["gt_seg"] = cv2.warpAffine(
gt_seg,
M,
(w, h),
flags=cv2.INTER_NEAREST,
borderValue=self.seg_pad_val,
)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += (
f"(prob={self.prob}, "
f"degree={self.degree}, "
f"pad_val={self.pal_val}, "
f"seg_pad_val={self.seg_pad_val}, "
)
return repr_str
@TRANSFORMS.register_module()
class SegRandomHorizontalFlip(BaseTransform):
def __init__(self, prob=0.5, swap_seg_labels=None):
super().__init__()
self.prob = prob
self.swap_seg_labels = swap_seg_labels
def transform(self, results: dict) -> dict:
if random.random() > self.prob:
return results
img = results["img"]
gt_seg = results["gt_seg"]
img = cv2.flip(img, 1)
gt_seg = cv2.flip(gt_seg, 1)
temp = gt_seg.copy()
if self.swap_seg_labels is not None:
for pair in self.swap_seg_labels:
assert len(pair) == 2
gt_seg[temp == pair[0]] = pair[1]
gt_seg[temp == pair[1]] = pair[0]
results["img"] = img
results["gt_seg"] = gt_seg
return results
@TRANSFORMS.register_module()
class SegPackInputs(BaseTransform):
def __init__(
self,
test_mode: bool = False,
meta_keys=(
"img_path",
"ori_shape",
"img_shape",
"pad_shape",
"flip",
),
):
super().__init__()
self.test_mode = test_mode
self.meta_keys = meta_keys
def transform(self, results: dict) -> dict:
packed_results = dict()
if "img" in results:
img = results["img"]
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
if not img.flags.c_contiguous:
img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1)))
else:
img = img.transpose(2, 0, 1)
img = to_tensor(img).contiguous()
packed_results["inputs"] = img
data_sample = dict()
if "gt_seg" in results:
assert len(results["gt_seg"].shape) == 2 # H x W
mask = (results["gt_seg"] > 0) * (results["gt_seg"] != 255)
if (
mask.sum() / (mask.shape[0] * mask.shape[1]) < 0.01
and self.test_mode == False
):
return None
data_sample["gt_seg"] = to_tensor(
results["gt_seg"][None, ...].astype(np.int64)
)
img_meta = {}
for key in self.meta_keys:
if key in results:
if isinstance(results[key], (int, float)):
img_meta[key] = np.float32(results[key])
elif isinstance(results[key], np.ndarray):
img_meta[key] = results[key].astype(np.float32)
else:
img_meta[key] = results[key]
data_sample["meta"] = img_meta
packed_results["data_samples"] = data_sample
return packed_results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f"(meta_keys={self.meta_keys})"
return repr_str
@TRANSFORMS.register_module()
class SegRandomResize(BaseTransform):
def __init__(
self,
base_height=1024,
base_width=768,
ratio_range=(0.4, 2.0),
keep_ratio=True,
):
super().__init__()
self.base_height = base_height
self.base_width = base_width
self.ratio_range = ratio_range
self.keep_ratio = keep_ratio
self.resizer = SegResize(
height=self.base_height, width=self.base_width, keep_ratio=keep_ratio
)
def transform(self, results: dict) -> dict:
min_ratio, max_ratio = self.ratio_range
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
self.resizer.height = int(self.base_height * ratio)
self.resizer.width = int(self.base_width * ratio)
return self.resizer.transform(results)
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"base_height={self.base_height}, "
f"base_width={self.base_width}, "
f"ratio_range={self.ratio_range}, "
f"keep_ratio={self.keep_ratio})"
)
@TRANSFORMS.register_module()
class SegResize(BaseTransform):
def __init__(
self,
height=1024,
width=768,
keep_ratio=False,
test_mode: bool = False,
):
super().__init__()
self.height = height
self.width = width
self.keep_ratio = keep_ratio
self.test_mode = test_mode
def transform(self, results: dict) -> dict:
img = results["img"]
h, w = img.shape[:2]
target_height = self.height
target_width = self.width
if self.keep_ratio is True:
scale_factor = min(target_width / w, target_height / h)
new_w = int(round(w * scale_factor))
new_h = int(round(h * scale_factor))
else:
new_w = target_width
new_h = target_height
dsize = (new_w, new_h)
# Use INTER_AREA for shrinking and INTER_CUBIC for enlarging
# to get antialiased results.
img_interpolation = cv2.INTER_AREA if new_w < w else cv2.INTER_CUBIC
# Update the results dictionary
results["img"] = cv2.resize(img, dsize, interpolation=img_interpolation)
## resize gt seg if training
if "gt_seg" in results and self.test_mode is False:
gt_seg = results["gt_seg"]
results["gt_seg"] = cv2.resize(
gt_seg, dsize, interpolation=cv2.INTER_NEAREST
)
return results
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"height={self.height}, "
f"width={self.width}, "
f"keep_ratio={self.keep_ratio})"
)
@TRANSFORMS.register_module()
class SegRandomCrop(BaseTransform):
def __init__(
self,
crop_height=1024,
crop_width=768,
prob=0.5,
cat_max_ratio=0.75,
ignore_index=255,
):
super().__init__()
self.crop_height = crop_height
self.crop_width = crop_width
self.prob = prob
self.cat_max_ratio = cat_max_ratio
self.ignore_index = ignore_index
def _generate_crop_bbox(self, img: np.ndarray) -> tuple:
"""Randomly get a crop bounding box."""
margin_h = max(img.shape[0] - self.crop_height, 0)
margin_w = max(img.shape[1] - self.crop_width, 0)
offset_h = np.random.randint(0, margin_h + 1)
offset_w = np.random.randint(0, margin_w + 1)
crop_y1, crop_y2 = offset_h, offset_h + self.crop_height
crop_x1, crop_x2 = offset_w, offset_w + self.crop_width
return crop_y1, crop_y2, crop_x1, crop_x2
def _crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray:
"""Crop from ``img``"""
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
return img
def transform(self, results: dict) -> dict:
"""Transform function to randomly crop images and segmentation maps."""
if random.random() > self.prob:
return results
img = results["img"]
gt_seg = results["gt_seg"]
h, w = img.shape[:2]
# Pad the image if it's smaller than the crop size
pad_h = max(self.crop_height - h, 0)
pad_w = max(self.crop_width - w, 0)
if pad_h > 0 or pad_w > 0:
img = cv2.copyMakeBorder(
img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0
)
gt_seg = cv2.copyMakeBorder(
gt_seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_index
)
padded_img = img
padded_gt_seg = gt_seg
crop_bbox = self._generate_crop_bbox(padded_img)
if self.cat_max_ratio < 1.0:
# Repeat 10 times to find a valid crop
for _ in range(10):
seg_temp = self._crop(padded_gt_seg, crop_bbox)
labels, cnt = np.unique(seg_temp, return_counts=True)
# Filter out the ignore_index
cnt = cnt[labels != self.ignore_index]
if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < self.cat_max_ratio:
break # Found a valid crop
crop_bbox = self._generate_crop_bbox(padded_img)
# Crop the image and segmentation map
results["img"] = self._crop(padded_img, crop_bbox)
results["gt_seg"] = self._crop(padded_gt_seg, crop_bbox)
return results
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"crop_height={self.crop_height}, "
f"crop_width={self.crop_width}, "
f"prob={self.prob}, "
f"cat_max_ratio={self.cat_max_ratio}, "
f"ignore_index={self.ignore_index})"
)
@TRANSFORMS.register_module()
class SegRandomBackground(BaseTransform):
def __init__(
self,
prob: float = 0.25,
skip_key: str = "is_itw",
background_images_root: str = "",
):
super().__init__()
self.prob = prob
self.skip_key = skip_key
self.background_images_root = background_images_root
self.background_images = sorted(
[
image_name
for image_name in os.listdir(background_images_root)
if image_name.endswith(".jpg")
]
)
def transform(self, results: dict) -> dict:
if random.random() > self.prob:
return results
if self.skip_key in results and results[self.skip_key]:
return results
image = results["img"] ## bgr image
if "gt_seg" in results:
gt_seg = results["gt_seg"]
mask = (gt_seg > 0).astype(np.uint8)
elif "mask" in results:
mask = results["mask"]
mask = (mask > 0).astype(np.uint8)
else:
warnings.warn(
f"foreground mask not found in results, skip random background!"
)
return results
background_image_path = os.path.join(
self.background_images_root, random.choice(self.background_images)
)
background_image = cv2.imread(background_image_path) ## bgr image
##-----------------------------
background_height = background_image.shape[0]
background_width = background_image.shape[1]
image_height = image.shape[0]
image_width = image.shape[1]
new_background_height = image_height
new_background_width = int(
new_background_height * background_width / background_height
)
background_image = cv2.resize(
background_image, (new_background_width, new_background_height)
)
# Crop the background image to the width of the original image
if new_background_width > image_width:
start_x = (new_background_width - image_width) // 2
end_x = start_x + image_width
background_image = background_image[:, start_x:end_x]
if (
background_image.shape[0] != image_height
or background_image.shape[1] != image_width
):
background_image = cv2.resize(background_image, (image_width, image_height))
# Use the segmentation mask as an alpha channel.
alpha_norm = mask.astype(np.float32) # Values 0 or 1.
alpha_mask = np.stack([alpha_norm] * 3, axis=-1)
composite = alpha_mask * image + (1 - alpha_mask) * background_image
composite = composite.astype(np.uint8)
# Apply color transfer using the Reinhard algorithm.
composite = self.reinhard_alpha(composite, alpha_norm)
results["img"] = composite
return results
def reinhard_alpha(self, comp_img, alpha_mask):
"""
# Reinhard color transfer algorithm with alpha mask support
# paper: https://www.cs.tau.ac.il/~turkel/imagepapers/ColorTransfer.pdf
# alpha mask in range [0, 1]
"""
# Convert to LAB color space
comp_lab = cv2.cvtColor(comp_img, cv2.COLOR_BGR2Lab)
# Calculate weighted mean and std for background and foreground
bg_weights = 1 - alpha_mask
fg_weights = alpha_mask
bg_mean, bg_std = self.weighted_mean_std(comp_lab, bg_weights)
fg_mean, fg_std = self.weighted_mean_std(comp_lab, fg_weights)
# Avoid division by zero
fg_std = np.maximum(fg_std, 1e-6)
ratio = (bg_std / fg_std).reshape(-1)
offset = (bg_mean - fg_mean * bg_std / fg_std).reshape(-1)
# Apply color transfer
trans_lab = cv2.convertScaleAbs(comp_lab * ratio + offset)
trans_img = cv2.cvtColor(trans_lab, cv2.COLOR_Lab2BGR)
# Blend the transferred image with the original image using the alpha mask
alpha_mask_3d = np.repeat(alpha_mask[:, :, np.newaxis], 3, axis=2)
trans_comp = (
trans_img * alpha_mask_3d + comp_img * (1 - alpha_mask_3d)
).astype(np.uint8)
return trans_comp
def weighted_mean_std(self, img, weights):
# Ensure weights have the same shape as img
weights_3d = np.repeat(weights[:, :, np.newaxis], img.shape[2], axis=2)
# Calculate weighted mean
total_weights = np.sum(weights_3d, axis=(0, 1))
mean = np.sum(img * weights_3d, axis=(0, 1)) / total_weights
# Calculate weighted standard deviation
variance = np.sum(((img - mean) ** 2) * weights_3d, axis=(0, 1)) / total_weights
std = np.sqrt(variance)
return mean, std