Spaces:
Running on Zero
Running on Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import random | |
| import warnings | |
| import cv2 | |
| import numpy as np | |
| from sapiens.engine.datasets import BaseTransform, to_tensor | |
| from sapiens.registry import TRANSFORMS | |
| class SegRandomRotate(BaseTransform): | |
| def __init__( | |
| self, | |
| prob=0.5, | |
| degree=60, | |
| pad_val=0, | |
| seg_pad_val=255, | |
| ): | |
| super().__init__() | |
| self.prob = prob | |
| assert prob >= 0 and prob <= 1 | |
| assert degree > 0, f"degree {degree} should be positive" | |
| self.degree = (-degree, degree) | |
| assert len(self.degree) == 2, ( | |
| f"degree {self.degree} should be a tuple of (min, max)" | |
| ) | |
| self.pad_val = pad_val | |
| self.seg_pad_val = seg_pad_val | |
| def transform(self, results: dict) -> dict: | |
| if random.random() > self.prob: | |
| return results | |
| degree = random.uniform(min(*self.degree), max(*self.degree)) | |
| img = results["img"] | |
| gt_seg = results["gt_seg"] | |
| h, w = img.shape[:2] | |
| center = (w / 2, h / 2) | |
| M = cv2.getRotationMatrix2D(center, degree, 1.0) | |
| results["img"] = cv2.warpAffine( | |
| img, M, (w, h), flags=cv2.INTER_LINEAR, borderValue=self.pad_val | |
| ) | |
| # Rotate the segmentation map | |
| results["gt_seg"] = cv2.warpAffine( | |
| gt_seg, | |
| M, | |
| (w, h), | |
| flags=cv2.INTER_NEAREST, | |
| borderValue=self.seg_pad_val, | |
| ) | |
| return results | |
| def __repr__(self): | |
| repr_str = self.__class__.__name__ | |
| repr_str += ( | |
| f"(prob={self.prob}, " | |
| f"degree={self.degree}, " | |
| f"pad_val={self.pal_val}, " | |
| f"seg_pad_val={self.seg_pad_val}, " | |
| ) | |
| return repr_str | |
| class SegRandomHorizontalFlip(BaseTransform): | |
| def __init__(self, prob=0.5, swap_seg_labels=None): | |
| super().__init__() | |
| self.prob = prob | |
| self.swap_seg_labels = swap_seg_labels | |
| def transform(self, results: dict) -> dict: | |
| if random.random() > self.prob: | |
| return results | |
| img = results["img"] | |
| gt_seg = results["gt_seg"] | |
| img = cv2.flip(img, 1) | |
| gt_seg = cv2.flip(gt_seg, 1) | |
| temp = gt_seg.copy() | |
| if self.swap_seg_labels is not None: | |
| for pair in self.swap_seg_labels: | |
| assert len(pair) == 2 | |
| gt_seg[temp == pair[0]] = pair[1] | |
| gt_seg[temp == pair[1]] = pair[0] | |
| results["img"] = img | |
| results["gt_seg"] = gt_seg | |
| return results | |
| class SegPackInputs(BaseTransform): | |
| def __init__( | |
| self, | |
| test_mode: bool = False, | |
| meta_keys=( | |
| "img_path", | |
| "ori_shape", | |
| "img_shape", | |
| "pad_shape", | |
| "flip", | |
| ), | |
| ): | |
| super().__init__() | |
| self.test_mode = test_mode | |
| self.meta_keys = meta_keys | |
| def transform(self, results: dict) -> dict: | |
| packed_results = dict() | |
| if "img" in results: | |
| img = results["img"] | |
| if len(img.shape) < 3: | |
| img = np.expand_dims(img, -1) | |
| if not img.flags.c_contiguous: | |
| img = to_tensor(np.ascontiguousarray(img.transpose(2, 0, 1))) | |
| else: | |
| img = img.transpose(2, 0, 1) | |
| img = to_tensor(img).contiguous() | |
| packed_results["inputs"] = img | |
| data_sample = dict() | |
| if "gt_seg" in results: | |
| assert len(results["gt_seg"].shape) == 2 # H x W | |
| mask = (results["gt_seg"] > 0) * (results["gt_seg"] != 255) | |
| if ( | |
| mask.sum() / (mask.shape[0] * mask.shape[1]) < 0.01 | |
| and self.test_mode == False | |
| ): | |
| return None | |
| data_sample["gt_seg"] = to_tensor( | |
| results["gt_seg"][None, ...].astype(np.int64) | |
| ) | |
| img_meta = {} | |
| for key in self.meta_keys: | |
| if key in results: | |
| if isinstance(results[key], (int, float)): | |
| img_meta[key] = np.float32(results[key]) | |
| elif isinstance(results[key], np.ndarray): | |
| img_meta[key] = results[key].astype(np.float32) | |
| else: | |
| img_meta[key] = results[key] | |
| data_sample["meta"] = img_meta | |
| packed_results["data_samples"] = data_sample | |
| return packed_results | |
| def __repr__(self) -> str: | |
| repr_str = self.__class__.__name__ | |
| repr_str += f"(meta_keys={self.meta_keys})" | |
| return repr_str | |
| class SegRandomResize(BaseTransform): | |
| def __init__( | |
| self, | |
| base_height=1024, | |
| base_width=768, | |
| ratio_range=(0.4, 2.0), | |
| keep_ratio=True, | |
| ): | |
| super().__init__() | |
| self.base_height = base_height | |
| self.base_width = base_width | |
| self.ratio_range = ratio_range | |
| self.keep_ratio = keep_ratio | |
| self.resizer = SegResize( | |
| height=self.base_height, width=self.base_width, keep_ratio=keep_ratio | |
| ) | |
| def transform(self, results: dict) -> dict: | |
| min_ratio, max_ratio = self.ratio_range | |
| ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio | |
| self.resizer.height = int(self.base_height * ratio) | |
| self.resizer.width = int(self.base_width * ratio) | |
| return self.resizer.transform(results) | |
| def __repr__(self): | |
| return ( | |
| f"{self.__class__.__name__}(" | |
| f"base_height={self.base_height}, " | |
| f"base_width={self.base_width}, " | |
| f"ratio_range={self.ratio_range}, " | |
| f"keep_ratio={self.keep_ratio})" | |
| ) | |
| class SegResize(BaseTransform): | |
| def __init__( | |
| self, | |
| height=1024, | |
| width=768, | |
| keep_ratio=False, | |
| test_mode: bool = False, | |
| ): | |
| super().__init__() | |
| self.height = height | |
| self.width = width | |
| self.keep_ratio = keep_ratio | |
| self.test_mode = test_mode | |
| def transform(self, results: dict) -> dict: | |
| img = results["img"] | |
| h, w = img.shape[:2] | |
| target_height = self.height | |
| target_width = self.width | |
| if self.keep_ratio is True: | |
| scale_factor = min(target_width / w, target_height / h) | |
| new_w = int(round(w * scale_factor)) | |
| new_h = int(round(h * scale_factor)) | |
| else: | |
| new_w = target_width | |
| new_h = target_height | |
| dsize = (new_w, new_h) | |
| # Use INTER_AREA for shrinking and INTER_CUBIC for enlarging | |
| # to get antialiased results. | |
| img_interpolation = cv2.INTER_AREA if new_w < w else cv2.INTER_CUBIC | |
| # Update the results dictionary | |
| results["img"] = cv2.resize(img, dsize, interpolation=img_interpolation) | |
| ## resize gt seg if training | |
| if "gt_seg" in results and self.test_mode is False: | |
| gt_seg = results["gt_seg"] | |
| results["gt_seg"] = cv2.resize( | |
| gt_seg, dsize, interpolation=cv2.INTER_NEAREST | |
| ) | |
| return results | |
| def __repr__(self): | |
| return ( | |
| f"{self.__class__.__name__}(" | |
| f"height={self.height}, " | |
| f"width={self.width}, " | |
| f"keep_ratio={self.keep_ratio})" | |
| ) | |
| class SegRandomCrop(BaseTransform): | |
| def __init__( | |
| self, | |
| crop_height=1024, | |
| crop_width=768, | |
| prob=0.5, | |
| cat_max_ratio=0.75, | |
| ignore_index=255, | |
| ): | |
| super().__init__() | |
| self.crop_height = crop_height | |
| self.crop_width = crop_width | |
| self.prob = prob | |
| self.cat_max_ratio = cat_max_ratio | |
| self.ignore_index = ignore_index | |
| def _generate_crop_bbox(self, img: np.ndarray) -> tuple: | |
| """Randomly get a crop bounding box.""" | |
| margin_h = max(img.shape[0] - self.crop_height, 0) | |
| margin_w = max(img.shape[1] - self.crop_width, 0) | |
| offset_h = np.random.randint(0, margin_h + 1) | |
| offset_w = np.random.randint(0, margin_w + 1) | |
| crop_y1, crop_y2 = offset_h, offset_h + self.crop_height | |
| crop_x1, crop_x2 = offset_w, offset_w + self.crop_width | |
| return crop_y1, crop_y2, crop_x1, crop_x2 | |
| def _crop(self, img: np.ndarray, crop_bbox: tuple) -> np.ndarray: | |
| """Crop from ``img``""" | |
| crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox | |
| img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...] | |
| return img | |
| def transform(self, results: dict) -> dict: | |
| """Transform function to randomly crop images and segmentation maps.""" | |
| if random.random() > self.prob: | |
| return results | |
| img = results["img"] | |
| gt_seg = results["gt_seg"] | |
| h, w = img.shape[:2] | |
| # Pad the image if it's smaller than the crop size | |
| pad_h = max(self.crop_height - h, 0) | |
| pad_w = max(self.crop_width - w, 0) | |
| if pad_h > 0 or pad_w > 0: | |
| img = cv2.copyMakeBorder( | |
| img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0 | |
| ) | |
| gt_seg = cv2.copyMakeBorder( | |
| gt_seg, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_index | |
| ) | |
| padded_img = img | |
| padded_gt_seg = gt_seg | |
| crop_bbox = self._generate_crop_bbox(padded_img) | |
| if self.cat_max_ratio < 1.0: | |
| # Repeat 10 times to find a valid crop | |
| for _ in range(10): | |
| seg_temp = self._crop(padded_gt_seg, crop_bbox) | |
| labels, cnt = np.unique(seg_temp, return_counts=True) | |
| # Filter out the ignore_index | |
| cnt = cnt[labels != self.ignore_index] | |
| if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < self.cat_max_ratio: | |
| break # Found a valid crop | |
| crop_bbox = self._generate_crop_bbox(padded_img) | |
| # Crop the image and segmentation map | |
| results["img"] = self._crop(padded_img, crop_bbox) | |
| results["gt_seg"] = self._crop(padded_gt_seg, crop_bbox) | |
| return results | |
| def __repr__(self): | |
| return ( | |
| f"{self.__class__.__name__}(" | |
| f"crop_height={self.crop_height}, " | |
| f"crop_width={self.crop_width}, " | |
| f"prob={self.prob}, " | |
| f"cat_max_ratio={self.cat_max_ratio}, " | |
| f"ignore_index={self.ignore_index})" | |
| ) | |
| class SegRandomBackground(BaseTransform): | |
| def __init__( | |
| self, | |
| prob: float = 0.25, | |
| skip_key: str = "is_itw", | |
| background_images_root: str = "", | |
| ): | |
| super().__init__() | |
| self.prob = prob | |
| self.skip_key = skip_key | |
| self.background_images_root = background_images_root | |
| self.background_images = sorted( | |
| [ | |
| image_name | |
| for image_name in os.listdir(background_images_root) | |
| if image_name.endswith(".jpg") | |
| ] | |
| ) | |
| def transform(self, results: dict) -> dict: | |
| if random.random() > self.prob: | |
| return results | |
| if self.skip_key in results and results[self.skip_key]: | |
| return results | |
| image = results["img"] ## bgr image | |
| if "gt_seg" in results: | |
| gt_seg = results["gt_seg"] | |
| mask = (gt_seg > 0).astype(np.uint8) | |
| elif "mask" in results: | |
| mask = results["mask"] | |
| mask = (mask > 0).astype(np.uint8) | |
| else: | |
| warnings.warn( | |
| f"foreground mask not found in results, skip random background!" | |
| ) | |
| return results | |
| background_image_path = os.path.join( | |
| self.background_images_root, random.choice(self.background_images) | |
| ) | |
| background_image = cv2.imread(background_image_path) ## bgr image | |
| ##----------------------------- | |
| background_height = background_image.shape[0] | |
| background_width = background_image.shape[1] | |
| image_height = image.shape[0] | |
| image_width = image.shape[1] | |
| new_background_height = image_height | |
| new_background_width = int( | |
| new_background_height * background_width / background_height | |
| ) | |
| background_image = cv2.resize( | |
| background_image, (new_background_width, new_background_height) | |
| ) | |
| # Crop the background image to the width of the original image | |
| if new_background_width > image_width: | |
| start_x = (new_background_width - image_width) // 2 | |
| end_x = start_x + image_width | |
| background_image = background_image[:, start_x:end_x] | |
| if ( | |
| background_image.shape[0] != image_height | |
| or background_image.shape[1] != image_width | |
| ): | |
| background_image = cv2.resize(background_image, (image_width, image_height)) | |
| # Use the segmentation mask as an alpha channel. | |
| alpha_norm = mask.astype(np.float32) # Values 0 or 1. | |
| alpha_mask = np.stack([alpha_norm] * 3, axis=-1) | |
| composite = alpha_mask * image + (1 - alpha_mask) * background_image | |
| composite = composite.astype(np.uint8) | |
| # Apply color transfer using the Reinhard algorithm. | |
| composite = self.reinhard_alpha(composite, alpha_norm) | |
| results["img"] = composite | |
| return results | |
| def reinhard_alpha(self, comp_img, alpha_mask): | |
| """ | |
| # Reinhard color transfer algorithm with alpha mask support | |
| # paper: https://www.cs.tau.ac.il/~turkel/imagepapers/ColorTransfer.pdf | |
| # alpha mask in range [0, 1] | |
| """ | |
| # Convert to LAB color space | |
| comp_lab = cv2.cvtColor(comp_img, cv2.COLOR_BGR2Lab) | |
| # Calculate weighted mean and std for background and foreground | |
| bg_weights = 1 - alpha_mask | |
| fg_weights = alpha_mask | |
| bg_mean, bg_std = self.weighted_mean_std(comp_lab, bg_weights) | |
| fg_mean, fg_std = self.weighted_mean_std(comp_lab, fg_weights) | |
| # Avoid division by zero | |
| fg_std = np.maximum(fg_std, 1e-6) | |
| ratio = (bg_std / fg_std).reshape(-1) | |
| offset = (bg_mean - fg_mean * bg_std / fg_std).reshape(-1) | |
| # Apply color transfer | |
| trans_lab = cv2.convertScaleAbs(comp_lab * ratio + offset) | |
| trans_img = cv2.cvtColor(trans_lab, cv2.COLOR_Lab2BGR) | |
| # Blend the transferred image with the original image using the alpha mask | |
| alpha_mask_3d = np.repeat(alpha_mask[:, :, np.newaxis], 3, axis=2) | |
| trans_comp = ( | |
| trans_img * alpha_mask_3d + comp_img * (1 - alpha_mask_3d) | |
| ).astype(np.uint8) | |
| return trans_comp | |
| def weighted_mean_std(self, img, weights): | |
| # Ensure weights have the same shape as img | |
| weights_3d = np.repeat(weights[:, :, np.newaxis], img.shape[2], axis=2) | |
| # Calculate weighted mean | |
| total_weights = np.sum(weights_3d, axis=(0, 1)) | |
| mean = np.sum(img * weights_3d, axis=(0, 1)) / total_weights | |
| # Calculate weighted standard deviation | |
| variance = np.sum(((img - mean) ** 2) * weights_3d, axis=(0, 1)) / total_weights | |
| std = np.sqrt(variance) | |
| return mean, std | |