| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| from typing import List, Tuple, Union |
| import numpy as np |
| import torch |
| from PIL import Image |
| from torchvision.transforms import RandomResizedCrop |
| from torchvision.transforms.functional import InterpolationMode, to_tensor |
|
|
|
|
| class BucketResize: |
| def __init__( |
| self, |
| max_area: float, |
| interpolation: InterpolationMode = InterpolationMode.LANCZOS, |
| aspect_ratios: List[str] = None, |
| stride: Union[int, Tuple[int]] = None, |
| ): |
| self.max_area = max_area |
| self.interpolation = interpolation |
|
|
| assert aspect_ratios and stride, "`aspect_ratios` or `stride` not given!" |
| self.buckets, self.bucket_ratios = self.init_buckets(aspect_ratios, max_area, stride) |
| self.bucket_resize = { |
| |
| |
| bucket: RandomResizedCrop( |
| size=(bucket[0], bucket[1]), |
| scale=(1, 1), |
| ratio=(bucket_ratio, bucket_ratio), |
| interpolation=self.interpolation, |
| ) |
| for bucket, bucket_ratio in zip(self.buckets, self.bucket_ratios) |
| } |
|
|
| def __call__(self, image: Union[torch.Tensor, Image.Image, List[Image.Image]]): |
|
|
| if isinstance(image, torch.Tensor): |
| height, width = image.shape[-2:] |
| elif isinstance(image, Image.Image): |
| width, height = image.size |
| elif isinstance(image, list) and isinstance(image[0], Image.Image): |
| width, height = image[0].size |
| else: |
| raise NotImplementedError |
|
|
| bucket = self.find_nearest_bucket(width, height) |
| resizer = self.bucket_resize[bucket] |
|
|
| if isinstance(image, list) and isinstance(image[0], Image.Image): |
| return torch.stack([to_tensor(resizer(_image)) for _image in image]) |
| else: |
| image = resizer(image) |
| if isinstance(image, Image.Image): |
| image = to_tensor(image) |
| return image |
|
|
| def find_nearest_bucket(self, width, height): |
| """ |
| 找到与给定图片最近的bucket尺寸 |
| """ |
| image_ratio = width / height |
| diff = np.abs(image_ratio - self.bucket_ratios) |
| index = diff.argmin() |
| return self.buckets[index] |
|
|
| @staticmethod |
| def init_buckets(aspect_ratio_names, max_area, stride): |
| """ |
| 指定一些列最接近给定宽高比和面积的,同时整除vae降采样和patch_size倍数的宽高 |
| """ |
| if not isinstance(stride, (tuple, list)): |
| stride = (stride, stride) |
| height_factor, width_factor = stride |
|
|
| buckets, bucket_ratios = [], [] |
| for name in aspect_ratio_names: |
| w, h = (int(v) for v in name.split(":")) |
| aspect_ratio = w / h |
|
|
| resize_width1 = math.sqrt(max_area * aspect_ratio) |
| bucket_width1 = round(resize_width1 / width_factor) * width_factor |
| resize_height1 = bucket_width1 / aspect_ratio |
| bucket_height1 = round(resize_height1 / height_factor) * height_factor |
| bucket_ratio1 = bucket_width1 / bucket_height1 |
| bucket_area1 = bucket_width1 * bucket_height1 |
|
|
| resize_height2 = math.sqrt(max_area / aspect_ratio) |
| bucket_height2 = round(resize_height2 / height_factor) * height_factor |
| resize_width2 = bucket_height2 * aspect_ratio |
| bucket_width2 = round(resize_width2 / width_factor) * width_factor |
| bucket_ratio2 = bucket_width2 / bucket_height2 |
| bucket_area2 = bucket_width2 * bucket_height2 |
|
|
| if abs(bucket_ratio1 - aspect_ratio) < abs(bucket_ratio2 - aspect_ratio): |
| bucket_width, bucket_height = bucket_width1, bucket_height1 |
| elif abs(bucket_ratio1 - aspect_ratio) > abs(bucket_ratio2 - aspect_ratio): |
| bucket_width, bucket_height = bucket_width2, bucket_height2 |
| else: |
| if abs(bucket_area1 - max_area) <= abs(bucket_area2 - max_area): |
| bucket_width, bucket_height = bucket_width1, bucket_height1 |
| else: |
| bucket_width, bucket_height = bucket_width2, bucket_height2 |
|
|
| bucket_ratio = bucket_width / bucket_height |
|
|
| buckets.append((bucket_height, bucket_width)) |
| bucket_ratios.append(bucket_ratio) |
|
|
| bucket_ratios = np.array(bucket_ratios) |
|
|
| return buckets, bucket_ratios |
|
|
|
|
| |
| |
| |
|
|
| def check_buckets(max_area: int, aspect_ratios: List[str], stride: int): |
| """ |
| 一个检查并打印 BucketResize.init_buckets 输出的辅助函数。 |
| |
| Args: |
| max_area (int): 目标总像素面积。 |
| aspect_ratios (List[str]): 目标宽高比列表 (例如: ["1:1", "4:3"])。 |
| stride (int): 步幅,高度和宽度必须是它的整数倍。 |
| """ |
| print(f"--- Checking Configuration ---") |
| print(f"Max Area: {max_area} | Aspect Ratios: {aspect_ratios} | Stride: {stride}") |
| print("-" * 35) |
|
|
| buckets, bucket_ratios = BucketResize.init_buckets(aspect_ratios, max_area, stride) |
|
|
| print("Generated Buckets (Height, Width) and Ratios:") |
| for (h, w), ratio in zip(buckets, bucket_ratios): |
| |
| print(f" - Bucket: ({h:4d}, {w:4d}) | Ratio: {ratio:.4f} | Area: {h*w}") |
| print("\n") |
|
|
|
|
| if __name__ == '__main__': |
| |
| |
| check_buckets( |
| |
| max_area=224*224, |
| aspect_ratios=["21:9", '1:1', '4:3', '3:4', '9:16', '16:9'], |
| stride=28 |
| ) |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|