| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from typing import Literal |
| from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Resize |
|
|
| from .area_resize import AreaResize |
| from .bucket_resize import BucketResize |
|
|
| def NaResize( |
| resolution: int, |
| mode: Literal["area", "square", "bucket"], |
| downsample_only: bool, |
| interpolation: InterpolationMode = InterpolationMode.BICUBIC, |
| **kwargs, |
| ): |
| if mode == "area": |
| return AreaResize( |
| max_area=resolution**2, |
| downsample_only=downsample_only, |
| interpolation=interpolation, |
| ) |
| elif mode == "square": |
| return Compose( |
| [ |
| Resize( |
| size=resolution, |
| interpolation=interpolation, |
| ), |
| CenterCrop(resolution), |
| ] |
| ) |
| elif mode == "bucket": |
| aspect_ratios = kwargs.get("aspect_ratios", ["21:9", "16:9", "4:3", "1:1", "3:4", "9:16"]) |
| stride = kwargs.get("stride", 16) |
| return Compose( |
| [ |
| BucketResize( |
| max_area=resolution**2, |
| interpolation=interpolation, |
| aspect_ratios=aspect_ratios, |
| stride=stride, |
| ) |
| ] |
| ) |
| raise ValueError(f"Unknown resize mode: {mode}") |
|
|